Harsimran19 commited on
Commit
19c0acc
·
1 Parent(s): 4c3b680

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. app.py +38 -0
  3. examples/input_0.png +0 -0
  4. examples/input_9.png +0 -0
  5. gen.pth.tar +3 -0
  6. model.py +17 -0
  7. requirements.txt +4 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ gen.pth.tar filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+ from model import gen_model
6
+ import torchvision.transforms as T
7
+
8
+ # Model
9
+ gen,transform_gen=gen_model()
10
+ # print(gen)
11
+ to_img=T.ToPILImage()
12
+ # examples=["examples/input_0.png","examples/input_9.png"]
13
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
14
+
15
+ # Predict Function
16
+ def predict(img):
17
+ # Apply Transformations
18
+ img=np.array(img)
19
+ img=transform_gen(image=img)
20
+ img=img['image'].unsqueeze(0)
21
+ # Predict
22
+ gen.eval()
23
+ with torch.inference_mode():
24
+ y_gen=gen(img)
25
+ y_gen=y_gen[0]
26
+ y_gen=to_img(y_gen)
27
+ return y_gen
28
+
29
+ # Gradio App
30
+ title="Pix2Pix GAN"
31
+ description="This is a Sattelite Image to Map converter"
32
+
33
+ demo=gr.Interface(fn=predict,
34
+ inputs=gr.Image(type='pil'),
35
+ outputs=gr.Image(type='pil'),
36
+ title=title ,
37
+ examples=example_list,
38
+ description=description)
examples/input_0.png ADDED
examples/input_9.png ADDED
gen.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cc994a88c4a562a581e7dfdec837ee0fd7fc1a9046baae3b37e1750ed031bba
3
+ size 653073375
model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+
4
+
5
+ def gen_model():
6
+ transform = transforms.Compose([
7
+ transforms.Resize((256, 256)),
8
+ transforms.RandomCrop((224, 224)),
9
+ transforms.RandomHorizontalFlip(),
10
+ transforms.ToTensor(),
11
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
12
+ device="cpu"
13
+ with torch.no_grad():
14
+ model = torch.load('gen.pth.tar', map_location='cpu')
15
+ model = model['state_dict']
16
+ model=model['state_dict']
17
+ return model,transform
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ gradio==3.14.0
4
+ numpy==1.24.2