import gradio as gr import torch import torch.nn as nn from PIL import Image import torchvision.transforms as transforms import numpy as np from huggingface_hub import hf_hub_download import os # Import your networks (you'll need to upload networks.py to your Space) from networks import ResnetGenerator # Adjust this import based on your networks.py structure class CycleGANInference: def __init__(self, model_repo_id, checkpoint_filename_AtoB, checkpoint_filename_BtoA=None): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Download model checkpoints from Hugging Face Hub checkpoint_path_AtoB = hf_hub_download( repo_id=model_repo_id, filename=checkpoint_filename_AtoB ) # Initialize generators # Adjust these parameters based on your model architecture self.netG_A2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) # A to B if checkpoint_filename_BtoA: checkpoint_path_BtoA = hf_hub_download( repo_id=model_repo_id, filename=checkpoint_filename_BtoA ) self.netG_B2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) # B to A else: self.netG_B2A = None # Load model weights self.netG_A2B.load_state_dict(torch.load(checkpoint_path_AtoB, map_location=self.device)) if self.netG_B2A and checkpoint_filename_BtoA: self.netG_B2A.load_state_dict(torch.load(checkpoint_path_BtoA, map_location=self.device)) # Set to evaluation mode self.netG_A2B.eval() if self.netG_B2A: self.netG_B2A.eval() # Move to device self.netG_A2B.to(self.device) if self.netG_B2A: self.netG_B2A.to(self.device) # Define transforms self.transform = transforms.Compose([ transforms.Resize((256, 256)), # Adjust size based on your model transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) self.inverse_transform = transforms.Compose([ transforms.Normalize((-1, -1, -1), (2, 2, 2)), # Denormalize transforms.ToPILImage() ]) def transform_image(self, image, direction="A_to_B"): # Preprocess input_tensor = self.transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): if direction == "A_to_B": output_tensor = self.netG_A2B(input_tensor) elif direction == "B_to_A" and self.netG_B2A: output_tensor = self.netG_B2A(input_tensor) else: raise ValueError("Invalid direction or model not available") # Postprocess output_image = self.inverse_transform(output_tensor.squeeze(0).cpu()) return output_image # Initialize your model # Replace these with your actual Hugging Face repo ID and checkpoint filenames MODEL_REPO_ID = "profmatthew/decgan" # Replace with your repo CHECKPOINT_A2B = "200_net_G_A.pth" # Replace with your checkpoint filename CHECKPOINT_B2A = "200_net_G_B.pth" # Replace with your checkpoint filename (optional) cyclegan_model = CycleGANInference( model_repo_id=MODEL_REPO_ID, checkpoint_filename_AtoB=CHECKPOINT_A2B, checkpoint_filename_BtoA=CHECKPOINT_B2A # Set to None if you only have one direction ) def generate_image(input_image, direction): try: output_image = cyclegan_model.transform_image(input_image, direction) return output_image except Exception as e: return f"Error: {str(e)}" # Create Gradio interface with gr.Blocks(title="CycleGAN Image Translation") as demo: gr.Markdown("# CycleGAN Image Translation") gr.Markdown("Upload an image and select the transformation direction.") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") direction = gr.Dropdown( choices=["A_to_B", "B_to_A"], value="A_to_B", label="Translation Direction" ) generate_btn = gr.Button("Generate", variant="primary") with gr.Column(): output_image = gr.Image(type="pil", label="Generated Image") generate_btn.click( fn=generate_image, inputs=[input_image, direction], outputs=output_image ) # Add some examples if you have them # gr.Examples( # examples=[ # # Add paths to example images here # # ["example1.jpg", "A_to_B"], # # ["example2.jpg", "B_to_A"], # ], # inputs=[input_image, direction], # outputs=output_image, # fn=generate_image, # ) if __name__ == "__main__": demo.launch()