# import gradio as gr # import torch # from torchvision import transforms # from PIL import Image # from generator2 import UNetGeneratorImproved # # Rutas de los modelos, todos los colores y 20 colores # all_colors_model_path = "/home/maripau/Documents/ITESO/Semestre6/Deep/COLOURING/APP/Coloring2/unet_generator.pth" # limited_colors_model_path = "/home/maripau/Documents/ITESO/Semestre6/Deep/COLOURING/APP/Coloring2/20color_generator.pth" # # Selección de dispositivo (GPU si está disponible) # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # # Transformación de la imagen para el modelo # transform = transforms.Compose([ # transforms.Resize((512, 512)), # transforms.ToTensor(), # ]) # def load_model(path): # """ # Carga el modelo desde una ruta específica. # """ # model = UNetGeneratorImproved().to(device) # model.load_state_dict(torch.load(path, map_location=device)) # model.eval() # return model # def process_image(image, model_type): # """ # Convierte la imagen a escala de grises, la muestra, # y luego genera una versión colorizada usando el modelo elegido. # """ # # Convertir a escala de grises # gray_image = image.convert("L").convert("RGB") # convertimos de L (1 canal) a RGB (3 canales) para que PIL la pueda mostrar bien # # Seleccionar el modelo según el tipo # model_path = all_colors_model_path if model_type == "Todos los colores" else limited_colors_model_path # generator = load_model(model_path) # # Preprocesar la imagen en escala de grises para el modelo # gray_tensor = transform(gray_image.convert("L")).unsqueeze(0).to(device) # # Generar la imagen colorizada # with torch.no_grad(): # output = generator(gray_tensor) # output = output.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy() # colorized_image = Image.fromarray((output * 255).astype('uint8')) # # Devolver la imagen en escala de grises y la colorizada # return gray_image, colorized_image # # Interfaz con dos salidas: escala de grises y color # interface = gr.Interface( # fn=process_image, # inputs=[ # gr.Image(type="pil", label="Imagen original"), # gr.Dropdown( # choices=["Todos los colores", "20 colores"], # value="Todos los colores", # label="Selecciona el modelo" # ) # ], # outputs=[ # gr.Image(type="pil", label="Imagen en blanco y negro"), # gr.Image(type="pil", label="Imagen colorizada") # ], # title="Colorizador de Imágenes" # ) # interface.launch() import gradio as gr import torch from torchvision import transforms from PIL import Image # Ruta a tus modelos exportados como TorchScript (.pt) model_paths = { "Todos los colores": "unet_generator.pt", "Solo 20 colores": "20color_generator.pt" } # Verifica si hay GPU disponible device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Transformaciones de la imagen (redimensionar y convertir a tensor) transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), ]) # Función que carga el modelo seleccionado def load_model(path): model = torch.jit.load(path, map_location=device) model.eval() return model # Función principal de colorización def colorize(image, modelo_seleccionado): """ Convierte la imagen a escala de grises, la muestra, y genera la imagen colorizada con el modelo seleccionado. """ # Mostrar la imagen en blanco y negro gray = image.convert("L") # Preprocesar para el modelo gray_tensor = transform(gray).unsqueeze(0).to(device) # Cargar el modelo según la selección model = load_model(model_paths[modelo_seleccionado]) # Generar la imagen colorizada with torch.no_grad(): output = model(gray_tensor) # Procesar salida y convertir a imagen PIL output = output.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy() output_image = Image.fromarray((output * 255).astype('uint8')) return gray, output_image # Regresa la imagen gris y la colorizada # Crear interfaz con Gradio gr.Interface( fn=colorize, inputs=[ gr.Image(type="pil", label="Imagen de entrada"), gr.Radio(choices=["Todos los colores", "Solo 20 colores"], label="Modelo") ], outputs=[ gr.Image(type="pil", label="Imagen en blanco y negro"), gr.Image(type="pil", label="Imagen colorizada") ], title="Colorización de Imágenes" ).launch()