img_generation / app.py
keysun89's picture
Update app.py
d86dcd4 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import random
from huggingface_hub import hf_hub_download
from generator_1 import Generator as StyleGANGenerator # Import your StyleGAN2 generator
from generator_2 import Generator as SRGANGenerator # Import your SRGAN generator
# wts = ['trial_0_G (1).pth', 'trial_0_G (2).pth', 'trial_0_G (3).pth', 'trial_0_G (4).pth', 'trial_0_G (5).pth', 'trial_0_G.pth']
wts = ['trial_0_G (2).pth', 'trial_0_G (5).pth', 'trial_0_G.pth']
random_wt = random.choice(wts)
# Load trained model weights from Hugging Face Hub
weights_path = hf_hub_download(
repo_id="keysun89/image_generation", # Fixed repo name
filename=random_wt
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Configure your generator parameters
z_dim = 512
w_dim = 512
img_resolution = 256 # Adjust to your training resolution
img_channels = 3
model = StyleGANGenerator(
z_dim=z_dim,
w_dim=w_dim,
img_resolution=img_resolution,
img_channels=img_channels
)
# Load weights
model.load_state_dict(torch.load(weights_path, map_location=device))
model.to(device)
model.eval()
wt_2 = 'genrator.pth'
srgan_weights = hf_hub_download(
repo_id="keysun89/image_generation", # Fixed repo name
filename=wt_2
)
# Initialize SRGAN with scale=2 (256 -> 512)
srgan_model = SRGANGenerator(img_feat=3, n_feats=64, kernel_size=3, num_block=16, scale=2)
srgan_model.load_state_dict(torch.load(srgan_weights, map_location=device))
srgan_model.to(device)
srgan_model.eval()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Adjust if needed
])
def generate():
"""Generate 256x256 image and upscale to 512x512"""
with torch.no_grad():
# Step 1: Generate 256x256 image with StyleGAN2
z = torch.randn(1, z_dim, device=device)
img_256 = model(z, use_truncation=True, truncation_psi=0.7)
# Convert to PIL Image (256x256)
img_256_np = img_256.squeeze(0).cpu().numpy()
img_256_np = np.transpose(img_256_np, (1, 2, 0)) # CHW to HWC
img_256_np = (img_256_np * 127.5 + 128).clip(0, 255).astype(np.uint8)
img_256_pil = Image.fromarray(img_256_np)
# Step 2: Upscale to 512x512 with SRGAN
img_256_tensor = transform(img_256_pil).unsqueeze(0).to(device)
# Generate high-resolution image (SRGAN returns tuple: image, features)
img_512, _ = srgan_model(img_256_tensor)
# Convert to PIL Image (512x512)
img_512_np = img_512.squeeze(0).cpu().numpy()
img_512_np = np.transpose(img_512_np, (1, 2, 0)) # CHW to HWC
# Denormalize from tanh output [-1, 1] to [0, 255]
img_512_np = (img_512_np * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
img_512_pil = Image.fromarray(img_512_np)
return img_256_pil, img_512_pil
# Gradio interface
with gr.Blocks(title="StyleGAN2 + SRGAN Generator") as demo:
gr.Markdown("# 🎨 StyleGAN2 + SRGAN Image Generator")
gr.Markdown("Generate a 256x256 image and upscale it to 512x512 using SRGAN")
with gr.Row():
generate_btn = gr.Button("🎲 Generate New Image", variant="primary", size="lg")
with gr.Row():
with gr.Column():
output_256 = gr.Image(label="Generated 256x256", type="pil", height=300)
with gr.Column():
output_512 = gr.Image(label="Upscaled 512x512 (SRGAN)", type="pil", height=300)
# Generate on button click
generate_btn.click(
fn=generate,
inputs=None,
outputs=[output_256, output_512]
)
# Generate on page load
demo.load(
fn=generate,
inputs=None,
outputs=[output_256, output_512]
)
if __name__ == "__main__":
demo.launch()