import torch from torchvision.transforms import v2 import gradio as gr from PIL import Image from colorizer import ColorComicNet, MODEL_CFG from utils import smart_padding, remove_padding # Define the transformation pipeline for the input image TRANSFORM = v2.Compose([ v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.5], std=[0.5]) ]) # Image preprocessing and postprocessing functions def preprocess_image(image: Image.Image, divisor=16): """ Preprocess the input PIL image for the model. """ image = image.convert('RGB') image_tensor = TRANSFORM(image).unsqueeze(0) # Shape: (1, 3, H, W) image_tensor, padding = smart_padding(image_tensor, divisor=divisor) return image_tensor, padding def postprocess_output(output_tensor, padding): """ Postprocess the model output tensor to a PIL image. """ output_tensor = remove_padding(output_tensor, padding) output_tensor = (output_tensor + 1) / 2 # Scale back to [0, 1] output_image = output_tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).numpy() # Shape: (H, W, C) return output_image # Define the colorization function def colorize_image(gray_image: Image.Image): """ Colorize a single grayscale image using the model. """ with torch.no_grad(): # Preprocess input_tensor, padding = preprocess_image(gray_image, divisor=64) # Inference output = model(input_tensor) # Postprocess output_image = postprocess_output(output, padding) return output_image # Initialize the model model = ColorComicNet(**MODEL_CFG) model.load_state_dict(torch.load("./weights/colorizer.pth", map_location=torch.device('cpu'))) model.fuse() model.eval() # Create the Gradio interface with gr.Blocks() as demo: # Header gr.Markdown("# 🎨 Comic Colorization") gr.Markdown("Bring your grayscale comics to life with **ColorComicNet**") with gr.Row(equal_height=True): with gr.Column(scale=1): input_image = gr.Image( label="📥 Upload Grayscale Image", type="pil", ) colorize_button = gr.Button( "✨ Colorize Image", elem_classes="button-primary" ) with gr.Column(scale=1): output_image = gr.Image( label="📤 Colorized Result", type="numpy", ) # Example section gr.Markdown("### 🖼️ Try an example") examples = gr.Examples( examples=[ ["./examples/gray.jpg"], ["./examples/gray_2.jpg"], ["./examples/gray_4.jpg"], ], inputs=input_image ) # Footer gr.Markdown("---") # Interaction colorize_button.click( fn=colorize_image, inputs=input_image, outputs=output_image ) demo.launch()