StyleGAN / app.py
masterofaudio2077's picture
Upload 3 files
7048b74 verified
Raw
History Blame Contribute Delete
5.44 kB
import os
import gradio as gr
os.environ["KERAS_BACKEND"] = "jax"
from inference import load_weights, generate_images, WEIGHTS_H5, CHECKPOINT_PKL
_model_loaded = False
def load_model_once():
global _model_loaded
if _model_loaded:
return True
ok = load_weights(resolution=256)
if ok:
_model_loaded = True
return ok
def generate(n_images, seed):
n_images = max(1, min(int(n_images), 16))
seed = int(seed) if seed is not None and str(seed).strip() else None
if not load_model_once():
return [], f"❌ No weights found. Put checkpoint.pkl in:\n{os.path.dirname(WEIGHTS_H5)}"
try:
images = generate_images(n_images, resolution=256, seed=seed)
if images is None or len(images) == 0:
return [], "No images generated."
return images, f"✅ Generated {len(images)} face(s) at 256×256."
except Exception as e:
import traceback
return [], f"Error: {e}\n{traceback.format_exc()}"
css = """
@import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;700;800&family=DM+Mono:wght@300;400&display=swap');
* { font-family: 'Syne', sans-serif; box-sizing: border-box; }
/* Dark background */
.gradio-container {
background: #0a0a0f !important;
min-height: 100vh;
}
/* Hero header */
.hero {
text-align: center;
padding: 2.5rem 1rem 1rem;
}
.hero-title {
font-size: 3rem;
font-weight: 800;
letter-spacing: -0.03em;
background: linear-gradient(135deg, #e2e8f0 0%, #94a3b8 50%, #475569 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin: 0;
line-height: 1.1;
}
.hero-sub {
font-family: 'DM Mono', monospace;
font-size: 0.78rem;
color: #475569;
letter-spacing: 0.15em;
text-transform: uppercase;
margin-top: 0.6rem;
}
/* Panel cards */
.gr-block, .gr-box, .panel {
background: #111118 !important;
border: 1px solid #1e1e2e !important;
border-radius: 12px !important;
}
/* Labels */
label span {
font-family: 'DM Mono', monospace !important;
font-size: 0.72rem !important;
letter-spacing: 0.1em !important;
text-transform: uppercase !important;
color: #64748b !important;
}
/* Slider */
input[type=range] {
accent-color: #6366f1;
}
/* Generate button */
#gen-btn {
background: #6366f1 !important;
border: none !important;
border-radius: 8px !important;
font-family: 'Syne', sans-serif !important;
font-weight: 700 !important;
font-size: 0.95rem !important;
letter-spacing: 0.05em !important;
color: white !important;
padding: 0.75rem 2rem !important;
transition: background 0.2s, transform 0.1s !important;
width: 100% !important;
}
#gen-btn:hover {
background: #4f46e5 !important;
transform: translateY(-1px);
}
#gen-btn:active { transform: translateY(0); }
/* Status box */
#status textarea {
font-family: 'DM Mono', monospace !important;
font-size: 0.8rem !important;
background: #0a0a0f !important;
color: #64748b !important;
border: 1px solid #1e1e2e !important;
border-radius: 8px !important;
}
/* Gallery */
.gallery-item img {
border-radius: 8px !important;
transition: transform 0.2s !important;
}
.gallery-item img:hover {
transform: scale(1.03);
}
/* Footer */
.footer {
text-align: center;
padding: 2rem;
font-family: 'DM Mono', monospace;
font-size: 0.72rem;
color: #1e293b;
letter-spacing: 0.08em;
}
"""
def main():
with gr.Blocks(css=css, title="StyleGAN v1 — Face Synthesis") as demo:
gr.HTML("""
<div class="hero">
<div class="hero-title">StyleGAN v1</div>
<div class="hero-sub">Progressive Face Synthesis · 256 × 256</div>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
n_images = gr.Slider(
minimum=1, maximum=16, value=4, step=1,
label="Number of faces"
)
seed = gr.Textbox(
value="",
label="Seed (leave blank for random)",
placeholder="e.g. 42",
)
gen_btn = gr.Button("Generate ↗", variant="primary", elem_id="gen-btn")
status = gr.Textbox(
label="Status", interactive=False,
elem_id="status", lines=2
)
with gr.Column(scale=3):
gallery = gr.Gallery(
label="Generated faces",
columns=4,
rows=4, # ← allow up to 4 rows for 16 images
object_fit="contain", # ← was "cover" which zooms in, "contain" shows full image
height="auto", # ← was fixed 520px which cut off rows
show_label=False,
)
gr.HTML('<div class="footer">Built with Keras · JAX · Gradio</div>')
gen_btn.click(fn=generate, inputs=[n_images, seed], outputs=[gallery, status])
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
)
if __name__ == "__main__":
main()