File size: 723 Bytes
e5f0362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Import segmentation model for ultrasound images of the marbling area
import torch
from model import ResAttnUNet
from PIL import Image
import torchvision.transforms as TF

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResAttnUNet(in_channels=3, out_classes=1)
model.load_state_dict(torch.load("resattn_unet_inference_echo_persille.pth", map_location=device))
model.to(device)
img = Image.open('example.png').convert("RGB")
transform = TF.Compose([
    TF.Resize((320, 320)),
    TF.ToTensor() 
])
x = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
    logits = model(x)
    probs = torch.sigmoid(logits)
    mask = (probs > 0.5).float()

print(mask.shape)