Harsimran19 commited on
Commit
36cfcef
·
1 Parent(s): f2425b7

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +38 -0
  2. model.py +17 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+ from model import gen_model
6
+ import torchvision.transforms as T
7
+
8
+ # Model
9
+ gen,transform_gen=gen_model()
10
+ # print(gen)
11
+ to_img=T.ToPILImage()
12
+ # examples=["examples/input_0.png","examples/input_9.png"]
13
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
14
+
15
+ # Predict Function
16
+ def predict(img):
17
+ # Apply Transformations
18
+ img=np.array(img)
19
+ img=transform_gen(image=img)
20
+ img=img['image'].unsqueeze(0)
21
+ # Predict
22
+ gen.eval()
23
+ with torch.inference_mode():
24
+ y_gen=gen(img)
25
+ y_gen=y_gen[0]
26
+ y_gen=to_img(y_gen)
27
+ return y_gen
28
+
29
+ # Gradio App
30
+ title="Pix2Pix GAN"
31
+ description="This is a Sattelite Image to Map converter"
32
+
33
+ demo=gr.Interface(fn=predict,
34
+ inputs=gr.Image(type='pil'),
35
+ outputs=gr.Image(type='pil'),
36
+ title=title ,
37
+ examples=example_list,
38
+ description=description)
model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # from torchvision import transforms
3
+
4
+
5
+ def gen_model():
6
+ # transform = transforms.Compose([
7
+ # transforms.Resize((256, 256)),
8
+ # transforms.RandomCrop((224, 224)),
9
+ # transforms.RandomHorizontalFlip(),
10
+ # transforms.ToTensor(),
11
+ # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
12
+ device="cpu"
13
+ with torch.no_grad():
14
+ model = torch.load('gen.pth.tar', map_location='cpu')
15
+ model = model['state_dict']
16
+ model=model['state_dict']
17
+ return model,transform
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ gradio==3.14.0
4
+ numpy==1.24.2