import os import gradio as gr os.environ["KERAS_BACKEND"] = "jax" from inference import load_weights, generate_images, WEIGHTS_H5, CHECKPOINT_PKL _model_loaded = False def load_model_once(): global _model_loaded if _model_loaded: return True ok = load_weights(resolution=256) if ok: _model_loaded = True return ok def generate(n_images, seed): n_images = max(1, min(int(n_images), 16)) seed = int(seed) if seed is not None and str(seed).strip() else None if not load_model_once(): return [], f"❌ No weights found. Put checkpoint.pkl in:\n{os.path.dirname(WEIGHTS_H5)}" try: images = generate_images(n_images, resolution=256, seed=seed) if images is None or len(images) == 0: return [], "No images generated." return images, f"✅ Generated {len(images)} face(s) at 256×256." except Exception as e: import traceback return [], f"Error: {e}\n{traceback.format_exc()}" css = """ @import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;700;800&family=DM+Mono:wght@300;400&display=swap'); * { font-family: 'Syne', sans-serif; box-sizing: border-box; } /* Dark background */ .gradio-container { background: #0a0a0f !important; min-height: 100vh; } /* Hero header */ .hero { text-align: center; padding: 2.5rem 1rem 1rem; } .hero-title { font-size: 3rem; font-weight: 800; letter-spacing: -0.03em; background: linear-gradient(135deg, #e2e8f0 0%, #94a3b8 50%, #475569 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin: 0; line-height: 1.1; } .hero-sub { font-family: 'DM Mono', monospace; font-size: 0.78rem; color: #475569; letter-spacing: 0.15em; text-transform: uppercase; margin-top: 0.6rem; } /* Panel cards */ .gr-block, .gr-box, .panel { background: #111118 !important; border: 1px solid #1e1e2e !important; border-radius: 12px !important; } /* Labels */ label span { font-family: 'DM Mono', monospace !important; font-size: 0.72rem !important; letter-spacing: 0.1em !important; text-transform: uppercase !important; color: #64748b !important; } /* Slider */ input[type=range] { accent-color: #6366f1; } /* Generate button */ #gen-btn { background: #6366f1 !important; border: none !important; border-radius: 8px !important; font-family: 'Syne', sans-serif !important; font-weight: 700 !important; font-size: 0.95rem !important; letter-spacing: 0.05em !important; color: white !important; padding: 0.75rem 2rem !important; transition: background 0.2s, transform 0.1s !important; width: 100% !important; } #gen-btn:hover { background: #4f46e5 !important; transform: translateY(-1px); } #gen-btn:active { transform: translateY(0); } /* Status box */ #status textarea { font-family: 'DM Mono', monospace !important; font-size: 0.8rem !important; background: #0a0a0f !important; color: #64748b !important; border: 1px solid #1e1e2e !important; border-radius: 8px !important; } /* Gallery */ .gallery-item img { border-radius: 8px !important; transition: transform 0.2s !important; } .gallery-item img:hover { transform: scale(1.03); } /* Footer */ .footer { text-align: center; padding: 2rem; font-family: 'DM Mono', monospace; font-size: 0.72rem; color: #1e293b; letter-spacing: 0.08em; } """ def main(): with gr.Blocks(css=css, title="StyleGAN v1 — Face Synthesis") as demo: gr.HTML("""
StyleGAN v1
Progressive Face Synthesis · 256 × 256
""") with gr.Row(): with gr.Column(scale=1): n_images = gr.Slider( minimum=1, maximum=16, value=4, step=1, label="Number of faces" ) seed = gr.Textbox( value="", label="Seed (leave blank for random)", placeholder="e.g. 42", ) gen_btn = gr.Button("Generate ↗", variant="primary", elem_id="gen-btn") status = gr.Textbox( label="Status", interactive=False, elem_id="status", lines=2 ) with gr.Column(scale=3): gallery = gr.Gallery( label="Generated faces", columns=4, rows=4, # ← allow up to 4 rows for 16 images object_fit="contain", # ← was "cover" which zooms in, "contain" shows full image height="auto", # ← was fixed 520px which cut off rows show_label=False, ) gr.HTML('') gen_btn.click(fn=generate, inputs=[n_images, seed], outputs=[gallery, status]) demo.launch( server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), ) if __name__ == "__main__": main()