Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| os.environ["KERAS_BACKEND"] = "jax" | |
| from inference import load_weights, generate_images, WEIGHTS_H5, CHECKPOINT_PKL | |
| _model_loaded = False | |
| def load_model_once(): | |
| global _model_loaded | |
| if _model_loaded: | |
| return True | |
| ok = load_weights(resolution=256) | |
| if ok: | |
| _model_loaded = True | |
| return ok | |
| def generate(n_images, seed): | |
| n_images = max(1, min(int(n_images), 16)) | |
| seed = int(seed) if seed is not None and str(seed).strip() else None | |
| if not load_model_once(): | |
| return [], f"❌ No weights found. Put checkpoint.pkl in:\n{os.path.dirname(WEIGHTS_H5)}" | |
| try: | |
| images = generate_images(n_images, resolution=256, seed=seed) | |
| if images is None or len(images) == 0: | |
| return [], "No images generated." | |
| return images, f"✅ Generated {len(images)} face(s) at 256×256." | |
| except Exception as e: | |
| import traceback | |
| return [], f"Error: {e}\n{traceback.format_exc()}" | |
| css = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;700;800&family=DM+Mono:wght@300;400&display=swap'); | |
| * { font-family: 'Syne', sans-serif; box-sizing: border-box; } | |
| /* Dark background */ | |
| .gradio-container { | |
| background: #0a0a0f !important; | |
| min-height: 100vh; | |
| } | |
| /* Hero header */ | |
| .hero { | |
| text-align: center; | |
| padding: 2.5rem 1rem 1rem; | |
| } | |
| .hero-title { | |
| font-size: 3rem; | |
| font-weight: 800; | |
| letter-spacing: -0.03em; | |
| background: linear-gradient(135deg, #e2e8f0 0%, #94a3b8 50%, #475569 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| margin: 0; | |
| line-height: 1.1; | |
| } | |
| .hero-sub { | |
| font-family: 'DM Mono', monospace; | |
| font-size: 0.78rem; | |
| color: #475569; | |
| letter-spacing: 0.15em; | |
| text-transform: uppercase; | |
| margin-top: 0.6rem; | |
| } | |
| /* Panel cards */ | |
| .gr-block, .gr-box, .panel { | |
| background: #111118 !important; | |
| border: 1px solid #1e1e2e !important; | |
| border-radius: 12px !important; | |
| } | |
| /* Labels */ | |
| label span { | |
| font-family: 'DM Mono', monospace !important; | |
| font-size: 0.72rem !important; | |
| letter-spacing: 0.1em !important; | |
| text-transform: uppercase !important; | |
| color: #64748b !important; | |
| } | |
| /* Slider */ | |
| input[type=range] { | |
| accent-color: #6366f1; | |
| } | |
| /* Generate button */ | |
| #gen-btn { | |
| background: #6366f1 !important; | |
| border: none !important; | |
| border-radius: 8px !important; | |
| font-family: 'Syne', sans-serif !important; | |
| font-weight: 700 !important; | |
| font-size: 0.95rem !important; | |
| letter-spacing: 0.05em !important; | |
| color: white !important; | |
| padding: 0.75rem 2rem !important; | |
| transition: background 0.2s, transform 0.1s !important; | |
| width: 100% !important; | |
| } | |
| #gen-btn:hover { | |
| background: #4f46e5 !important; | |
| transform: translateY(-1px); | |
| } | |
| #gen-btn:active { transform: translateY(0); } | |
| /* Status box */ | |
| #status textarea { | |
| font-family: 'DM Mono', monospace !important; | |
| font-size: 0.8rem !important; | |
| background: #0a0a0f !important; | |
| color: #64748b !important; | |
| border: 1px solid #1e1e2e !important; | |
| border-radius: 8px !important; | |
| } | |
| /* Gallery */ | |
| .gallery-item img { | |
| border-radius: 8px !important; | |
| transition: transform 0.2s !important; | |
| } | |
| .gallery-item img:hover { | |
| transform: scale(1.03); | |
| } | |
| /* Footer */ | |
| .footer { | |
| text-align: center; | |
| padding: 2rem; | |
| font-family: 'DM Mono', monospace; | |
| font-size: 0.72rem; | |
| color: #1e293b; | |
| letter-spacing: 0.08em; | |
| } | |
| """ | |
| def main(): | |
| with gr.Blocks(css=css, title="StyleGAN v1 — Face Synthesis") as demo: | |
| gr.HTML(""" | |
| <div class="hero"> | |
| <div class="hero-title">StyleGAN v1</div> | |
| <div class="hero-sub">Progressive Face Synthesis · 256 × 256</div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| n_images = gr.Slider( | |
| minimum=1, maximum=16, value=4, step=1, | |
| label="Number of faces" | |
| ) | |
| seed = gr.Textbox( | |
| value="", | |
| label="Seed (leave blank for random)", | |
| placeholder="e.g. 42", | |
| ) | |
| gen_btn = gr.Button("Generate ↗", variant="primary", elem_id="gen-btn") | |
| status = gr.Textbox( | |
| label="Status", interactive=False, | |
| elem_id="status", lines=2 | |
| ) | |
| with gr.Column(scale=3): | |
| gallery = gr.Gallery( | |
| label="Generated faces", | |
| columns=4, | |
| rows=4, # ← allow up to 4 rows for 16 images | |
| object_fit="contain", # ← was "cover" which zooms in, "contain" shows full image | |
| height="auto", # ← was fixed 520px which cut off rows | |
| show_label=False, | |
| ) | |
| gr.HTML('<div class="footer">Built with Keras · JAX · Gradio</div>') | |
| gen_btn.click(fn=generate, inputs=[n_images, seed], outputs=[gallery, status]) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.environ.get("PORT", 7860)), | |
| ) | |
| if __name__ == "__main__": | |
| main() |