CarGANDemo / app.py
Parsa2025AI's picture
main app
10ae698 verified
"""
Gradio app – Car GAN image generator.
Deployed on HuggingFace Spaces (SDK: gradio).
Directory structure expected on the Space:
app.py
src/
generator_weights.pth ← uploaded separately or loaded from HF Hub
"""
import os
import io
import random
import torch
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
from PIL import Image
# Add project root to path so `src` is importable
import sys
sys.path.insert(0, os.path.dirname(__file__))
from src.models.generator import Generator
# ─────────────────────────────────────────────
# Config (keep in sync with your config.yaml)
# ─────────────────────────────────────────────
LATENT_DIM = 128
IMAGE_SIZE = 64
FEATURES = 64
CHANNELS = 3
MODEL_REPO = os.getenv("MODEL_REPO", "Parsa2025AI/car-gan")
WEIGHT_FILE = "generator_weights.pth"
# ─────────────────────────────────────────────
# Load model
# ─────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(
latent_dim=LATENT_DIM,
features=FEATURES,
channels=CHANNELS,
image_size=IMAGE_SIZE,
).to(device)
def _load_weights():
"""Load generator weights from local file or HF Hub."""
local_path = WEIGHT_FILE
if not os.path.exists(local_path):
print(f"[App] Downloading weights from {MODEL_REPO} …")
local_path = hf_hub_download(repo_id=MODEL_REPO, filename=WEIGHT_FILE)
ckpt = torch.load(local_path, map_location=device)
# Support both raw state_dict and full checkpoint
state = ckpt.get("generator_state", ckpt)
generator.load_state_dict(state)
generator.eval()
print("[App] Generator ready.")
_load_weights()
# ─────────────────────────────────────────────
# Generation logic
# ─────────────────────────────────────────────
def generate_cars(n_images: int, seed: int) -> list[Image.Image]:
"""Generate n_images car images. Returns list of PIL Images."""
if seed == -1:
seed = random.randint(0, 2**31)
torch.manual_seed(seed)
np.random.seed(seed)
with torch.no_grad():
z = torch.randn(int(n_images), LATENT_DIM, device=device)
imgs = generator(z).cpu()
# De-normalise [-1, 1] β†’ [0, 255]
imgs = ((imgs + 1) / 2 * 255).clamp(0, 255).byte()
pil_images = [Image.fromarray(img.permute(1, 2, 0).numpy()) for img in imgs]
return pil_images
# ─────────────────────────────────────────────
# Gradio UI
# ─────────────────────────────────────────────
CSS = """
#generate-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
font-size: 1.1rem;
padding: 0.75rem 2rem;
border-radius: 12px;
border: none;
cursor: pointer;
transition: opacity 0.2s;
font-weight: 600;
}
#generate-btn:hover { opacity: 0.88; }
.gallery-item img { border-radius: 8px; }
"""
with gr.Blocks(css=CSS, title="Car GAN Generator") as demo:
gr.Markdown(
"""
# πŸš— Car GAN β€” AI Image Generator
Generate realistic car images using a Deep Convolutional GAN trained on the Stanford Cars dataset.
"""
)
with gr.Row():
with gr.Column(scale=1):
n_slider = gr.Slider(
minimum=1, maximum=16, step=1, value=4,
label="Number of images",
)
seed_input = gr.Number(
value=-1, label="Seed (-1 = random)", precision=0,
)
gen_btn = gr.Button("✨ Generate Cars", elem_id="generate-btn")
gr.Markdown(
"""
**Tips**
- Use a fixed seed to reproduce the same images
- Generate up to 16 images at once
"""
)
with gr.Column(scale=3):
gallery = gr.Gallery(
label="Generated cars",
columns=4,
rows=2,
height="auto",
object_fit="contain",
)
gen_btn.click(
fn=generate_cars,
inputs=[n_slider, seed_input],
outputs=gallery,
)
gr.Examples(
examples=[[4, 42], [9, 123], [16, 999]],
inputs=[n_slider, seed_input],
fn=generate_cars,
outputs=gallery,
cache_examples=True,
)
gr.Markdown(
"""
---
Model architecture: DCGAN | Dataset: Stanford Cars | Framework: PyTorch
"""
)
if __name__ == "__main__":
demo.launch()