Spaces:
Paused
Paused
| """ | |
| Gradio app β Car GAN image generator. | |
| Deployed on HuggingFace Spaces (SDK: gradio). | |
| Directory structure expected on the Space: | |
| app.py | |
| src/ | |
| generator_weights.pth β uploaded separately or loaded from HF Hub | |
| """ | |
| import os | |
| import io | |
| import random | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| # Add project root to path so `src` is importable | |
| import sys | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| from src.models.generator import Generator | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Config (keep in sync with your config.yaml) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| LATENT_DIM = 128 | |
| IMAGE_SIZE = 64 | |
| FEATURES = 64 | |
| CHANNELS = 3 | |
| MODEL_REPO = os.getenv("MODEL_REPO", "Parsa2025AI/car-gan") | |
| WEIGHT_FILE = "generator_weights.pth" | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Load model | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| generator = Generator( | |
| latent_dim=LATENT_DIM, | |
| features=FEATURES, | |
| channels=CHANNELS, | |
| image_size=IMAGE_SIZE, | |
| ).to(device) | |
| def _load_weights(): | |
| """Load generator weights from local file or HF Hub.""" | |
| local_path = WEIGHT_FILE | |
| if not os.path.exists(local_path): | |
| print(f"[App] Downloading weights from {MODEL_REPO} β¦") | |
| local_path = hf_hub_download(repo_id=MODEL_REPO, filename=WEIGHT_FILE) | |
| ckpt = torch.load(local_path, map_location=device) | |
| # Support both raw state_dict and full checkpoint | |
| state = ckpt.get("generator_state", ckpt) | |
| generator.load_state_dict(state) | |
| generator.eval() | |
| print("[App] Generator ready.") | |
| _load_weights() | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Generation logic | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_cars(n_images: int, seed: int) -> list[Image.Image]: | |
| """Generate n_images car images. Returns list of PIL Images.""" | |
| if seed == -1: | |
| seed = random.randint(0, 2**31) | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| with torch.no_grad(): | |
| z = torch.randn(int(n_images), LATENT_DIM, device=device) | |
| imgs = generator(z).cpu() | |
| # De-normalise [-1, 1] β [0, 255] | |
| imgs = ((imgs + 1) / 2 * 255).clamp(0, 255).byte() | |
| pil_images = [Image.fromarray(img.permute(1, 2, 0).numpy()) for img in imgs] | |
| return pil_images | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Gradio UI | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| CSS = """ | |
| #generate-btn { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| font-size: 1.1rem; | |
| padding: 0.75rem 2rem; | |
| border-radius: 12px; | |
| border: none; | |
| cursor: pointer; | |
| transition: opacity 0.2s; | |
| font-weight: 600; | |
| } | |
| #generate-btn:hover { opacity: 0.88; } | |
| .gallery-item img { border-radius: 8px; } | |
| """ | |
| with gr.Blocks(css=CSS, title="Car GAN Generator") as demo: | |
| gr.Markdown( | |
| """ | |
| # π Car GAN β AI Image Generator | |
| Generate realistic car images using a Deep Convolutional GAN trained on the Stanford Cars dataset. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| n_slider = gr.Slider( | |
| minimum=1, maximum=16, step=1, value=4, | |
| label="Number of images", | |
| ) | |
| seed_input = gr.Number( | |
| value=-1, label="Seed (-1 = random)", precision=0, | |
| ) | |
| gen_btn = gr.Button("β¨ Generate Cars", elem_id="generate-btn") | |
| gr.Markdown( | |
| """ | |
| **Tips** | |
| - Use a fixed seed to reproduce the same images | |
| - Generate up to 16 images at once | |
| """ | |
| ) | |
| with gr.Column(scale=3): | |
| gallery = gr.Gallery( | |
| label="Generated cars", | |
| columns=4, | |
| rows=2, | |
| height="auto", | |
| object_fit="contain", | |
| ) | |
| gen_btn.click( | |
| fn=generate_cars, | |
| inputs=[n_slider, seed_input], | |
| outputs=gallery, | |
| ) | |
| gr.Examples( | |
| examples=[[4, 42], [9, 123], [16, 999]], | |
| inputs=[n_slider, seed_input], | |
| fn=generate_cars, | |
| outputs=gallery, | |
| cache_examples=True, | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| Model architecture: DCGAN | Dataset: Stanford Cars | Framework: PyTorch | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |