| import gradio as gr |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| import torchvision.transforms as transforms |
| import numpy as np |
| from huggingface_hub import hf_hub_download |
| import os |
|
|
| |
| from networks import ResnetGenerator |
|
|
| class CycleGANInference: |
| def __init__(self, model_repo_id, checkpoint_filename_AtoB, checkpoint_filename_BtoA=None): |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| |
| checkpoint_path_AtoB = hf_hub_download( |
| repo_id=model_repo_id, |
| filename=checkpoint_filename_AtoB |
| ) |
| |
| |
| |
| self.netG_A2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) |
| |
| if checkpoint_filename_BtoA: |
| checkpoint_path_BtoA = hf_hub_download( |
| repo_id=model_repo_id, |
| filename=checkpoint_filename_BtoA |
| ) |
| self.netG_B2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) |
| else: |
| self.netG_B2A = None |
| |
| |
| self.netG_A2B.load_state_dict(torch.load(checkpoint_path_AtoB, map_location=self.device)) |
| if self.netG_B2A and checkpoint_filename_BtoA: |
| self.netG_B2A.load_state_dict(torch.load(checkpoint_path_BtoA, map_location=self.device)) |
| |
| |
| self.netG_A2B.eval() |
| if self.netG_B2A: |
| self.netG_B2A.eval() |
| |
| |
| self.netG_A2B.to(self.device) |
| if self.netG_B2A: |
| self.netG_B2A.to(self.device) |
| |
| |
| self.transform = transforms.Compose([ |
| transforms.Resize((256, 256)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| ]) |
| |
| self.inverse_transform = transforms.Compose([ |
| transforms.Normalize((-1, -1, -1), (2, 2, 2)), |
| transforms.ToPILImage() |
| ]) |
| |
| def transform_image(self, image, direction="A_to_B"): |
| |
| input_tensor = self.transform(image).unsqueeze(0).to(self.device) |
| |
| with torch.no_grad(): |
| if direction == "A_to_B": |
| output_tensor = self.netG_A2B(input_tensor) |
| elif direction == "B_to_A" and self.netG_B2A: |
| output_tensor = self.netG_B2A(input_tensor) |
| else: |
| raise ValueError("Invalid direction or model not available") |
| |
| |
| output_image = self.inverse_transform(output_tensor.squeeze(0).cpu()) |
| return output_image |
|
|
| |
| |
| MODEL_REPO_ID = "profmatthew/decgan" |
| CHECKPOINT_A2B = "200_net_G_A.pth" |
| CHECKPOINT_B2A = "200_net_G_B.pth" |
|
|
| cyclegan_model = CycleGANInference( |
| model_repo_id=MODEL_REPO_ID, |
| checkpoint_filename_AtoB=CHECKPOINT_A2B, |
| checkpoint_filename_BtoA=CHECKPOINT_B2A |
| ) |
|
|
| def generate_image(input_image, direction): |
| try: |
| output_image = cyclegan_model.transform_image(input_image, direction) |
| return output_image |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
| |
| with gr.Blocks(title="CycleGAN Image Translation") as demo: |
| gr.Markdown("# CycleGAN Image Translation") |
| gr.Markdown("Upload an image and select the transformation direction.") |
| |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image(type="pil", label="Input Image") |
| direction = gr.Dropdown( |
| choices=["A_to_B", "B_to_A"], |
| value="A_to_B", |
| label="Translation Direction" |
| ) |
| generate_btn = gr.Button("Generate", variant="primary") |
| |
| with gr.Column(): |
| output_image = gr.Image(type="pil", label="Generated Image") |
| |
| generate_btn.click( |
| fn=generate_image, |
| inputs=[input_image, direction], |
| outputs=output_image |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if __name__ == "__main__": |
| demo.launch() |