coloring / app2.py
maripau22's picture
eleccion entre dos modelos
ad4b2b3
# import gradio as gr
# import torch
# from torchvision import transforms
# from PIL import Image
# from generator2 import UNetGeneratorImproved
# # Rutas de los modelos, todos los colores y 20 colores
# all_colors_model_path = "/home/maripau/Documents/ITESO/Semestre6/Deep/COLOURING/APP/Coloring2/unet_generator.pth"
# limited_colors_model_path = "/home/maripau/Documents/ITESO/Semestre6/Deep/COLOURING/APP/Coloring2/20color_generator.pth"
# # Selección de dispositivo (GPU si está disponible)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # Transformación de la imagen para el modelo
# transform = transforms.Compose([
# transforms.Resize((512, 512)),
# transforms.ToTensor(),
# ])
# def load_model(path):
# """
# Carga el modelo desde una ruta específica.
# """
# model = UNetGeneratorImproved().to(device)
# model.load_state_dict(torch.load(path, map_location=device))
# model.eval()
# return model
# def process_image(image, model_type):
# """
# Convierte la imagen a escala de grises, la muestra,
# y luego genera una versión colorizada usando el modelo elegido.
# """
# # Convertir a escala de grises
# gray_image = image.convert("L").convert("RGB") # convertimos de L (1 canal) a RGB (3 canales) para que PIL la pueda mostrar bien
# # Seleccionar el modelo según el tipo
# model_path = all_colors_model_path if model_type == "Todos los colores" else limited_colors_model_path
# generator = load_model(model_path)
# # Preprocesar la imagen en escala de grises para el modelo
# gray_tensor = transform(gray_image.convert("L")).unsqueeze(0).to(device)
# # Generar la imagen colorizada
# with torch.no_grad():
# output = generator(gray_tensor)
# output = output.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy()
# colorized_image = Image.fromarray((output * 255).astype('uint8'))
# # Devolver la imagen en escala de grises y la colorizada
# return gray_image, colorized_image
# # Interfaz con dos salidas: escala de grises y color
# interface = gr.Interface(
# fn=process_image,
# inputs=[
# gr.Image(type="pil", label="Imagen original"),
# gr.Dropdown(
# choices=["Todos los colores", "20 colores"],
# value="Todos los colores",
# label="Selecciona el modelo"
# )
# ],
# outputs=[
# gr.Image(type="pil", label="Imagen en blanco y negro"),
# gr.Image(type="pil", label="Imagen colorizada")
# ],
# title="Colorizador de Imágenes"
# )
# interface.launch()
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
# Ruta a tus modelos exportados como TorchScript (.pt)
model_paths = {
"Todos los colores": "unet_generator.pt",
"Solo 20 colores": "20color_generator.pt"
}
# Verifica si hay GPU disponible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Transformaciones de la imagen (redimensionar y convertir a tensor)
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
])
# Función que carga el modelo seleccionado
def load_model(path):
model = torch.jit.load(path, map_location=device)
model.eval()
return model
# Función principal de colorización
def colorize(image, modelo_seleccionado):
"""
Convierte la imagen a escala de grises, la muestra, y genera la imagen colorizada
con el modelo seleccionado.
"""
# Mostrar la imagen en blanco y negro
gray = image.convert("L")
# Preprocesar para el modelo
gray_tensor = transform(gray).unsqueeze(0).to(device)
# Cargar el modelo según la selección
model = load_model(model_paths[modelo_seleccionado])
# Generar la imagen colorizada
with torch.no_grad():
output = model(gray_tensor)
# Procesar salida y convertir a imagen PIL
output = output.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy()
output_image = Image.fromarray((output * 255).astype('uint8'))
return gray, output_image # Regresa la imagen gris y la colorizada
# Crear interfaz con Gradio
gr.Interface(
fn=colorize,
inputs=[
gr.Image(type="pil", label="Imagen de entrada"),
gr.Radio(choices=["Todos los colores", "Solo 20 colores"], label="Modelo")
],
outputs=[
gr.Image(type="pil", label="Imagen en blanco y negro"),
gr.Image(type="pil", label="Imagen colorizada")
],
title="Colorización de Imágenes"
).launch()