File size: 2,216 Bytes
70764fd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
# Path to your exported TorchScript models (.pt)
model_paths = {
"All colors": "unet_generator.pt",
"20 colors only": "20color_generator.pt"
}
# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Image transformations (resize and convert to tensor)
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
])
# Function to load the selected model
def load_model(path):
model = torch.jit.load(path, map_location=device)
model.eval()
return model
# Main colorization function
def colorize(image, selected_model):
"""
Converts the input image to grayscale, displays it,
and generates the colorized version using the selected model.
"""
# Convert to grayscale
gray = image.convert("L")
# Preprocess for model input
gray_tensor = transform(gray).unsqueeze(0).to(device)
# Load the selected model
model = load_model(model_paths[selected_model])
# Generate the colorized image
with torch.no_grad():
output = model(gray_tensor)
# Process output and convert to PIL image
output = output.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy()
output_image = Image.fromarray((output * 255).astype('uint8'))
return gray, output_image # Return grayscale and colorized images
# Create Gradio interface
gr.Interface(
fn=colorize,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Radio(choices=["All colors", "20 colors only"], label="Model")
],
outputs=[
gr.Image(type="pil", label="Grayscale Image"),
gr.Image(type="pil", label="Colorized Image")
],
title="Image Colorization",
description=(
"Upload a color image and choose a model to see it colorized from a grayscale version. "
"The system first converts the input image to black and white, then uses a trained deep learning model "
"to generate a colorized version. You can experiment with two models: one trained on a full color palette "
"and another limited to just 20 colors."
)
).launch()
|