Pix2Pix-Maps / app.py
Harsimran19's picture
Update app.py
9f23747
raw
history blame contribute delete
956 Bytes
import gradio as gr
import torch
import numpy as np
import os
from model import gen_model
import torchvision.transforms as T
# Model
gen,transform_gen=gen_model()
# print(gen)
to_img=T.ToPILImage()
# examples=["examples/input_0.png","examples/input_9.png"]
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Predict Function
def predict(img):
# Apply Transformations
# img=np.array(img)
img=transform_gen(img).unsqueeze(0)
# Predict
gen.eval()
with torch.inference_mode():
y_gen=gen(img)
y_gen=y_gen[0]
y_gen=to_img(y_gen)
return y_gen
# Gradio App
title="Satellite-to-Map GAN"
description="This is a Sattelite Image to Map converter"
demo=gr.Interface(fn=predict,
inputs=gr.Image(type='pil'),
outputs=gr.Image(type='pil'),
title=title ,
examples=example_list,
description=description)
demo.launch(debug=False)