# Import segmentation model for ultrasound images of the marbling area import torch from model import ResAttnUNet from PIL import Image import torchvision.transforms as TF device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ResAttnUNet(in_channels=3, out_classes=1) model.load_state_dict(torch.load("resattn_unet_inference_echo_persille.pth", map_location=device)) model.to(device) img = Image.open('example.png').convert("RGB") transform = TF.Compose([ TF.Resize((320, 320)), TF.ToTensor() ]) x = transform(img).unsqueeze(0).to(device) with torch.no_grad(): logits = model(x) probs = torch.sigmoid(logits) mask = (probs > 0.5).float() print(mask.shape)