maps_with_gans / app.py
ch0fas
en esta casa no somos fans de la comunidad hispanohablante, al parecer
8fc4c59
Raw
History Blame Contribute Delete
2.7 kB
import torch
import numpy as np
from PIL import Image
import gradio as gr
from models import networks
# Definiendo valores pre-determinados
DEVICE = "cpu"
INPUT_NC = 3
OUTPUT_NC = 3
NFG=64
NETG = 'resnet_9blocks'
NORM = 'instance'
USE_DROPOUT = False
INIT_TYPE = 'normal'
INIT_GAIN = 0.02
# Cargando Modelos pre-entrenados
oam_g = networks.define_G(
input_nc=INPUT_NC,
output_nc=OUTPUT_NC,
ngf=NFG,
netG=NETG,
norm=NORM,
use_dropout=USE_DROPOUT,
init_type=INIT_TYPE,
init_gain=INIT_GAIN,
)
oam_g.load_state_dict(
torch.load("models/gen_OAM.pth",
map_location=DEVICE,
weights_only=True)
)
gm_g = networks.define_G(
input_nc=INPUT_NC,
output_nc=OUTPUT_NC,
ngf=NFG,
netG=NETG,
norm=NORM,
use_dropout=USE_DROPOUT,
init_type=INIT_TYPE,
init_gain=INIT_GAIN,
)
gm_g.load_state_dict(
torch.load("models/gen_googleMaps.pth",
map_location=DEVICE,
weights_only=True)
)
sen_g = networks.define_G(
input_nc=INPUT_NC,
output_nc=OUTPUT_NC,
ngf=NFG,
netG=NETG,
norm=NORM,
use_dropout=USE_DROPOUT,
init_type=INIT_TYPE,
init_gain=INIT_GAIN,
)
sen_g.load_state_dict(
torch.load("models/gen_sentinel.pth",
map_location=DEVICE,
weights_only=True)
)
# Procesamiento de Imagen para modelo
def preprocess(img):
img = img.convert("RGB").resize((128, 128))
img = np.array(img).astype("float32") / 127.5 - 1.0
img = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
return img
def postprocess(tensor):
img = tensor.squeeze().permute(1, 2, 0).cpu().numpy()
img = (img + 1) * 127.5
return img.astype("uint8")
# Generando nueva imagen, en cada modelo
def translate(image):
x = preprocess(image)
with torch.no_grad():
y_oam = oam_g(x)
y_gmaps = gm_g(x)
y_sen = sen_g(x)
return [Image.fromarray(postprocess(y_oam)), Image.fromarray(postprocess(y_gmaps)), Image.fromarray(postprocess(y_sen))]
# Interfaz de Gradio
demo = gr.Interface(
fn=translate,
inputs=gr.Image(type="pil", height=900, width=900),
outputs=[gr.Image(type="pil", height=300, width=300, label="OAM Model"), gr.Image(type="pil", height=300, width=300, label="Google Maps Model"), gr.Image(type="pil", height=300, width=300, label="Sentinel Model")],
title="CycleGAN - Gans For Maps",
description="Upload an image to turn it into a fantasy map!\n\nAfter entering an image, it can be stylized with 3 different models\n\n**WARNING:** The sentinel model works better for images with little zoom.\n\nSofía Maldonado, Vivienne Toledo & Oscar Josué Rocha",
)
demo.launch()