Coloring2 / app.py
maripau22's picture
espero que ya color2
70764fd
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
# Path to your exported TorchScript models (.pt)
model_paths = {
"All colors": "unet_generator.pt",
"20 colors only": "20color_generator.pt"
}
# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Image transformations (resize and convert to tensor)
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
])
# Function to load the selected model
def load_model(path):
model = torch.jit.load(path, map_location=device)
model.eval()
return model
# Main colorization function
def colorize(image, selected_model):
"""
Converts the input image to grayscale, displays it,
and generates the colorized version using the selected model.
"""
# Convert to grayscale
gray = image.convert("L")
# Preprocess for model input
gray_tensor = transform(gray).unsqueeze(0).to(device)
# Load the selected model
model = load_model(model_paths[selected_model])
# Generate the colorized image
with torch.no_grad():
output = model(gray_tensor)
# Process output and convert to PIL image
output = output.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy()
output_image = Image.fromarray((output * 255).astype('uint8'))
return gray, output_image # Return grayscale and colorized images
# Create Gradio interface
gr.Interface(
fn=colorize,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Radio(choices=["All colors", "20 colors only"], label="Model")
],
outputs=[
gr.Image(type="pil", label="Grayscale Image"),
gr.Image(type="pil", label="Colorized Image")
],
title="Image Colorization",
description=(
"Upload a color image and choose a model to see it colorized from a grayscale version. "
"The system first converts the input image to black and white, then uses a trained deep learning model "
"to generate a colorized version. You can experiment with two models: one trained on a full color palette "
"and another limited to just 20 colors."
)
).launch()