|
|
import os |
|
|
import sys |
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) |
|
|
import torch |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import scripts.config as config |
|
|
import torchvision.transforms as transforms |
|
|
import scripts.Segmentation.augment as augment |
|
|
from scripts.Segmentation.models import ResNetUNet |
|
|
|
|
|
config.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
modelo = '/best_model.pt' |
|
|
model = ResNetUNet(num_classes=2).to(config.device) |
|
|
model.load_state_dict(torch.load(config.checkpoints + modelo, map_location=config.device)) |
|
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Grayscale(num_output_channels=1), |
|
|
transforms.Resize((config.height, config.width)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5], std=[0.5]) |
|
|
]) |
|
|
|
|
|
def segment_image(input_img): |
|
|
input_img = input_img.convert('L') |
|
|
img_tensor = transform(input_img).unsqueeze(0).to(config.device) |
|
|
|
|
|
output = None |
|
|
if config.USE_TTA: |
|
|
output = augment.predict_with_tta(model, img_tensor) |
|
|
else: |
|
|
output = model(img_tensor) |
|
|
|
|
|
probs = torch.softmax(output, dim=1) |
|
|
mask = torch.argmax(probs, dim=1).squeeze(0) |
|
|
|
|
|
if config.USE_REFINEMENT: |
|
|
mask = augment.refine_mask(mask) |
|
|
|
|
|
mask = mask.cpu().numpy() |
|
|
mask_img = Image.fromarray((mask * 255).astype(np.uint8)) |
|
|
return input_img, mask_img |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=segment_image, |
|
|
inputs=gr.Image(type="pil", label="Input image"), |
|
|
outputs=[ |
|
|
gr.Image(type="pil", label="Original image"), |
|
|
gr.Image(type="pil", label="Segmented mask"), |
|
|
], |
|
|
title="ResNet-UNet Image Segmentator", |
|
|
description="Send an image and see the segmentation result generated by the trained model." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True) |
|
|
|