import gradio as gr from PIL import Image import torch import numpy as np import torchvision.transforms as transforms from torchvision.transforms.functional import resize import albumentations as A from segmentation_models_pytorch import Unet from typing import Tuple, List import os from glob import glob from custom_unet import CustomUnet class GradioApp: def __init__(self) -> None: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' custom = CustomUnet(in_channels=3, depth=3, start_channels=16).to(self.device).eval() custom.load_state_dict(torch.load(os.path.join('models', 'custom_unet.pt'), map_location=self.device, weights_only=False)) pretrained = Unet(encoder_name='timm-efficientnet-b0', in_channels=3, encoder_depth=5, classes=1).to(self.device).eval() pretrained.load_state_dict(torch.load(os.path.join('models', 'pretrained_unet.pt'), map_location=self.device, weights_only=False)) self.models = {'Custom': custom, 'Pretrained': pretrained} self.transform = A.Compose(transforms=[A.Resize(320, 320)]) def predict(self, img_file: str, model_name: str) -> Tuple[str, List[Tuple[np.ndarray, str]]]: image = np.asarray(Image.open(img_file)) h,w = image.shape[:-1] image = torch.from_numpy(self.transform(image=image)['image']).float().permute(2,0,1) / 255. with torch.inference_mode(): prediction = self.models[model_name](image.to(self.device).unsqueeze(0))[0].sigmoid().round().cpu() mask = resize(img=prediction, size=(h,w), interpolation=transforms.InterpolationMode.NEAREST)[0].numpy() return img_file, [(mask, 'person')] def launch(self): demo = gr.Interface( fn=self.predict, inputs=[ gr.Image(type='filepath', label='Input image to segment'), gr.Radio(choices=('Custom', 'Pretrained'), label='Available models', value='Custom') ], outputs=gr.AnnotatedImage(label='Model predictions'), examples=[[example_path] for example_path in glob('examples/*.jpg')], cache_examples=False, title='Person Segmentation', description=f'This model performs segmentation on people in images. A Unet neural network architecture is used. \ The dataset can be found [here](https://github.com/VikramShenoy97/Human-Segmentation-Dataset) \ and the source code is on [GitHub](https://github.com/i4ata/UnetSegmentation).' ) demo.launch() if __name__ == '__main__': app = GradioApp() app.launch()