minigenie / app.py
BrutalCaesar's picture
fix: UI polish — filmstrip height, how-it-works visible, cleaner CSS
5b4f0d0
"""
MiniGenie — Interactive Gradio demo for HuggingFace Spaces.
Users click actions to step through a CoinRun game frame-by-frame.
The dynamics model predicts the next frame given 4 context frames + action.
This is the HuggingFace Spaces entry point. It:
1. Downloads the model checkpoint from HuggingFace Hub on first launch
2. Loads seed episode frames bundled in the Space repo
3. Serves the Gradio UI on the Space's public URL
Runs on CPU (free tier). Each frame takes ~30–60 seconds to generate.
"""
import os
import random
import time
from typing import Dict, List, Tuple
import gradio as gr
import numpy as np
import torch
import yaml
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# HuggingFace Hub model repository — change this to your own repo
HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "BrutalCaesar/minigenie-dynamics")
HF_MODEL_FILENAME = os.environ.get("HF_MODEL_FILENAME", "step_0080000.pt")
CKPT_DIR = "checkpoints/dynamics"
DATA_DIR = "data/coinrun/episodes"
CONFIG_PATH = "configs/dynamics.yaml"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_STEPS = 15
CFG_SCALE = 2.0
CONTEXT_LENGTH = 4
MAX_FILMSTRIP = 20
# CoinRun action mapping (Procgen 15-action space, only ~6 distinct in CoinRun)
COINRUN_ACTIONS: Dict[str, int] = {
"\u2b05\ufe0f Left": 1,
"\u27a1\ufe0f Right": 7,
"\u2b06\ufe0f Jump": 5,
"\u2197\ufe0f Jump Right": 8,
"\u2196\ufe0f Jump Left": 2,
"\u23f8\ufe0f No-op": 4,
}
ACTION_NAMES: Dict[int, str] = {v: k for k, v in COINRUN_ACTIONS.items()}
# ---------------------------------------------------------------------------
# Model + Data loading
# ---------------------------------------------------------------------------
def download_checkpoint() -> str:
"""Download model checkpoint from HuggingFace Hub if not already cached.
Returns:
Path to the checkpoint file.
"""
os.makedirs(CKPT_DIR, exist_ok=True)
local_path = os.path.join(CKPT_DIR, HF_MODEL_FILENAME)
# Already downloaded?
if os.path.exists(local_path):
print(f"Checkpoint already exists: {local_path}")
return local_path
# Try HuggingFace Hub download
try:
from huggingface_hub import hf_hub_download
print(f"Downloading checkpoint from {HF_MODEL_REPO}/{HF_MODEL_FILENAME}...")
downloaded = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=HF_MODEL_FILENAME,
local_dir=CKPT_DIR,
local_dir_use_symlinks=False,
)
print(f"Checkpoint downloaded: {downloaded}")
return downloaded
except Exception as e:
# Check if there's any .pt file in the checkpoint dir already
import glob
existing = sorted(glob.glob(os.path.join(CKPT_DIR, "step_*.pt")))
if existing:
print(f"Hub download failed ({e}), using existing: {existing[-1]}")
return existing[-1]
raise FileNotFoundError(
f"Could not download checkpoint from {HF_MODEL_REPO} and no local "
f"checkpoint found in {CKPT_DIR}. Error: {e}"
)
def load_model(
ckpt_dir: str,
config_path: str = CONFIG_PATH,
device: str = DEVICE,
) -> Tuple[torch.nn.Module, int, dict]:
"""Load the trained dynamics model from checkpoint."""
from src.models.unet import UNet
from src.training.checkpoint import CheckpointManager
# Load config
if os.path.exists(config_path):
with open(config_path) as f:
config = yaml.safe_load(f)
else:
config = {}
mcfg = config.get("model", {})
model = UNet(
in_channels=mcfg.get("in_channels", 15),
out_channels=mcfg.get("out_channels", 3),
channel_mult=mcfg.get("channel_mult", [64, 128, 256, 512]),
cond_dim=mcfg.get("cond_dim", 512),
num_actions=mcfg.get("num_actions", 15),
num_groups=mcfg.get("num_groups", 32),
cfg_dropout=0.0, # No dropout at inference
).to(device)
ckpt_mgr = CheckpointManager(ckpt_dir)
state = ckpt_mgr.load_latest()
if state is None:
raise FileNotFoundError(f"No checkpoint found in {ckpt_dir}")
model.load_state_dict(state["model"])
model.eval()
step = state["step"]
return model, step, config
def load_seed_frames(
data_dir: str,
num_seeds: int = 20,
context_length: int = CONTEXT_LENGTH,
) -> List[np.ndarray]:
"""Load seed frame sequences from bundled episodes for the Reset button."""
from glob import glob
npz_paths = sorted([
p for p in glob(os.path.join(data_dir, "*.npz"))
if not os.path.basename(p).startswith("._")
])
if not npz_paths:
raise FileNotFoundError(f"No .npz files found in {data_dir}")
rng = random.Random(42)
seeds = []
sampled_paths = rng.sample(npz_paths, min(num_seeds, len(npz_paths)))
for path in sampled_paths:
data = np.load(path)
frames = data["frames"] # [T, H, W, 3] uint8
T = len(frames)
if T < context_length + 1:
continue
max_start = T - context_length - 1
if max_start <= 0:
continue
start = rng.randint(0, max_start)
seed_frames = frames[start : start + context_length].copy()
seeds.append(seed_frames)
if len(seeds) >= num_seeds:
break
if not seeds:
raise ValueError("Could not extract any seed frames from episodes")
return seeds
# ---------------------------------------------------------------------------
# Frame generation
# ---------------------------------------------------------------------------
@torch.no_grad()
def predict_next_frame(
model: torch.nn.Module,
context_frames: List[np.ndarray],
action: int,
num_steps: int = NUM_STEPS,
cfg_scale: float = CFG_SCALE,
device: str = DEVICE,
) -> np.ndarray:
"""Generate the next frame given context frames and an action."""
from src.training.train_dynamics import generate_next_frame
tensors = []
for f in context_frames:
if f.dtype == np.uint8:
t = torch.from_numpy(f.copy()).float().div(255.0)
else:
t = torch.from_numpy(f.copy()).float()
t = t.permute(2, 0, 1) # [3, h, w]
tensors.append(t)
context = torch.cat(tensors, dim=0).unsqueeze(0).to(device) # [1, 12, 64, 64]
act = torch.tensor([action], dtype=torch.long, device=device)
pred = generate_next_frame(
model, context, act,
num_steps=num_steps,
cfg_scale=cfg_scale,
) # [1, 3, 64, 64]
frame = pred[0].cpu().clamp(0, 1).permute(1, 2, 0).numpy()
frame = (frame * 255).astype(np.uint8)
return frame
# ---------------------------------------------------------------------------
# UI Helpers
# ---------------------------------------------------------------------------
def _make_filmstrip(frames: List[np.ndarray]) -> np.ndarray:
"""Stitch frames into a horizontal strip with subtle borders."""
if not frames:
return np.zeros((68, 68, 3), dtype=np.uint8)
bordered = []
for i, f in enumerate(frames):
h, w = f.shape[:2]
# 2px border: white for real frames (first 4), subtle gray for generated
border_color = 255 if i < CONTEXT_LENGTH else 140
b = np.full((h + 4, w + 4, 3), border_color, dtype=np.uint8)
b[2:-2, 2:-2] = f
bordered.append(b)
strip = np.concatenate(bordered, axis=1)
return strip
def _upscale_frame(frame: np.ndarray, size: int = 320) -> np.ndarray:
"""Upscale a 64x64 frame to display size using nearest neighbor (pixel-art style)."""
from PIL import Image
img = Image.fromarray(frame)
img = img.resize((size, size), Image.NEAREST)
return np.array(img)
# ---------------------------------------------------------------------------
# Gradio app
# ---------------------------------------------------------------------------
def create_demo(
model: torch.nn.Module,
seed_frames: List[np.ndarray],
model_step: int,
) -> gr.Blocks:
"""Build the Gradio interface with loading indicators and visual polish."""
custom_css = """
/* --- Global --- */
.gradio-container {
max-width: 1100px !important;
margin: 0 auto !important;
}
footer { display: none !important; }
/* --- Main frame --- */
.main-frame img {
image-rendering: pixelated;
border-radius: 8px;
}
/* --- Action panel --- */
.action-panel {
border-radius: 12px;
padding: 16px;
}
/* --- Filmstrip: full height, no clipping --- */
.filmstrip-wrap {
margin-top: 4px;
}
.filmstrip-wrap img {
image-rendering: pixelated;
object-fit: contain;
}
/* --- Status text --- */
.status-box textarea {
font-weight: 600 !important;
font-size: 0.95em !important;
border: none !important;
background: transparent !important;
}
/* --- How it works section --- */
.how-section {
margin-top: 16px;
border-radius: 12px;
padding: 18px 22px;
line-height: 1.7;
font-size: 0.93em;
}
.how-section ul {
padding-left: 22px;
margin-top: 6px;
}
.how-section li {
margin-bottom: 4px;
}
/* --- Generating pulse --- */
@keyframes pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.7; }
}
"""
def reset_state():
"""Pick a random seed and reset the frame buffer."""
seed = random.choice(seed_frames)
frame_buffer = [seed[i] for i in range(CONTEXT_LENGTH)]
current = _upscale_frame(frame_buffer[-1])
filmstrip = _make_filmstrip(frame_buffer)
status = "\U0001f7e2 Ready \u2014 pick an action to start generating!"
return current, filmstrip, frame_buffer, 0, status
def take_action(action_name, frame_buffer, step_count):
"""Generate next frame for the chosen action, with timing info."""
if frame_buffer is None or len(frame_buffer) < CONTEXT_LENGTH:
return reset_state()
action_idx = COINRUN_ACTIONS[action_name]
context = frame_buffer[-CONTEXT_LENGTH:]
start_time = time.time()
next_frame = predict_next_frame(
model, context, action_idx,
num_steps=NUM_STEPS,
cfg_scale=CFG_SCALE,
device=DEVICE,
)
elapsed = time.time() - start_time
frame_buffer = frame_buffer + [next_frame]
step_count += 1
current = _upscale_frame(next_frame)
filmstrip = _make_filmstrip(frame_buffer[-MAX_FILMSTRIP:])
quality_note = ""
if step_count >= 5:
quality_note = " \u26a0\ufe0f Quality may degrade \u2014 try resetting!"
status = (
f"\U0001f7e2 Step {step_count} \u2014 {action_name} "
f"({elapsed:.1f}s){quality_note}"
)
return current, filmstrip, frame_buffer, step_count, status
# --- Build UI ---
with gr.Blocks(
title="MiniGenie \U0001f9de \u2014 World Model Demo",
theme=gr.themes.Soft(
primary_hue="violet",
secondary_hue="indigo",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter"),
),
css=custom_css,
) as demo:
# --- Header ---
gr.HTML("""
<div style="text-align: center; padding: 20px 0 10px 0;">
<h1 style="font-size: 2.5em; margin: 0;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text; -webkit-text-fill-color: transparent;">
\U0001f9de MiniGenie
</h1>
<p style="font-size: 1.15em; color: #64748b; margin-top: 6px;">
Interactive World Model \u2014 Play CoinRun one frame at a time
</p>
</div>
""")
# Info banner
device_label = "GPU 🚀" if DEVICE == "cuda" else "CPU 🐢"
gr.HTML(f"""
<div style="background: linear-gradient(135deg, #f0f4ff 0%, #f5f0ff 100%);
border: 1px solid #e0d4f5; border-radius: 12px;
padding: 14px 20px; margin-bottom: 12px;">
<div style="display: flex; justify-content: center; gap: 32px;
flex-wrap: wrap; font-size: 0.92em; color: #475569;">
<span>🧠 <strong>42M-param U-Net</strong></span>
<span>🌊 <strong>Flow Matching</strong> · 15 Euler steps</span>
<span>🎯 <strong>PSNR 26.75 dB</strong> · SSIM 0.84</span>
<span>📊 <strong>Trained {model_step:,} steps</strong></span>
<span>💻 <strong>Running on {device_label}</strong></span>
</div>
</div>
""")
# --- State ---
frame_buffer_state = gr.State(value=None)
step_count_state = gr.State(value=0)
with gr.Row(equal_height=True):
# === Left: Main frame display ===
with gr.Column(scale=3):
current_frame_display = gr.Image(
label="Current Frame",
height=384,
width=384,
show_label=False,
interactive=False,
show_download_button=False,
elem_classes=["main-frame"],
)
status_text = gr.Textbox(
label="",
interactive=False,
value="\u23f3 Loading \u2014 please wait...",
show_label=False,
container=False,
elem_classes=["status-box"],
)
# === Right: Action panel ===
with gr.Column(scale=1, min_width=200, elem_classes=["action-panel"]):
gr.HTML("""
<div style="text-align: center; margin-bottom: 8px;">
<span style="font-size: 1.3em; font-weight: 700;">\U0001f3ae Actions</span>
<p style="font-size: 0.82em; color: #94a3b8; margin: 4px 0 0 0;">
Click to generate the next frame
</p>
</div>
""")
action_buttons = {}
btn_jump_left = gr.Button("\u2196\ufe0f Jump Left", variant="secondary", size="lg")
action_buttons["\u2196\ufe0f Jump Left"] = btn_jump_left
btn_jump = gr.Button("\u2b06\ufe0f Jump", variant="secondary", size="lg")
action_buttons["\u2b06\ufe0f Jump"] = btn_jump
btn_jump_right = gr.Button("\u2197\ufe0f Jump Right", variant="secondary", size="lg")
action_buttons["\u2197\ufe0f Jump Right"] = btn_jump_right
with gr.Row():
btn_left = gr.Button("\u2b05\ufe0f Left", variant="secondary", size="lg", scale=1)
action_buttons["\u2b05\ufe0f Left"] = btn_left
btn_right = gr.Button("\u27a1\ufe0f Right", variant="secondary", size="lg", scale=1)
action_buttons["\u27a1\ufe0f Right"] = btn_right
btn_noop = gr.Button("\u23f8\ufe0f No-op", variant="secondary", size="lg")
action_buttons["\u23f8\ufe0f No-op"] = btn_noop
gr.HTML('<hr style="border-color: #e2e8f0; margin: 12px 0;">')
reset_btn = gr.Button(
"\U0001f504 Reset / New Seed",
variant="primary",
size="lg",
)
# --- Filmstrip ---
gr.HTML("""
<div style="margin-top: 20px; margin-bottom: 4px;">
<span style="font-size: 1.15em; font-weight: 700;">\U0001f4fd\ufe0f Frame History</span>
<span style="font-size: 0.82em; color: #94a3b8; margin-left: 8px;">
White border = real frames \u00b7 Gray border = model-generated \u00b7 Most recent on the right
</span>
</div>
""")
filmstrip_display = gr.Image(
label="Filmstrip",
height=120,
show_label=False,
interactive=False,
show_download_button=False,
elem_classes=["filmstrip-wrap"],
)
# --- How it works (open by default, always visible) ---
gr.HTML(f"""
<details class="how-section" open>
<summary style="cursor: pointer; font-weight: 700; font-size: 1.05em;">
\u2139\ufe0f How it works & tips
</summary>
<div style="margin-top: 10px;">
<p>
<strong>Each click generates one frame</strong> using 15 steps of ODE integration
with classifier-free guidance (scale {CFG_SCALE}).
The model sees the <strong>last 4 frames</strong> as context.
</p>
<p>
The first 4 frames (white borders) are real CoinRun frames from the dataset.
All subsequent frames (gray borders) are entirely model-generated.
</p>
<ul>
<li>\U0001f4a1 <strong>Best experience:</strong> Try 3\u20135 actions from a reset, then reset again</li>
<li>\u23f1\ufe0f <strong>CPU inference:</strong> ~30\u201360 seconds per frame \u2014 be patient!</li>
<li>\u26a0\ufe0f Quality degrades after ~5 generated steps (autoregressive error accumulation)</li>
<li>\U0001f3b2 Action conditioning is still learning \u2014 different actions may look similar</li>
</ul>
<p style="margin-top: 8px; font-size: 0.85em; color: #94a3b8;">
Built entirely from scratch in PyTorch \u2014 no pretrained models or diffusion libraries.
Model checkpoint: step {model_step:,} | Game: CoinRun |
<a href="https://github.com/BrutalCaesar/minigenie" target="_blank"
style="color: #7c3aed;">GitHub</a>
</p>
</div>
</details>
""")
# --- Wire events ---
all_outputs = [
current_frame_display,
filmstrip_display,
frame_buffer_state,
step_count_state,
status_text,
]
reset_btn.click(
fn=reset_state,
inputs=[],
outputs=all_outputs,
)
for name, btn in action_buttons.items():
btn.click(
fn=lambda: "\U0001f7e3 Generating next frame... (this takes ~30\u201360s on CPU)",
inputs=[],
outputs=[status_text],
).then(
fn=take_action,
inputs=[
gr.State(value=name),
frame_buffer_state,
step_count_state,
],
outputs=all_outputs,
)
demo.load(fn=reset_state, inputs=[], outputs=all_outputs)
return demo
# ---------------------------------------------------------------------------
# App startup
# ---------------------------------------------------------------------------
def main():
"""Entry point for HuggingFace Spaces."""
print("=" * 60)
print("\U0001f9de MiniGenie \u2014 Starting HuggingFace Spaces demo")
print(f" Device: {DEVICE}")
print("=" * 60)
# Step 1: Download checkpoint from HuggingFace Hub
print("\n\U0001f4e5 Step 1/3: Downloading model checkpoint...")
download_checkpoint()
# Step 2: Load model
print("\n\U0001f9e0 Step 2/3: Loading dynamics model...")
model, step, config = load_model(CKPT_DIR, CONFIG_PATH, DEVICE)
param_count = sum(p.numel() for p in model.parameters())
print(f" Model loaded: step {step:,}, {param_count:,} params")
# Step 3: Load seed frames
print("\n\U0001f3ae Step 3/3: Loading seed frames...")
ctx_len = config.get("model", {}).get("context_frames", CONTEXT_LENGTH)
seeds = load_seed_frames(DATA_DIR, num_seeds=20, context_length=ctx_len)
print(f" Loaded {len(seeds)} seed frame sequences")
# Build and launch
print("\n\U0001f680 Building Gradio interface...")
demo = create_demo(model, seeds, step)
print("\U0001f310 Launching...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False, # HF Spaces provides the public URL
show_error=True,
)
if __name__ == "__main__":
main()