Spaces:
Runtime error
Runtime error
| """ | |
| MiniGenie — Interactive Gradio demo for HuggingFace Spaces. | |
| Users click actions to step through a CoinRun game frame-by-frame. | |
| The dynamics model predicts the next frame given 4 context frames + action. | |
| This is the HuggingFace Spaces entry point. It: | |
| 1. Downloads the model checkpoint from HuggingFace Hub on first launch | |
| 2. Loads seed episode frames bundled in the Space repo | |
| 3. Serves the Gradio UI on the Space's public URL | |
| Runs on CPU (free tier). Each frame takes ~30–60 seconds to generate. | |
| """ | |
| import os | |
| import random | |
| import time | |
| from typing import Dict, List, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import yaml | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| # HuggingFace Hub model repository — change this to your own repo | |
| HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "BrutalCaesar/minigenie-dynamics") | |
| HF_MODEL_FILENAME = os.environ.get("HF_MODEL_FILENAME", "step_0080000.pt") | |
| CKPT_DIR = "checkpoints/dynamics" | |
| DATA_DIR = "data/coinrun/episodes" | |
| CONFIG_PATH = "configs/dynamics.yaml" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| NUM_STEPS = 15 | |
| CFG_SCALE = 2.0 | |
| CONTEXT_LENGTH = 4 | |
| MAX_FILMSTRIP = 20 | |
| # CoinRun action mapping (Procgen 15-action space, only ~6 distinct in CoinRun) | |
| COINRUN_ACTIONS: Dict[str, int] = { | |
| "\u2b05\ufe0f Left": 1, | |
| "\u27a1\ufe0f Right": 7, | |
| "\u2b06\ufe0f Jump": 5, | |
| "\u2197\ufe0f Jump Right": 8, | |
| "\u2196\ufe0f Jump Left": 2, | |
| "\u23f8\ufe0f No-op": 4, | |
| } | |
| ACTION_NAMES: Dict[int, str] = {v: k for k, v in COINRUN_ACTIONS.items()} | |
| # --------------------------------------------------------------------------- | |
| # Model + Data loading | |
| # --------------------------------------------------------------------------- | |
| def download_checkpoint() -> str: | |
| """Download model checkpoint from HuggingFace Hub if not already cached. | |
| Returns: | |
| Path to the checkpoint file. | |
| """ | |
| os.makedirs(CKPT_DIR, exist_ok=True) | |
| local_path = os.path.join(CKPT_DIR, HF_MODEL_FILENAME) | |
| # Already downloaded? | |
| if os.path.exists(local_path): | |
| print(f"Checkpoint already exists: {local_path}") | |
| return local_path | |
| # Try HuggingFace Hub download | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| print(f"Downloading checkpoint from {HF_MODEL_REPO}/{HF_MODEL_FILENAME}...") | |
| downloaded = hf_hub_download( | |
| repo_id=HF_MODEL_REPO, | |
| filename=HF_MODEL_FILENAME, | |
| local_dir=CKPT_DIR, | |
| local_dir_use_symlinks=False, | |
| ) | |
| print(f"Checkpoint downloaded: {downloaded}") | |
| return downloaded | |
| except Exception as e: | |
| # Check if there's any .pt file in the checkpoint dir already | |
| import glob | |
| existing = sorted(glob.glob(os.path.join(CKPT_DIR, "step_*.pt"))) | |
| if existing: | |
| print(f"Hub download failed ({e}), using existing: {existing[-1]}") | |
| return existing[-1] | |
| raise FileNotFoundError( | |
| f"Could not download checkpoint from {HF_MODEL_REPO} and no local " | |
| f"checkpoint found in {CKPT_DIR}. Error: {e}" | |
| ) | |
| def load_model( | |
| ckpt_dir: str, | |
| config_path: str = CONFIG_PATH, | |
| device: str = DEVICE, | |
| ) -> Tuple[torch.nn.Module, int, dict]: | |
| """Load the trained dynamics model from checkpoint.""" | |
| from src.models.unet import UNet | |
| from src.training.checkpoint import CheckpointManager | |
| # Load config | |
| if os.path.exists(config_path): | |
| with open(config_path) as f: | |
| config = yaml.safe_load(f) | |
| else: | |
| config = {} | |
| mcfg = config.get("model", {}) | |
| model = UNet( | |
| in_channels=mcfg.get("in_channels", 15), | |
| out_channels=mcfg.get("out_channels", 3), | |
| channel_mult=mcfg.get("channel_mult", [64, 128, 256, 512]), | |
| cond_dim=mcfg.get("cond_dim", 512), | |
| num_actions=mcfg.get("num_actions", 15), | |
| num_groups=mcfg.get("num_groups", 32), | |
| cfg_dropout=0.0, # No dropout at inference | |
| ).to(device) | |
| ckpt_mgr = CheckpointManager(ckpt_dir) | |
| state = ckpt_mgr.load_latest() | |
| if state is None: | |
| raise FileNotFoundError(f"No checkpoint found in {ckpt_dir}") | |
| model.load_state_dict(state["model"]) | |
| model.eval() | |
| step = state["step"] | |
| return model, step, config | |
| def load_seed_frames( | |
| data_dir: str, | |
| num_seeds: int = 20, | |
| context_length: int = CONTEXT_LENGTH, | |
| ) -> List[np.ndarray]: | |
| """Load seed frame sequences from bundled episodes for the Reset button.""" | |
| from glob import glob | |
| npz_paths = sorted([ | |
| p for p in glob(os.path.join(data_dir, "*.npz")) | |
| if not os.path.basename(p).startswith("._") | |
| ]) | |
| if not npz_paths: | |
| raise FileNotFoundError(f"No .npz files found in {data_dir}") | |
| rng = random.Random(42) | |
| seeds = [] | |
| sampled_paths = rng.sample(npz_paths, min(num_seeds, len(npz_paths))) | |
| for path in sampled_paths: | |
| data = np.load(path) | |
| frames = data["frames"] # [T, H, W, 3] uint8 | |
| T = len(frames) | |
| if T < context_length + 1: | |
| continue | |
| max_start = T - context_length - 1 | |
| if max_start <= 0: | |
| continue | |
| start = rng.randint(0, max_start) | |
| seed_frames = frames[start : start + context_length].copy() | |
| seeds.append(seed_frames) | |
| if len(seeds) >= num_seeds: | |
| break | |
| if not seeds: | |
| raise ValueError("Could not extract any seed frames from episodes") | |
| return seeds | |
| # --------------------------------------------------------------------------- | |
| # Frame generation | |
| # --------------------------------------------------------------------------- | |
| def predict_next_frame( | |
| model: torch.nn.Module, | |
| context_frames: List[np.ndarray], | |
| action: int, | |
| num_steps: int = NUM_STEPS, | |
| cfg_scale: float = CFG_SCALE, | |
| device: str = DEVICE, | |
| ) -> np.ndarray: | |
| """Generate the next frame given context frames and an action.""" | |
| from src.training.train_dynamics import generate_next_frame | |
| tensors = [] | |
| for f in context_frames: | |
| if f.dtype == np.uint8: | |
| t = torch.from_numpy(f.copy()).float().div(255.0) | |
| else: | |
| t = torch.from_numpy(f.copy()).float() | |
| t = t.permute(2, 0, 1) # [3, h, w] | |
| tensors.append(t) | |
| context = torch.cat(tensors, dim=0).unsqueeze(0).to(device) # [1, 12, 64, 64] | |
| act = torch.tensor([action], dtype=torch.long, device=device) | |
| pred = generate_next_frame( | |
| model, context, act, | |
| num_steps=num_steps, | |
| cfg_scale=cfg_scale, | |
| ) # [1, 3, 64, 64] | |
| frame = pred[0].cpu().clamp(0, 1).permute(1, 2, 0).numpy() | |
| frame = (frame * 255).astype(np.uint8) | |
| return frame | |
| # --------------------------------------------------------------------------- | |
| # UI Helpers | |
| # --------------------------------------------------------------------------- | |
| def _make_filmstrip(frames: List[np.ndarray]) -> np.ndarray: | |
| """Stitch frames into a horizontal strip with subtle borders.""" | |
| if not frames: | |
| return np.zeros((68, 68, 3), dtype=np.uint8) | |
| bordered = [] | |
| for i, f in enumerate(frames): | |
| h, w = f.shape[:2] | |
| # 2px border: white for real frames (first 4), subtle gray for generated | |
| border_color = 255 if i < CONTEXT_LENGTH else 140 | |
| b = np.full((h + 4, w + 4, 3), border_color, dtype=np.uint8) | |
| b[2:-2, 2:-2] = f | |
| bordered.append(b) | |
| strip = np.concatenate(bordered, axis=1) | |
| return strip | |
| def _upscale_frame(frame: np.ndarray, size: int = 320) -> np.ndarray: | |
| """Upscale a 64x64 frame to display size using nearest neighbor (pixel-art style).""" | |
| from PIL import Image | |
| img = Image.fromarray(frame) | |
| img = img.resize((size, size), Image.NEAREST) | |
| return np.array(img) | |
| # --------------------------------------------------------------------------- | |
| # Gradio app | |
| # --------------------------------------------------------------------------- | |
| def create_demo( | |
| model: torch.nn.Module, | |
| seed_frames: List[np.ndarray], | |
| model_step: int, | |
| ) -> gr.Blocks: | |
| """Build the Gradio interface with loading indicators and visual polish.""" | |
| custom_css = """ | |
| /* --- Global --- */ | |
| .gradio-container { | |
| max-width: 1100px !important; | |
| margin: 0 auto !important; | |
| } | |
| footer { display: none !important; } | |
| /* --- Main frame --- */ | |
| .main-frame img { | |
| image-rendering: pixelated; | |
| border-radius: 8px; | |
| } | |
| /* --- Action panel --- */ | |
| .action-panel { | |
| border-radius: 12px; | |
| padding: 16px; | |
| } | |
| /* --- Filmstrip: full height, no clipping --- */ | |
| .filmstrip-wrap { | |
| margin-top: 4px; | |
| } | |
| .filmstrip-wrap img { | |
| image-rendering: pixelated; | |
| object-fit: contain; | |
| } | |
| /* --- Status text --- */ | |
| .status-box textarea { | |
| font-weight: 600 !important; | |
| font-size: 0.95em !important; | |
| border: none !important; | |
| background: transparent !important; | |
| } | |
| /* --- How it works section --- */ | |
| .how-section { | |
| margin-top: 16px; | |
| border-radius: 12px; | |
| padding: 18px 22px; | |
| line-height: 1.7; | |
| font-size: 0.93em; | |
| } | |
| .how-section ul { | |
| padding-left: 22px; | |
| margin-top: 6px; | |
| } | |
| .how-section li { | |
| margin-bottom: 4px; | |
| } | |
| /* --- Generating pulse --- */ | |
| @keyframes pulse { | |
| 0%, 100% { opacity: 1; } | |
| 50% { opacity: 0.7; } | |
| } | |
| """ | |
| def reset_state(): | |
| """Pick a random seed and reset the frame buffer.""" | |
| seed = random.choice(seed_frames) | |
| frame_buffer = [seed[i] for i in range(CONTEXT_LENGTH)] | |
| current = _upscale_frame(frame_buffer[-1]) | |
| filmstrip = _make_filmstrip(frame_buffer) | |
| status = "\U0001f7e2 Ready \u2014 pick an action to start generating!" | |
| return current, filmstrip, frame_buffer, 0, status | |
| def take_action(action_name, frame_buffer, step_count): | |
| """Generate next frame for the chosen action, with timing info.""" | |
| if frame_buffer is None or len(frame_buffer) < CONTEXT_LENGTH: | |
| return reset_state() | |
| action_idx = COINRUN_ACTIONS[action_name] | |
| context = frame_buffer[-CONTEXT_LENGTH:] | |
| start_time = time.time() | |
| next_frame = predict_next_frame( | |
| model, context, action_idx, | |
| num_steps=NUM_STEPS, | |
| cfg_scale=CFG_SCALE, | |
| device=DEVICE, | |
| ) | |
| elapsed = time.time() - start_time | |
| frame_buffer = frame_buffer + [next_frame] | |
| step_count += 1 | |
| current = _upscale_frame(next_frame) | |
| filmstrip = _make_filmstrip(frame_buffer[-MAX_FILMSTRIP:]) | |
| quality_note = "" | |
| if step_count >= 5: | |
| quality_note = " \u26a0\ufe0f Quality may degrade \u2014 try resetting!" | |
| status = ( | |
| f"\U0001f7e2 Step {step_count} \u2014 {action_name} " | |
| f"({elapsed:.1f}s){quality_note}" | |
| ) | |
| return current, filmstrip, frame_buffer, step_count, status | |
| # --- Build UI --- | |
| with gr.Blocks( | |
| title="MiniGenie \U0001f9de \u2014 World Model Demo", | |
| theme=gr.themes.Soft( | |
| primary_hue="violet", | |
| secondary_hue="indigo", | |
| neutral_hue="slate", | |
| font=gr.themes.GoogleFont("Inter"), | |
| ), | |
| css=custom_css, | |
| ) as demo: | |
| # --- Header --- | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 20px 0 10px 0;"> | |
| <h1 style="font-size: 2.5em; margin: 0; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; -webkit-text-fill-color: transparent;"> | |
| \U0001f9de MiniGenie | |
| </h1> | |
| <p style="font-size: 1.15em; color: #64748b; margin-top: 6px;"> | |
| Interactive World Model \u2014 Play CoinRun one frame at a time | |
| </p> | |
| </div> | |
| """) | |
| # Info banner | |
| device_label = "GPU 🚀" if DEVICE == "cuda" else "CPU 🐢" | |
| gr.HTML(f""" | |
| <div style="background: linear-gradient(135deg, #f0f4ff 0%, #f5f0ff 100%); | |
| border: 1px solid #e0d4f5; border-radius: 12px; | |
| padding: 14px 20px; margin-bottom: 12px;"> | |
| <div style="display: flex; justify-content: center; gap: 32px; | |
| flex-wrap: wrap; font-size: 0.92em; color: #475569;"> | |
| <span>🧠 <strong>42M-param U-Net</strong></span> | |
| <span>🌊 <strong>Flow Matching</strong> · 15 Euler steps</span> | |
| <span>🎯 <strong>PSNR 26.75 dB</strong> · SSIM 0.84</span> | |
| <span>📊 <strong>Trained {model_step:,} steps</strong></span> | |
| <span>💻 <strong>Running on {device_label}</strong></span> | |
| </div> | |
| </div> | |
| """) | |
| # --- State --- | |
| frame_buffer_state = gr.State(value=None) | |
| step_count_state = gr.State(value=0) | |
| with gr.Row(equal_height=True): | |
| # === Left: Main frame display === | |
| with gr.Column(scale=3): | |
| current_frame_display = gr.Image( | |
| label="Current Frame", | |
| height=384, | |
| width=384, | |
| show_label=False, | |
| interactive=False, | |
| show_download_button=False, | |
| elem_classes=["main-frame"], | |
| ) | |
| status_text = gr.Textbox( | |
| label="", | |
| interactive=False, | |
| value="\u23f3 Loading \u2014 please wait...", | |
| show_label=False, | |
| container=False, | |
| elem_classes=["status-box"], | |
| ) | |
| # === Right: Action panel === | |
| with gr.Column(scale=1, min_width=200, elem_classes=["action-panel"]): | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-bottom: 8px;"> | |
| <span style="font-size: 1.3em; font-weight: 700;">\U0001f3ae Actions</span> | |
| <p style="font-size: 0.82em; color: #94a3b8; margin: 4px 0 0 0;"> | |
| Click to generate the next frame | |
| </p> | |
| </div> | |
| """) | |
| action_buttons = {} | |
| btn_jump_left = gr.Button("\u2196\ufe0f Jump Left", variant="secondary", size="lg") | |
| action_buttons["\u2196\ufe0f Jump Left"] = btn_jump_left | |
| btn_jump = gr.Button("\u2b06\ufe0f Jump", variant="secondary", size="lg") | |
| action_buttons["\u2b06\ufe0f Jump"] = btn_jump | |
| btn_jump_right = gr.Button("\u2197\ufe0f Jump Right", variant="secondary", size="lg") | |
| action_buttons["\u2197\ufe0f Jump Right"] = btn_jump_right | |
| with gr.Row(): | |
| btn_left = gr.Button("\u2b05\ufe0f Left", variant="secondary", size="lg", scale=1) | |
| action_buttons["\u2b05\ufe0f Left"] = btn_left | |
| btn_right = gr.Button("\u27a1\ufe0f Right", variant="secondary", size="lg", scale=1) | |
| action_buttons["\u27a1\ufe0f Right"] = btn_right | |
| btn_noop = gr.Button("\u23f8\ufe0f No-op", variant="secondary", size="lg") | |
| action_buttons["\u23f8\ufe0f No-op"] = btn_noop | |
| gr.HTML('<hr style="border-color: #e2e8f0; margin: 12px 0;">') | |
| reset_btn = gr.Button( | |
| "\U0001f504 Reset / New Seed", | |
| variant="primary", | |
| size="lg", | |
| ) | |
| # --- Filmstrip --- | |
| gr.HTML(""" | |
| <div style="margin-top: 20px; margin-bottom: 4px;"> | |
| <span style="font-size: 1.15em; font-weight: 700;">\U0001f4fd\ufe0f Frame History</span> | |
| <span style="font-size: 0.82em; color: #94a3b8; margin-left: 8px;"> | |
| White border = real frames \u00b7 Gray border = model-generated \u00b7 Most recent on the right | |
| </span> | |
| </div> | |
| """) | |
| filmstrip_display = gr.Image( | |
| label="Filmstrip", | |
| height=120, | |
| show_label=False, | |
| interactive=False, | |
| show_download_button=False, | |
| elem_classes=["filmstrip-wrap"], | |
| ) | |
| # --- How it works (open by default, always visible) --- | |
| gr.HTML(f""" | |
| <details class="how-section" open> | |
| <summary style="cursor: pointer; font-weight: 700; font-size: 1.05em;"> | |
| \u2139\ufe0f How it works & tips | |
| </summary> | |
| <div style="margin-top: 10px;"> | |
| <p> | |
| <strong>Each click generates one frame</strong> using 15 steps of ODE integration | |
| with classifier-free guidance (scale {CFG_SCALE}). | |
| The model sees the <strong>last 4 frames</strong> as context. | |
| </p> | |
| <p> | |
| The first 4 frames (white borders) are real CoinRun frames from the dataset. | |
| All subsequent frames (gray borders) are entirely model-generated. | |
| </p> | |
| <ul> | |
| <li>\U0001f4a1 <strong>Best experience:</strong> Try 3\u20135 actions from a reset, then reset again</li> | |
| <li>\u23f1\ufe0f <strong>CPU inference:</strong> ~30\u201360 seconds per frame \u2014 be patient!</li> | |
| <li>\u26a0\ufe0f Quality degrades after ~5 generated steps (autoregressive error accumulation)</li> | |
| <li>\U0001f3b2 Action conditioning is still learning \u2014 different actions may look similar</li> | |
| </ul> | |
| <p style="margin-top: 8px; font-size: 0.85em; color: #94a3b8;"> | |
| Built entirely from scratch in PyTorch \u2014 no pretrained models or diffusion libraries. | |
| Model checkpoint: step {model_step:,} | Game: CoinRun | | |
| <a href="https://github.com/BrutalCaesar/minigenie" target="_blank" | |
| style="color: #7c3aed;">GitHub</a> | |
| </p> | |
| </div> | |
| </details> | |
| """) | |
| # --- Wire events --- | |
| all_outputs = [ | |
| current_frame_display, | |
| filmstrip_display, | |
| frame_buffer_state, | |
| step_count_state, | |
| status_text, | |
| ] | |
| reset_btn.click( | |
| fn=reset_state, | |
| inputs=[], | |
| outputs=all_outputs, | |
| ) | |
| for name, btn in action_buttons.items(): | |
| btn.click( | |
| fn=lambda: "\U0001f7e3 Generating next frame... (this takes ~30\u201360s on CPU)", | |
| inputs=[], | |
| outputs=[status_text], | |
| ).then( | |
| fn=take_action, | |
| inputs=[ | |
| gr.State(value=name), | |
| frame_buffer_state, | |
| step_count_state, | |
| ], | |
| outputs=all_outputs, | |
| ) | |
| demo.load(fn=reset_state, inputs=[], outputs=all_outputs) | |
| return demo | |
| # --------------------------------------------------------------------------- | |
| # App startup | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| """Entry point for HuggingFace Spaces.""" | |
| print("=" * 60) | |
| print("\U0001f9de MiniGenie \u2014 Starting HuggingFace Spaces demo") | |
| print(f" Device: {DEVICE}") | |
| print("=" * 60) | |
| # Step 1: Download checkpoint from HuggingFace Hub | |
| print("\n\U0001f4e5 Step 1/3: Downloading model checkpoint...") | |
| download_checkpoint() | |
| # Step 2: Load model | |
| print("\n\U0001f9e0 Step 2/3: Loading dynamics model...") | |
| model, step, config = load_model(CKPT_DIR, CONFIG_PATH, DEVICE) | |
| param_count = sum(p.numel() for p in model.parameters()) | |
| print(f" Model loaded: step {step:,}, {param_count:,} params") | |
| # Step 3: Load seed frames | |
| print("\n\U0001f3ae Step 3/3: Loading seed frames...") | |
| ctx_len = config.get("model", {}).get("context_frames", CONTEXT_LENGTH) | |
| seeds = load_seed_frames(DATA_DIR, num_seeds=20, context_length=ctx_len) | |
| print(f" Loaded {len(seeds)} seed frame sequences") | |
| # Build and launch | |
| print("\n\U0001f680 Building Gradio interface...") | |
| demo = create_demo(model, seeds, step) | |
| print("\U0001f310 Launching...") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, # HF Spaces provides the public URL | |
| show_error=True, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |