Att-DeCGAN / app.py
RAIL-KNUST's picture
Create app.py
25b6002 verified
import gradio as gr
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
from huggingface_hub import hf_hub_download
import os
# Import your networks (you'll need to upload networks.py to your Space)
from networks import ResnetGenerator # Adjust this import based on your networks.py structure
class CycleGANInference:
def __init__(self, model_repo_id, checkpoint_filename_AtoB, checkpoint_filename_BtoA=None):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Download model checkpoints from Hugging Face Hub
checkpoint_path_AtoB = hf_hub_download(
repo_id=model_repo_id,
filename=checkpoint_filename_AtoB
)
# Initialize generators
# Adjust these parameters based on your model architecture
self.netG_A2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) # A to B
if checkpoint_filename_BtoA:
checkpoint_path_BtoA = hf_hub_download(
repo_id=model_repo_id,
filename=checkpoint_filename_BtoA
)
self.netG_B2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) # B to A
else:
self.netG_B2A = None
# Load model weights
self.netG_A2B.load_state_dict(torch.load(checkpoint_path_AtoB, map_location=self.device))
if self.netG_B2A and checkpoint_filename_BtoA:
self.netG_B2A.load_state_dict(torch.load(checkpoint_path_BtoA, map_location=self.device))
# Set to evaluation mode
self.netG_A2B.eval()
if self.netG_B2A:
self.netG_B2A.eval()
# Move to device
self.netG_A2B.to(self.device)
if self.netG_B2A:
self.netG_B2A.to(self.device)
# Define transforms
self.transform = transforms.Compose([
transforms.Resize((256, 256)), # Adjust size based on your model
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.inverse_transform = transforms.Compose([
transforms.Normalize((-1, -1, -1), (2, 2, 2)), # Denormalize
transforms.ToPILImage()
])
def transform_image(self, image, direction="A_to_B"):
# Preprocess
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
if direction == "A_to_B":
output_tensor = self.netG_A2B(input_tensor)
elif direction == "B_to_A" and self.netG_B2A:
output_tensor = self.netG_B2A(input_tensor)
else:
raise ValueError("Invalid direction or model not available")
# Postprocess
output_image = self.inverse_transform(output_tensor.squeeze(0).cpu())
return output_image
# Initialize your model
# Replace these with your actual Hugging Face repo ID and checkpoint filenames
MODEL_REPO_ID = "profmatthew/decgan" # Replace with your repo
CHECKPOINT_A2B = "200_net_G_A.pth" # Replace with your checkpoint filename
CHECKPOINT_B2A = "200_net_G_B.pth" # Replace with your checkpoint filename (optional)
cyclegan_model = CycleGANInference(
model_repo_id=MODEL_REPO_ID,
checkpoint_filename_AtoB=CHECKPOINT_A2B,
checkpoint_filename_BtoA=CHECKPOINT_B2A # Set to None if you only have one direction
)
def generate_image(input_image, direction):
try:
output_image = cyclegan_model.transform_image(input_image, direction)
return output_image
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="CycleGAN Image Translation") as demo:
gr.Markdown("# CycleGAN Image Translation")
gr.Markdown("Upload an image and select the transformation direction.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
direction = gr.Dropdown(
choices=["A_to_B", "B_to_A"],
value="A_to_B",
label="Translation Direction"
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output_image = gr.Image(type="pil", label="Generated Image")
generate_btn.click(
fn=generate_image,
inputs=[input_image, direction],
outputs=output_image
)
# Add some examples if you have them
# gr.Examples(
# examples=[
# # Add paths to example images here
# # ["example1.jpg", "A_to_B"],
# # ["example2.jpg", "B_to_A"],
# ],
# inputs=[input_image, direction],
# outputs=output_image,
# fn=generate_image,
# )
if __name__ == "__main__":
demo.launch()