| # Import segmentation model for ultrasound images of the marbling area | |
| import torch | |
| from model import ResAttnUNet | |
| from PIL import Image | |
| import torchvision.transforms as TF | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = ResAttnUNet(in_channels=3, out_classes=1) | |
| model.load_state_dict(torch.load("resattn_unet_inference_echo_persille.pth", map_location=device)) | |
| model.to(device) | |
| img = Image.open('example.png').convert("RGB") | |
| transform = TF.Compose([ | |
| TF.Resize((320, 320)), | |
| TF.ToTensor() | |
| ]) | |
| x = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(x) | |
| probs = torch.sigmoid(logits) | |
| mask = (probs > 0.5).float() | |
| print(mask.shape) |