File size: 1,864 Bytes
7b615ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
import torch
import numpy as np
import gradio as gr
from PIL import Image
import scripts.config as config
import torchvision.transforms as transforms
import scripts.Segmentation.augment as augment
from scripts.Segmentation.models import ResNetUNet

config.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

modelo = '/best_model.pt'
model = ResNetUNet(num_classes=2).to(config.device)
model.load_state_dict(torch.load(config.checkpoints + modelo, map_location=config.device))
model.eval()

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((config.height, config.width)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

def segment_image(input_img):
    input_img = input_img.convert('L')
    img_tensor = transform(input_img).unsqueeze(0).to(config.device)  # (1, 1, H, W)

    output = None
    if config.USE_TTA:
        output = augment.predict_with_tta(model, img_tensor)
    else:
        output = model(img_tensor)

    probs = torch.softmax(output, dim=1)  # (1, 2, H, W)
    mask = torch.argmax(probs, dim=1).squeeze(0)

    if config.USE_REFINEMENT:
        mask = augment.refine_mask(mask)

    mask = mask.cpu().numpy()
    mask_img = Image.fromarray((mask * 255).astype(np.uint8))
    return input_img, mask_img

demo = gr.Interface(
    fn=segment_image,
    inputs=gr.Image(type="pil", label="Input image"),
    outputs=[
        gr.Image(type="pil", label="Original image"),
        gr.Image(type="pil", label="Segmented mask"),
    ],
    title="ResNet-UNet Image Segmentator",
    description="Send an image and see the segmentation result generated by the trained model."
)

if __name__ == "__main__":
    demo.launch(share=True)