Spaces:
Runtime error
Runtime error
File size: 1,386 Bytes
36cfcef e627ca3 fc65bc8 36cfcef 71a3e64 36cfcef 6fbbb75 dcd284c 4c665dd dcd284c 36cfcef b8a8801 dcd284c b8a8801 36cfcef 75554ed 36cfcef 5ea9633 36cfcef 57f85b0 09ce3ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import gradio as gr
import torch
import numpy as np
import os
from model import gen_model
from torchvision import transforms
MEAN = (0.5, 0.5, 0.5,)
STD = (0.5, 0.5, 0.5,)
# Model
gen,transform_gen=gen_model()
# print(gen)
# to_img=T.ToPILImage()
# examples=["examples/input_0.png","examples/input_9.png"]
example_list = [["examples/" + example] for example in os.listdir("examples")]
# example_list=['1.jpg','2.jpg']
# def de_norm(img):
# img_ = img.mul(torch.FloatTensor(STD).view(3, 1, 1))
# img_ = img_.add(torch.FloatTensor(MEAN).view(3, 1, 1)).detach().numpy()
# img_ = np.transpose(img_, (1, 2, 0))
# return img_
inverse_transform = transforms.Compose([ transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
transforms.ToPILImage()
])
def predict(img):
# Apply Transformations
img = transform_gen(img).unsqueeze(0)
# Predict
gen.eval()
with torch.inference_mode():
y_gen = gen(img)
y_gen = y_gen[0]
y_gen = inverse_transform(y_gen)
return y_gen
# Gradio App
title="Image Segmentation GAN"
description="This segments a Normal Image"
demo=gr.Interface(fn=predict,
inputs=gr.Image(type='pil'),
outputs=gr.Image(type='pil'),
title=title ,
examples=example_list,
description=description)
demo.launch(debug=False) |