Spaces:
Paused
Paused
File size: 5,229 Bytes
e1203d8 10ae698 e1203d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | """
Gradio app β Car GAN image generator.
Deployed on HuggingFace Spaces (SDK: gradio).
Directory structure expected on the Space:
app.py
src/
generator_weights.pth β uploaded separately or loaded from HF Hub
"""
import os
import io
import random
import torch
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
from PIL import Image
# Add project root to path so `src` is importable
import sys
sys.path.insert(0, os.path.dirname(__file__))
from src.models.generator import Generator
# βββββββββββββββββββββββββββββββββββββββββββββ
# Config (keep in sync with your config.yaml)
# βββββββββββββββββββββββββββββββββββββββββββββ
LATENT_DIM = 128
IMAGE_SIZE = 64
FEATURES = 64
CHANNELS = 3
MODEL_REPO = os.getenv("MODEL_REPO", "Parsa2025AI/car-gan")
WEIGHT_FILE = "generator_weights.pth"
# βββββββββββββββββββββββββββββββββββββββββββββ
# Load model
# βββββββββββββββββββββββββββββββββββββββββββββ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(
latent_dim=LATENT_DIM,
features=FEATURES,
channels=CHANNELS,
image_size=IMAGE_SIZE,
).to(device)
def _load_weights():
"""Load generator weights from local file or HF Hub."""
local_path = WEIGHT_FILE
if not os.path.exists(local_path):
print(f"[App] Downloading weights from {MODEL_REPO} β¦")
local_path = hf_hub_download(repo_id=MODEL_REPO, filename=WEIGHT_FILE)
ckpt = torch.load(local_path, map_location=device)
# Support both raw state_dict and full checkpoint
state = ckpt.get("generator_state", ckpt)
generator.load_state_dict(state)
generator.eval()
print("[App] Generator ready.")
_load_weights()
# βββββββββββββββββββββββββββββββββββββββββββββ
# Generation logic
# βββββββββββββββββββββββββββββββββββββββββββββ
def generate_cars(n_images: int, seed: int) -> list[Image.Image]:
"""Generate n_images car images. Returns list of PIL Images."""
if seed == -1:
seed = random.randint(0, 2**31)
torch.manual_seed(seed)
np.random.seed(seed)
with torch.no_grad():
z = torch.randn(int(n_images), LATENT_DIM, device=device)
imgs = generator(z).cpu()
# De-normalise [-1, 1] β [0, 255]
imgs = ((imgs + 1) / 2 * 255).clamp(0, 255).byte()
pil_images = [Image.fromarray(img.permute(1, 2, 0).numpy()) for img in imgs]
return pil_images
# βββββββββββββββββββββββββββββββββββββββββββββ
# Gradio UI
# βββββββββββββββββββββββββββββββββββββββββββββ
CSS = """
#generate-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
font-size: 1.1rem;
padding: 0.75rem 2rem;
border-radius: 12px;
border: none;
cursor: pointer;
transition: opacity 0.2s;
font-weight: 600;
}
#generate-btn:hover { opacity: 0.88; }
.gallery-item img { border-radius: 8px; }
"""
with gr.Blocks(css=CSS, title="Car GAN Generator") as demo:
gr.Markdown(
"""
# π Car GAN β AI Image Generator
Generate realistic car images using a Deep Convolutional GAN trained on the Stanford Cars dataset.
"""
)
with gr.Row():
with gr.Column(scale=1):
n_slider = gr.Slider(
minimum=1, maximum=16, step=1, value=4,
label="Number of images",
)
seed_input = gr.Number(
value=-1, label="Seed (-1 = random)", precision=0,
)
gen_btn = gr.Button("β¨ Generate Cars", elem_id="generate-btn")
gr.Markdown(
"""
**Tips**
- Use a fixed seed to reproduce the same images
- Generate up to 16 images at once
"""
)
with gr.Column(scale=3):
gallery = gr.Gallery(
label="Generated cars",
columns=4,
rows=2,
height="auto",
object_fit="contain",
)
gen_btn.click(
fn=generate_cars,
inputs=[n_slider, seed_input],
outputs=gallery,
)
gr.Examples(
examples=[[4, 42], [9, 123], [16, 999]],
inputs=[n_slider, seed_input],
fn=generate_cars,
outputs=gallery,
cache_examples=True,
)
gr.Markdown(
"""
---
Model architecture: DCGAN | Dataset: Stanford Cars | Framework: PyTorch
"""
)
if __name__ == "__main__":
demo.launch()
|