|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from PIL import Image |
|
|
import torchvision.transforms as transforms |
|
|
import numpy as np |
|
|
from huggingface_hub import hf_hub_download |
|
|
import os |
|
|
|
|
|
|
|
|
from networks import ResnetGenerator |
|
|
|
|
|
class CycleGANInference: |
|
|
def __init__(self, model_repo_id, checkpoint_filename_AtoB, checkpoint_filename_BtoA=None): |
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
checkpoint_path_AtoB = hf_hub_download( |
|
|
repo_id=model_repo_id, |
|
|
filename=checkpoint_filename_AtoB |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.netG_A2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) |
|
|
|
|
|
if checkpoint_filename_BtoA: |
|
|
checkpoint_path_BtoA = hf_hub_download( |
|
|
repo_id=model_repo_id, |
|
|
filename=checkpoint_filename_BtoA |
|
|
) |
|
|
self.netG_B2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) |
|
|
else: |
|
|
self.netG_B2A = None |
|
|
|
|
|
|
|
|
self.netG_A2B.load_state_dict(torch.load(checkpoint_path_AtoB, map_location=self.device)) |
|
|
if self.netG_B2A and checkpoint_filename_BtoA: |
|
|
self.netG_B2A.load_state_dict(torch.load(checkpoint_path_BtoA, map_location=self.device)) |
|
|
|
|
|
|
|
|
self.netG_A2B.eval() |
|
|
if self.netG_B2A: |
|
|
self.netG_B2A.eval() |
|
|
|
|
|
|
|
|
self.netG_A2B.to(self.device) |
|
|
if self.netG_B2A: |
|
|
self.netG_B2A.to(self.device) |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((256, 256)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
]) |
|
|
|
|
|
self.inverse_transform = transforms.Compose([ |
|
|
transforms.Normalize((-1, -1, -1), (2, 2, 2)), |
|
|
transforms.ToPILImage() |
|
|
]) |
|
|
|
|
|
def transform_image(self, image, direction="A_to_B"): |
|
|
|
|
|
input_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
if direction == "A_to_B": |
|
|
output_tensor = self.netG_A2B(input_tensor) |
|
|
elif direction == "B_to_A" and self.netG_B2A: |
|
|
output_tensor = self.netG_B2A(input_tensor) |
|
|
else: |
|
|
raise ValueError("Invalid direction or model not available") |
|
|
|
|
|
|
|
|
output_image = self.inverse_transform(output_tensor.squeeze(0).cpu()) |
|
|
return output_image |
|
|
|
|
|
|
|
|
|
|
|
MODEL_REPO_ID = "profmatthew/decgan" |
|
|
CHECKPOINT_A2B = "200_net_G_A.pth" |
|
|
CHECKPOINT_B2A = "200_net_G_B.pth" |
|
|
|
|
|
cyclegan_model = CycleGANInference( |
|
|
model_repo_id=MODEL_REPO_ID, |
|
|
checkpoint_filename_AtoB=CHECKPOINT_A2B, |
|
|
checkpoint_filename_BtoA=CHECKPOINT_B2A |
|
|
) |
|
|
|
|
|
def generate_image(input_image, direction): |
|
|
try: |
|
|
output_image = cyclegan_model.transform_image(input_image, direction) |
|
|
return output_image |
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="CycleGAN Image Translation") as demo: |
|
|
gr.Markdown("# CycleGAN Image Translation") |
|
|
gr.Markdown("Upload an image and select the transformation direction.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image(type="pil", label="Input Image") |
|
|
direction = gr.Dropdown( |
|
|
choices=["A_to_B", "B_to_A"], |
|
|
value="A_to_B", |
|
|
label="Translation Direction" |
|
|
) |
|
|
generate_btn = gr.Button("Generate", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_image = gr.Image(type="pil", label="Generated Image") |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_image, |
|
|
inputs=[input_image, direction], |
|
|
outputs=output_image |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|