devoppro commited on
Commit
33fa605
·
verified ·
1 Parent(s): cf3c3e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -0
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from torchvision.utils import save_image
4
+ from torchvision.transforms import ToPILImage
5
+ from model import Generator
6
+
7
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
+ NOISE_DIM = 256
9
+
10
+ G = Generator().to(DEVICE)
11
+ G.load_state_dict(torch.load("generator.pth", map_location=DEVICE))
12
+ G.eval()
13
+
14
+ to_pil = ToPILImage()
15
+
16
+ def generate_image():
17
+ noise = torch.randn(1, NOISE_DIM).to(DEVICE)
18
+
19
+ with torch.no_grad():
20
+ image = G(noise)
21
+
22
+ image = (image + 1) / 2
23
+ return to_pil(image.squeeze(0))
24
+
25
+ demo = gr.Interface(
26
+ fn=generate_image,
27
+ inputs=None,
28
+ outputs=gr.Image(),
29
+ title="GAN Image Generator API"
30
+ )
31
+
32
+ demo.launch()