Spaces:
Sleeping
Sleeping
File size: 2,656 Bytes
671b160 4fcc913 ed1f711 4fcc913 ed1f711 671b160 fb9b166 4fcc913 ed1f711 4fcc913 ed1f711 4fcc913 ed1f711 671b160 ed1f711 4fcc913 ed1f711 4fcc913 671b160 ed1f711 671b160 ed1f711 671b160 b539cbc 7e91fbc ed1f711 671b160 7e91fbc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | import gradio as gr
from PIL import Image
import torch
import numpy as np
import torchvision.transforms as transforms
from torchvision.transforms.functional import resize
import albumentations as A
from segmentation_models_pytorch import Unet
from typing import Tuple, List
import os
from glob import glob
from custom_unet import CustomUnet
class GradioApp:
def __init__(self) -> None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
custom = CustomUnet(in_channels=3, depth=3, start_channels=16).to(self.device).eval()
custom.load_state_dict(torch.load(os.path.join('models', 'custom_unet.pt'), map_location=self.device, weights_only=False))
pretrained = Unet(encoder_name='timm-efficientnet-b0', in_channels=3, encoder_depth=5, classes=1).to(self.device).eval()
pretrained.load_state_dict(torch.load(os.path.join('models', 'pretrained_unet.pt'), map_location=self.device, weights_only=False))
self.models = {'Custom': custom, 'Pretrained': pretrained}
self.transform = A.Compose(transforms=[A.Resize(320, 320)])
def predict(self, img_file: str, model_name: str) -> Tuple[str, List[Tuple[np.ndarray, str]]]:
image = np.asarray(Image.open(img_file))
h,w = image.shape[:-1]
image = torch.from_numpy(self.transform(image=image)['image']).float().permute(2,0,1) / 255.
with torch.inference_mode():
prediction = self.models[model_name](image.to(self.device).unsqueeze(0))[0].sigmoid().round().cpu()
mask = resize(img=prediction, size=(h,w), interpolation=transforms.InterpolationMode.NEAREST)[0].numpy()
return img_file, [(mask, 'person')]
def launch(self):
demo = gr.Interface(
fn=self.predict,
inputs=[
gr.Image(type='filepath', label='Input image to segment'),
gr.Radio(choices=('Custom', 'Pretrained'), label='Available models', value='Custom')
],
outputs=gr.AnnotatedImage(label='Model predictions'),
examples=[[example_path] for example_path in glob('examples/*.jpg')],
cache_examples=False,
title='Person Segmentation',
description=f'This model performs segmentation on people in images. A Unet neural network architecture is used. \
The dataset can be found [here](https://github.com/VikramShenoy97/Human-Segmentation-Dataset) \
and the source code is on [GitHub](https://github.com/i4ata/UnetSegmentation).'
)
demo.launch()
if __name__ == '__main__':
app = GradioApp()
app.launch()
|