import os import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) import torch import numpy as np import gradio as gr from PIL import Image import scripts.config as config import torchvision.transforms as transforms import scripts.Segmentation.augment as augment from scripts.Segmentation.models import ResNetUNet config.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') modelo = '/best_model.pt' model = ResNetUNet(num_classes=2).to(config.device) model.load_state_dict(torch.load(config.checkpoints + modelo, map_location=config.device)) model.eval() transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.Resize((config.height, config.width)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) def segment_image(input_img): input_img = input_img.convert('L') img_tensor = transform(input_img).unsqueeze(0).to(config.device) # (1, 1, H, W) output = None if config.USE_TTA: output = augment.predict_with_tta(model, img_tensor) else: output = model(img_tensor) probs = torch.softmax(output, dim=1) # (1, 2, H, W) mask = torch.argmax(probs, dim=1).squeeze(0) if config.USE_REFINEMENT: mask = augment.refine_mask(mask) mask = mask.cpu().numpy() mask_img = Image.fromarray((mask * 255).astype(np.uint8)) return input_img, mask_img demo = gr.Interface( fn=segment_image, inputs=gr.Image(type="pil", label="Input image"), outputs=[ gr.Image(type="pil", label="Original image"), gr.Image(type="pil", label="Segmented mask"), ], title="ResNet-UNet Image Segmentator", description="Send an image and see the segmentation result generated by the trained model." ) if __name__ == "__main__": demo.launch(share=True)