import gradio as gr import torch import numpy as np from PIL import Image from torchvision import transforms import random from huggingface_hub import hf_hub_download from generator_1 import Generator as StyleGANGenerator # Import your StyleGAN2 generator from generator_2 import Generator as SRGANGenerator # Import your SRGAN generator # wts = ['trial_0_G (1).pth', 'trial_0_G (2).pth', 'trial_0_G (3).pth', 'trial_0_G (4).pth', 'trial_0_G (5).pth', 'trial_0_G.pth'] wts = ['trial_0_G (2).pth', 'trial_0_G (5).pth', 'trial_0_G.pth'] random_wt = random.choice(wts) # Load trained model weights from Hugging Face Hub weights_path = hf_hub_download( repo_id="keysun89/image_generation", # Fixed repo name filename=random_wt ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Configure your generator parameters z_dim = 512 w_dim = 512 img_resolution = 256 # Adjust to your training resolution img_channels = 3 model = StyleGANGenerator( z_dim=z_dim, w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels ) # Load weights model.load_state_dict(torch.load(weights_path, map_location=device)) model.to(device) model.eval() wt_2 = 'genrator.pth' srgan_weights = hf_hub_download( repo_id="keysun89/image_generation", # Fixed repo name filename=wt_2 ) # Initialize SRGAN with scale=2 (256 -> 512) srgan_model = SRGANGenerator(img_feat=3, n_feats=64, kernel_size=3, num_block=16, scale=2) srgan_model.load_state_dict(torch.load(srgan_weights, map_location=device)) srgan_model.to(device) srgan_model.eval() transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Adjust if needed ]) def generate(): """Generate 256x256 image and upscale to 512x512""" with torch.no_grad(): # Step 1: Generate 256x256 image with StyleGAN2 z = torch.randn(1, z_dim, device=device) img_256 = model(z, use_truncation=True, truncation_psi=0.7) # Convert to PIL Image (256x256) img_256_np = img_256.squeeze(0).cpu().numpy() img_256_np = np.transpose(img_256_np, (1, 2, 0)) # CHW to HWC img_256_np = (img_256_np * 127.5 + 128).clip(0, 255).astype(np.uint8) img_256_pil = Image.fromarray(img_256_np) # Step 2: Upscale to 512x512 with SRGAN img_256_tensor = transform(img_256_pil).unsqueeze(0).to(device) # Generate high-resolution image (SRGAN returns tuple: image, features) img_512, _ = srgan_model(img_256_tensor) # Convert to PIL Image (512x512) img_512_np = img_512.squeeze(0).cpu().numpy() img_512_np = np.transpose(img_512_np, (1, 2, 0)) # CHW to HWC # Denormalize from tanh output [-1, 1] to [0, 255] img_512_np = (img_512_np * 127.5 + 127.5).clip(0, 255).astype(np.uint8) img_512_pil = Image.fromarray(img_512_np) return img_256_pil, img_512_pil # Gradio interface with gr.Blocks(title="StyleGAN2 + SRGAN Generator") as demo: gr.Markdown("# 🎨 StyleGAN2 + SRGAN Image Generator") gr.Markdown("Generate a 256x256 image and upscale it to 512x512 using SRGAN") with gr.Row(): generate_btn = gr.Button("🎲 Generate New Image", variant="primary", size="lg") with gr.Row(): with gr.Column(): output_256 = gr.Image(label="Generated 256x256", type="pil", height=300) with gr.Column(): output_512 = gr.Image(label="Upscaled 512x512 (SRGAN)", type="pil", height=300) # Generate on button click generate_btn.click( fn=generate, inputs=None, outputs=[output_256, output_512] ) # Generate on page load demo.load( fn=generate, inputs=None, outputs=[output_256, output_512] ) if __name__ == "__main__": demo.launch()