Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| from colorization_model import ColorizationModel # Import your model class | |
| # Load the trained generator model | |
| model_path = "generator.pth" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Define model options (replace with your configuration) | |
| class Options: | |
| input_nc = 1 | |
| output_nc = 2 | |
| ngf = 64 | |
| netG = "unet_256" | |
| norm = "batch" | |
| no_dropout = False | |
| init_type = "normal" | |
| init_gain = 0.02 | |
| gpu_ids = [0] if torch.cuda.is_available() else [] | |
| opt = Options() | |
| generator = ColorizationModel(opt).netG | |
| generator.load_state_dict(torch.load(model_path, map_location=device)) | |
| generator.eval().to(device) | |
| # Define preprocessing and postprocessing steps | |
| def preprocess_image(image): | |
| transform = transforms.Compose([ | |
| transforms.Grayscale(num_output_channels=1), | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5], std=[0.5]) | |
| ]) | |
| return transform(image).unsqueeze(0).to(device) | |
| def postprocess_image(output): | |
| output = output.squeeze(0).cpu().detach() | |
| output = torch.cat([output[0:1, :, :] * 50.0 + 50.0, output[1:, :, :] * 110.0], dim=0) | |
| output_image = transforms.ToPILImage()(output) | |
| return output_image | |
| # Gradio interface function | |
| def colorize(grayscale_image): | |
| input_tensor = preprocess_image(grayscale_image) | |
| with torch.no_grad(): | |
| colorized = generator(input_tensor) | |
| return postprocess_image(colorized) | |
| # Define Gradio interface | |
| interface = gr.Interface( | |
| fn=colorize, | |
| inputs=gr.Image(type="pil", label="Grayscale Image"), | |
| outputs=gr.Image(type="pil", label="Colorized Image"), | |
| title="Pix2Pix Image Colorization", | |
| description="Upload a grayscale image, and the model will colorize it using Pix2Pix GAN." | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| interface.launch() | |