Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| import torchvision.transforms as transforms | |
| from torchvision.transforms.functional import resize | |
| import albumentations as A | |
| from segmentation_models_pytorch import Unet | |
| from typing import Tuple, List | |
| import os | |
| from glob import glob | |
| from custom_unet import CustomUnet | |
| class GradioApp: | |
| def __init__(self) -> None: | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| custom = CustomUnet(in_channels=3, depth=3, start_channels=16).to(self.device).eval() | |
| custom.load_state_dict(torch.load(os.path.join('models', 'custom_unet.pt'), map_location=self.device, weights_only=False)) | |
| pretrained = Unet(encoder_name='timm-efficientnet-b0', in_channels=3, encoder_depth=5, classes=1).to(self.device).eval() | |
| pretrained.load_state_dict(torch.load(os.path.join('models', 'pretrained_unet.pt'), map_location=self.device, weights_only=False)) | |
| self.models = {'Custom': custom, 'Pretrained': pretrained} | |
| self.transform = A.Compose(transforms=[A.Resize(320, 320)]) | |
| def predict(self, img_file: str, model_name: str) -> Tuple[str, List[Tuple[np.ndarray, str]]]: | |
| image = np.asarray(Image.open(img_file)) | |
| h,w = image.shape[:-1] | |
| image = torch.from_numpy(self.transform(image=image)['image']).float().permute(2,0,1) / 255. | |
| with torch.inference_mode(): | |
| prediction = self.models[model_name](image.to(self.device).unsqueeze(0))[0].sigmoid().round().cpu() | |
| mask = resize(img=prediction, size=(h,w), interpolation=transforms.InterpolationMode.NEAREST)[0].numpy() | |
| return img_file, [(mask, 'person')] | |
| def launch(self): | |
| demo = gr.Interface( | |
| fn=self.predict, | |
| inputs=[ | |
| gr.Image(type='filepath', label='Input image to segment'), | |
| gr.Radio(choices=('Custom', 'Pretrained'), label='Available models', value='Custom') | |
| ], | |
| outputs=gr.AnnotatedImage(label='Model predictions'), | |
| examples=[[example_path] for example_path in glob('examples/*.jpg')], | |
| cache_examples=False, | |
| title='Person Segmentation', | |
| description=f'This model performs segmentation on people in images. A Unet neural network architecture is used. \ | |
| The dataset can be found [here](https://github.com/VikramShenoy97/Human-Segmentation-Dataset) \ | |
| and the source code is on [GitHub](https://github.com/i4ata/UnetSegmentation).' | |
| ) | |
| demo.launch() | |
| if __name__ == '__main__': | |
| app = GradioApp() | |
| app.launch() | |