import gradio as gr import torch import numpy as np import os from model import gen_model from torchvision import transforms MEAN = (0.5, 0.5, 0.5,) STD = (0.5, 0.5, 0.5,) # Model gen,transform_gen=gen_model() # print(gen) # to_img=T.ToPILImage() # examples=["examples/input_0.png","examples/input_9.png"] example_list = [["examples/" + example] for example in os.listdir("examples")] # example_list=['1.jpg','2.jpg'] # def de_norm(img): # img_ = img.mul(torch.FloatTensor(STD).view(3, 1, 1)) # img_ = img_.add(torch.FloatTensor(MEAN).view(3, 1, 1)).detach().numpy() # img_ = np.transpose(img_, (1, 2, 0)) # return img_ inverse_transform = transforms.Compose([ transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]), transforms.ToPILImage() ]) def predict(img): # Apply Transformations img = transform_gen(img).unsqueeze(0) # Predict gen.eval() with torch.inference_mode(): y_gen = gen(img) y_gen = y_gen[0] y_gen = inverse_transform(y_gen) return y_gen # Gradio App title="Image Segmentation GAN" description="This segments a Normal Image" demo=gr.Interface(fn=predict, inputs=gr.Image(type='pil'), outputs=gr.Image(type='pil'), title=title , examples=example_list, description=description) demo.launch(debug=False)