import gradio as gr import torch from torchvision import transforms from PIL import Image # Path to your exported TorchScript models (.pt) model_paths = { "All colors": "unet_generator.pt", "20 colors only": "20color_generator.pt" } # Check if a GPU is available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Image transformations (resize and convert to tensor) transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), ]) # Function to load the selected model def load_model(path): model = torch.jit.load(path, map_location=device) model.eval() return model # Main colorization function def colorize(image, selected_model): """ Converts the input image to grayscale, displays it, and generates the colorized version using the selected model. """ # Convert to grayscale gray = image.convert("L") # Preprocess for model input gray_tensor = transform(gray).unsqueeze(0).to(device) # Load the selected model model = load_model(model_paths[selected_model]) # Generate the colorized image with torch.no_grad(): output = model(gray_tensor) # Process output and convert to PIL image output = output.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy() output_image = Image.fromarray((output * 255).astype('uint8')) return gray, output_image # Return grayscale and colorized images # Create Gradio interface gr.Interface( fn=colorize, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Radio(choices=["All colors", "20 colors only"], label="Model") ], outputs=[ gr.Image(type="pil", label="Grayscale Image"), gr.Image(type="pil", label="Colorized Image") ], title="Image Colorization", description=( "Upload a color image and choose a model to see it colorized from a grayscale version. " "The system first converts the input image to black and white, then uses a trained deep learning model " "to generate a colorized version. You can experiment with two models: one trained on a full color palette " "and another limited to just 20 colors." ) ).launch()