File size: 1,386 Bytes
36cfcef
 
 
 
 
e627ca3
fc65bc8
 
36cfcef
 
 
71a3e64
36cfcef
6fbbb75
 
dcd284c
 
 
 
 
4c665dd
dcd284c
 
36cfcef
b8a8801
 
 
 
 
 
 
 
dcd284c
b8a8801
 
 
36cfcef
75554ed
36cfcef
5ea9633
 
36cfcef
 
 
 
 
57f85b0
09ce3ca
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
import torch
import numpy as np
import os
from model import gen_model
from torchvision import transforms
MEAN = (0.5, 0.5, 0.5,)
STD = (0.5, 0.5, 0.5,)
# Model
gen,transform_gen=gen_model()
# print(gen)
# to_img=T.ToPILImage()
# examples=["examples/input_0.png","examples/input_9.png"]
example_list = [["examples/" + example] for example in os.listdir("examples")]
# example_list=['1.jpg','2.jpg']
# def de_norm(img):
#     img_ = img.mul(torch.FloatTensor(STD).view(3, 1, 1))
#     img_ = img_.add(torch.FloatTensor(MEAN).view(3, 1, 1)).detach().numpy()
#     img_ = np.transpose(img_, (1, 2, 0))
#     return img_ 
inverse_transform = transforms.Compose([  transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
    transforms.ToPILImage()
])
def predict(img):
    # Apply Transformations
    img = transform_gen(img).unsqueeze(0)

    # Predict
    gen.eval()
    with torch.inference_mode():
        y_gen = gen(img)
    y_gen = y_gen[0]
    y_gen = inverse_transform(y_gen)

    return y_gen



# Gradio App
title="Image Segmentation GAN"
description="This segments a Normal Image"

demo=gr.Interface(fn=predict,
                  inputs=gr.Image(type='pil'),
                  outputs=gr.Image(type='pil'),
                  title=title ,
                  examples=example_list,
                  description=description)
demo.launch(debug=False)