Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms | |
| import random | |
| from huggingface_hub import hf_hub_download | |
| from generator_1 import Generator as StyleGANGenerator # Import your StyleGAN2 generator | |
| from generator_2 import Generator as SRGANGenerator # Import your SRGAN generator | |
| # wts = ['trial_0_G (1).pth', 'trial_0_G (2).pth', 'trial_0_G (3).pth', 'trial_0_G (4).pth', 'trial_0_G (5).pth', 'trial_0_G.pth'] | |
| wts = ['trial_0_G (2).pth', 'trial_0_G (5).pth', 'trial_0_G.pth'] | |
| random_wt = random.choice(wts) | |
| # Load trained model weights from Hugging Face Hub | |
| weights_path = hf_hub_download( | |
| repo_id="keysun89/image_generation", # Fixed repo name | |
| filename=random_wt | |
| ) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Configure your generator parameters | |
| z_dim = 512 | |
| w_dim = 512 | |
| img_resolution = 256 # Adjust to your training resolution | |
| img_channels = 3 | |
| model = StyleGANGenerator( | |
| z_dim=z_dim, | |
| w_dim=w_dim, | |
| img_resolution=img_resolution, | |
| img_channels=img_channels | |
| ) | |
| # Load weights | |
| model.load_state_dict(torch.load(weights_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| wt_2 = 'genrator.pth' | |
| srgan_weights = hf_hub_download( | |
| repo_id="keysun89/image_generation", # Fixed repo name | |
| filename=wt_2 | |
| ) | |
| # Initialize SRGAN with scale=2 (256 -> 512) | |
| srgan_model = SRGANGenerator(img_feat=3, n_feats=64, kernel_size=3, num_block=16, scale=2) | |
| srgan_model.load_state_dict(torch.load(srgan_weights, map_location=device)) | |
| srgan_model.to(device) | |
| srgan_model.eval() | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Adjust if needed | |
| ]) | |
| def generate(): | |
| """Generate 256x256 image and upscale to 512x512""" | |
| with torch.no_grad(): | |
| # Step 1: Generate 256x256 image with StyleGAN2 | |
| z = torch.randn(1, z_dim, device=device) | |
| img_256 = model(z, use_truncation=True, truncation_psi=0.7) | |
| # Convert to PIL Image (256x256) | |
| img_256_np = img_256.squeeze(0).cpu().numpy() | |
| img_256_np = np.transpose(img_256_np, (1, 2, 0)) # CHW to HWC | |
| img_256_np = (img_256_np * 127.5 + 128).clip(0, 255).astype(np.uint8) | |
| img_256_pil = Image.fromarray(img_256_np) | |
| # Step 2: Upscale to 512x512 with SRGAN | |
| img_256_tensor = transform(img_256_pil).unsqueeze(0).to(device) | |
| # Generate high-resolution image (SRGAN returns tuple: image, features) | |
| img_512, _ = srgan_model(img_256_tensor) | |
| # Convert to PIL Image (512x512) | |
| img_512_np = img_512.squeeze(0).cpu().numpy() | |
| img_512_np = np.transpose(img_512_np, (1, 2, 0)) # CHW to HWC | |
| # Denormalize from tanh output [-1, 1] to [0, 255] | |
| img_512_np = (img_512_np * 127.5 + 127.5).clip(0, 255).astype(np.uint8) | |
| img_512_pil = Image.fromarray(img_512_np) | |
| return img_256_pil, img_512_pil | |
| # Gradio interface | |
| with gr.Blocks(title="StyleGAN2 + SRGAN Generator") as demo: | |
| gr.Markdown("# 🎨 StyleGAN2 + SRGAN Image Generator") | |
| gr.Markdown("Generate a 256x256 image and upscale it to 512x512 using SRGAN") | |
| with gr.Row(): | |
| generate_btn = gr.Button("🎲 Generate New Image", variant="primary", size="lg") | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_256 = gr.Image(label="Generated 256x256", type="pil", height=300) | |
| with gr.Column(): | |
| output_512 = gr.Image(label="Upscaled 512x512 (SRGAN)", type="pil", height=300) | |
| # Generate on button click | |
| generate_btn.click( | |
| fn=generate, | |
| inputs=None, | |
| outputs=[output_256, output_512] | |
| ) | |
| # Generate on page load | |
| demo.load( | |
| fn=generate, | |
| inputs=None, | |
| outputs=[output_256, output_512] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |