segmenter / app.py
schrilax's picture
Update app.py
9cf9edd
import datasets
import gradio as gr
import numpy as np
import torch
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
extractor = AutoFeatureExtractor.from_pretrained('saved_model_files')
labels = {0: 'road/sidewalk/path', 1: 'human', 2: 'vehicles', 3:'other objects', 4:'nature and background'}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SegformerForSemanticSegmentation.from_pretrained('saved_model_files',
num_labels = len(labels),
id2label={str(i) : c for i, c in enumerate(labels)},
label2id={c: str(i) for i, c in enumerate(labels)},
ignore_mismatched_sizes=True)
model.eval()
model.to(device)
def classify(im):
inputs = extractor(images=im, return_tensors='pt').to(device)
outputs = model(**inputs)
logits = outputs.logits
classes = logits[0].detach().cpu().numpy().argmax(axis=0)
colors = np.array([[128,0,0], [128,128,0], [0,0,128], [128,0,128], [0,0,0]])
return colors[classes]
interface = gr.Interface(fn=classify,
inputs='image',
outputs='image')
interface = gr.Interface(fn=classify, inputs='image', outputs='image', examples=['1.png', '2.png', '3.png'], title='Image Segmentation App', description='Perform segmentation on pictures of outdoor scenes', flagging_dir='flagged_examples/') # FILL HERE
interface.launch(debug=True)