decgan / app.py
RAIL-KNUST's picture
adding files for the application
9054e98 verified
import gradio as gr
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
from huggingface_hub import hf_hub_download
import os
# Import your networks (you'll need to upload networks.py to your Space)
from networks import ResnetGenerator # Adjust this import based on your networks.py structure
class CycleGANInference:
def __init__(self, model_repo_id, checkpoint_filename_AtoB, checkpoint_filename_BtoA=None):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Download model checkpoints from Hugging Face Hub
checkpoint_path_AtoB = hf_hub_download(
repo_id=model_repo_id,
filename=checkpoint_filename_AtoB
)
# Initialize generators
# Adjust these parameters based on your model architecture
self.netG_A2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) # A to B
if checkpoint_filename_BtoA:
checkpoint_path_BtoA = hf_hub_download(
repo_id=model_repo_id,
filename=checkpoint_filename_BtoA
)
self.netG_B2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) # B to A
else:
self.netG_B2A = None
# Load model weights
self.netG_A2B.load_state_dict(torch.load(checkpoint_path_AtoB, map_location=self.device))
if self.netG_B2A and checkpoint_filename_BtoA:
self.netG_B2A.load_state_dict(torch.load(checkpoint_path_BtoA, map_location=self.device))
# Set to evaluation mode
self.netG_A2B.eval()
if self.netG_B2A:
self.netG_B2A.eval()
# Move to device
self.netG_A2B.to(self.device)
if self.netG_B2A:
self.netG_B2A.to(self.device)
# Define transforms
self.transform = transforms.Compose([
transforms.Resize((256, 256)), # Adjust size based on your model
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.inverse_transform = transforms.Compose([
transforms.Normalize((-1, -1, -1), (2, 2, 2)), # Denormalize
transforms.ToPILImage()
])
def transform_image(self, image, direction="A_to_B"):
# Preprocess
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
if direction == "A_to_B":
output_tensor = self.netG_A2B(input_tensor)
elif direction == "B_to_A" and self.netG_B2A:
output_tensor = self.netG_B2A(input_tensor)
else:
raise ValueError("Invalid direction or model not available")
# Postprocess
output_image = self.inverse_transform(output_tensor.squeeze(0).cpu())
return output_image
# Initialize your model
# Replace these with your actual Hugging Face repo ID and checkpoint filenames
MODEL_REPO_ID = "profmatthew/decgan" # Replace with your repo
CHECKPOINT_A2B = "200_net_G_A.pth" # Replace with your checkpoint filename
CHECKPOINT_B2A = "200_net_G_B.pth" # Replace with your checkpoint filename (optional)
cyclegan_model = CycleGANInference(
model_repo_id=MODEL_REPO_ID,
checkpoint_filename_AtoB=CHECKPOINT_A2B,
checkpoint_filename_BtoA=CHECKPOINT_B2A # Set to None if you only have one direction
)
def generate_image(input_image, direction):
try:
output_image = cyclegan_model.transform_image(input_image, direction)
return output_image
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="CycleGAN Image Translation") as demo:
gr.Markdown("# CycleGAN Image Translation")
gr.Markdown("Upload an image and select the transformation direction.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
direction = gr.Dropdown(
choices=["A_to_B", "B_to_A"],
value="A_to_B",
label="Translation Direction"
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output_image = gr.Image(type="pil", label="Generated Image")
generate_btn.click(
fn=generate_image,
inputs=[input_image, direction],
outputs=output_image
)
# Add some examples if you have them
# gr.Examples(
# examples=[
# # Add paths to example images here
# # ["example1.jpg", "A_to_B"],
# # ["example2.jpg", "B_to_A"],
# ],
# inputs=[input_image, direction],
# outputs=output_image,
# fn=generate_image,
# )
if __name__ == "__main__":
demo.launch()