Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import os | |
| from model import gen_model | |
| from torchvision import transforms | |
| MEAN = (0.5, 0.5, 0.5,) | |
| STD = (0.5, 0.5, 0.5,) | |
| # Model | |
| gen,transform_gen=gen_model() | |
| # print(gen) | |
| # to_img=T.ToPILImage() | |
| # examples=["examples/input_0.png","examples/input_9.png"] | |
| example_list = [["examples/" + example] for example in os.listdir("examples")] | |
| # example_list=['1.jpg','2.jpg'] | |
| # def de_norm(img): | |
| # img_ = img.mul(torch.FloatTensor(STD).view(3, 1, 1)) | |
| # img_ = img_.add(torch.FloatTensor(MEAN).view(3, 1, 1)).detach().numpy() | |
| # img_ = np.transpose(img_, (1, 2, 0)) | |
| # return img_ | |
| inverse_transform = transforms.Compose([ transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]), | |
| transforms.ToPILImage() | |
| ]) | |
| def predict(img): | |
| # Apply Transformations | |
| img = transform_gen(img).unsqueeze(0) | |
| # Predict | |
| gen.eval() | |
| with torch.inference_mode(): | |
| y_gen = gen(img) | |
| y_gen = y_gen[0] | |
| y_gen = inverse_transform(y_gen) | |
| return y_gen | |
| # Gradio App | |
| title="Image Segmentation GAN" | |
| description="This segments a Normal Image" | |
| demo=gr.Interface(fn=predict, | |
| inputs=gr.Image(type='pil'), | |
| outputs=gr.Image(type='pil'), | |
| title=title , | |
| examples=example_list, | |
| description=description) | |
| demo.launch(debug=False) |