SegmentationGAN / app.py
Harsimran19's picture
Update app.py
dcd284c
raw
history blame
1.75 kB
import gradio as gr
import torch
import numpy as np
import os
from model import gen_model
import torchvision.transforms as T
MEAN = (0.5, 0.5, 0.5,)
STD = (0.5, 0.5, 0.5,)
# Model
gen,transform_gen=gen_model()
# print(gen)
to_img=T.ToPILImage()
# examples=["examples/input_0.png","examples/input_9.png"]
example_list = [["examples/" + example] for example in os.listdir("examples")]
# example_list=['1.jpg','2.jpg']
# def de_norm(img):
# img_ = img.mul(torch.FloatTensor(STD).view(3, 1, 1))
# img_ = img_.add(torch.FloatTensor(MEAN).view(3, 1, 1)).detach().numpy()
# img_ = np.transpose(img_, (1, 2, 0))
# return img_
inverse_transform = transforms.Compose([ transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1/0.5, 1/0.5, 1/0.5]),
transforms.Normalize(mean=[-transform.mean[0]/transform.std[0],
-transform.mean[1]/transform.std[1],
-transform.mean[2]/transform.std[2]],
std=[1/transform.std[0],
1/transform.std[1],
1/transform.std[2]]),
transforms.ToPILImage()
])
def predict(img):
# Apply Transformations
img = transform_gen(img).unsqueeze(0)
# Predict
gen.eval()
with torch.inference_mode():
y_gen = gen(img)
y_gen = y_gen[0]
y_gen = inverse_transform(y_gen)
return y_gen
# Gradio App
title="Image Segmentation GAN"
description="This segments a Normal Image"
demo=gr.Interface(fn=predict,
inputs=gr.Image(type='pil'),
outputs=gr.Image(type='pil'),
title=title ,
examples=example_list,
description=description)
demo.launch(debug=False)