Spaces:
Running
on
T4
Running
on
T4
| #!/usr/bin/env -S uv run --script | |
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "requests<3", | |
| # "pillow", | |
| # "opencv-python", | |
| # "pyboy", | |
| # "huggingface-hub", | |
| # "gradio", | |
| # "numpy", | |
| # "nitrogen @ git+https://github.com/MineDojo/NitroGen.git@main", | |
| # ] | |
| # [tool.uv] | |
| # exclude-newer = "2025-12-22T00:00:00Z" | |
| # /// | |
| """ | |
| Unified Gradio app for NitroGen Pokemon Red player with real-time streaming | |
| Combines model inference and PyBoy gameplay in a single interface | |
| """ | |
| import gradio as gr | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from pyboy import PyBoy | |
| from pyboy.utils import WindowEvent | |
| import time | |
| import tempfile | |
| import requests | |
| from huggingface_hub import HfFileSystem | |
| from nitrogen.inference_session import InferenceSession | |
| from nitrogen.shared import PATH_REPO, BUTTON_ACTION_TOKENS | |
| ROM_URL = "https://github.com/hxh-robb/pokemon-roms/raw/refs/heads/master/ROM/Pokemon%20-%20Red%20Version%20(USA,%20Europe).gb" | |
| STATE_PATH = "./init.state" | |
| # Game Boy button mapping | |
| GB_BUTTONS = { | |
| "A": WindowEvent.PRESS_BUTTON_A, | |
| "B": WindowEvent.PRESS_BUTTON_B, | |
| "START": WindowEvent.PRESS_BUTTON_START, | |
| "SELECT": WindowEvent.PRESS_BUTTON_SELECT, | |
| "UP": WindowEvent.PRESS_ARROW_UP, | |
| "DOWN": WindowEvent.PRESS_ARROW_DOWN, | |
| "LEFT": WindowEvent.PRESS_ARROW_LEFT, | |
| "RIGHT": WindowEvent.PRESS_ARROW_RIGHT, | |
| } | |
| GB_BUTTONS_RELEASE = { | |
| "A": WindowEvent.RELEASE_BUTTON_A, | |
| "B": WindowEvent.RELEASE_BUTTON_B, | |
| "START": WindowEvent.RELEASE_BUTTON_START, | |
| "SELECT": WindowEvent.RELEASE_BUTTON_SELECT, | |
| "UP": WindowEvent.RELEASE_ARROW_UP, | |
| "DOWN": WindowEvent.RELEASE_ARROW_DOWN, | |
| "LEFT": WindowEvent.RELEASE_ARROW_LEFT, | |
| "RIGHT": WindowEvent.RELEASE_ARROW_RIGHT, | |
| } | |
| def preprocess_img(frame): | |
| """Convert Game Boy frame to 256x256 RGB PIL Image for model input""" | |
| if isinstance(frame, Image.Image): | |
| frame = np.array(frame) | |
| if len(frame.shape) == 2: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) | |
| elif frame.shape[2] == 4: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) | |
| frame_resized = cv2.resize(frame, (256, 256), interpolation=cv2.INTER_AREA) | |
| return Image.fromarray(frame_resized) | |
| def gamepad_to_gameboy_buttons(pred, button_threshold=0.5, joystick_threshold=0.3): | |
| """Convert model's gamepad prediction to Game Boy button presses""" | |
| j_left, j_right, buttons = pred["j_left"], pred["j_right"], pred["buttons"] | |
| pressed_buttons = [] | |
| if len(buttons) == 0: | |
| return pressed_buttons | |
| button_vals = buttons[0] | |
| if len(button_vals) < len(BUTTON_ACTION_TOKENS): | |
| return pressed_buttons | |
| # D-Pad mapping (indices 1-4) | |
| if button_vals[1] > button_threshold: | |
| pressed_buttons.append("DOWN") | |
| if button_vals[2] > button_threshold: | |
| pressed_buttons.append("LEFT") | |
| if button_vals[3] > button_threshold: | |
| pressed_buttons.append("RIGHT") | |
| if button_vals[4] > button_threshold: | |
| pressed_buttons.append("UP") | |
| # Joystick fallback if no D-pad pressed | |
| if not any(b in pressed_buttons for b in ["UP", "DOWN", "LEFT", "RIGHT"]): | |
| if len(j_left) > 0: | |
| xl, yl = j_left[0] | |
| if abs(xl) > joystick_threshold or abs(yl) > joystick_threshold: | |
| if abs(xl) > abs(yl): | |
| if xl > joystick_threshold: | |
| pressed_buttons.append("RIGHT") | |
| elif xl < -joystick_threshold: | |
| pressed_buttons.append("LEFT") | |
| else: | |
| if yl > joystick_threshold: | |
| pressed_buttons.append("DOWN") | |
| elif yl < -joystick_threshold: | |
| pressed_buttons.append("UP") | |
| # Action buttons | |
| if button_vals[18] > button_threshold: # SOUTH -> A | |
| pressed_buttons.append("A") | |
| if button_vals[5] > button_threshold: # EAST -> B | |
| pressed_buttons.append("B") | |
| if button_vals[19] > button_threshold: # START | |
| pressed_buttons.append("START") | |
| if button_vals[0] > button_threshold: # BACK -> SELECT | |
| pressed_buttons.append("SELECT") | |
| # Alternative mappings | |
| if button_vals[10] > button_threshold and "A" not in pressed_buttons: # NORTH -> A | |
| pressed_buttons.append("A") | |
| if button_vals[20] > button_threshold and "B" not in pressed_buttons: # WEST -> B | |
| pressed_buttons.append("B") | |
| if button_vals[7] > button_threshold and "A" not in pressed_buttons: # LEFT_SHOULDER -> A | |
| pressed_buttons.append("A") | |
| if button_vals[14] > button_threshold and "B" not in pressed_buttons: # RIGHT_SHOULDER -> B | |
| pressed_buttons.append("B") | |
| return pressed_buttons | |
| def play_pokemon( | |
| cfg_scale: float, | |
| context_length: int, | |
| max_steps: int, | |
| frame_skip: int, | |
| button_threshold: float, | |
| display_every: int, | |
| update_delay: float | |
| ): | |
| """Generator that yields frames while playing Pokemon Red""" | |
| # Download ROM from URL | |
| yield None, "⏳ Downloading ROM file...", None | |
| try: | |
| temp_dir = Path(tempfile.gettempdir()) | |
| rom_path = temp_dir / "PokemonRed.gb" | |
| # Download ROM if not already cached | |
| if not rom_path.exists(): | |
| response = requests.get(ROM_URL, stream=True) | |
| response.raise_for_status() | |
| with open(rom_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| yield None, "✅ ROM downloaded successfully", None | |
| time.sleep(0.5) | |
| else: | |
| yield None, "✅ Using cached ROM", None | |
| time.sleep(0.3) | |
| except Exception as e: | |
| yield None, f"❌ Error downloading ROM: {str(e)}", None | |
| return | |
| # Download checkpoint from HuggingFace using HfFileSystem | |
| yield None, "⏳ Downloading checkpoint from nvidia/NitroGen...", None | |
| try: | |
| ckpt_path = temp_dir / "ng.pt" | |
| # Download checkpoint from HuggingFace Hub if not already cached | |
| if not ckpt_path.exists(): | |
| hffs = HfFileSystem() | |
| hffs.get_file("nvidia/NitroGen/ng.pt", str(ckpt_path)) | |
| if not ckpt_path.exists(): | |
| yield None, "❌ Failed to download checkpoint from HuggingFace", None | |
| return | |
| yield None, "✅ Checkpoint downloaded successfully", None | |
| time.sleep(0.5) | |
| else: | |
| yield None, "✅ Using cached checkpoint", None | |
| time.sleep(0.3) | |
| except Exception as e: | |
| yield None, f"❌ Error downloading checkpoint: {str(e)}", None | |
| return | |
| # Initialize inference session | |
| yield None, "⏳ Initializing inference session...", None | |
| session = InferenceSession.from_ckpt( | |
| str(ckpt_path), | |
| cfg_scale=cfg_scale, | |
| context_length=context_length | |
| ) | |
| session.reset() | |
| # Initialize PyBoy | |
| pyboy = PyBoy(str(rom_path), window="null") | |
| pyboy.set_emulation_speed(0) # Unlimited speed | |
| # Load save state if it exists | |
| state_path = Path(STATE_PATH) | |
| if state_path.exists(): | |
| with open(state_path, "rb") as f: | |
| pyboy.load_state(f) | |
| yield None, f"✅ Loaded save state: {STATE_PATH}", None | |
| time.sleep(0.3) | |
| else: | |
| yield None, f"⚠️ Save state not found: {STATE_PATH} (starting fresh)", None | |
| time.sleep(0.3) | |
| # Display settings | |
| width, height = 640, 576 | |
| step_count = 0 | |
| # Button timing: Press button briefly (4 frames), then release and wait | |
| # This prevents holding buttons for too long (which would cause repeated movement) | |
| # E.g., with frame_skip=16: press DOWN for 4 frames, release, wait 12 frames | |
| # Result: Character moves 1 tile down, not 16 tiles | |
| button_hold_frames = 4 | |
| try: | |
| while step_count < max_steps: | |
| # Get screen and predict | |
| screen = pyboy.screen.image | |
| obs_processed = preprocess_img(screen) | |
| pred = session.predict(obs_processed) | |
| # Convert to Game Boy buttons | |
| buttons_to_press = gamepad_to_gameboy_buttons(pred, button_threshold) | |
| # Press buttons | |
| for btn in buttons_to_press: | |
| pyboy.send_input(GB_BUTTONS[btn]) | |
| # Hold buttons for a few frames (so action registers) | |
| pyboy.tick(button_hold_frames, render=False) | |
| # Release buttons | |
| for btn in buttons_to_press: | |
| pyboy.send_input(GB_BUTTONS_RELEASE[btn]) | |
| # Tick remaining frames to complete the frame_skip cycle | |
| remaining_frames = frame_skip - button_hold_frames | |
| if remaining_frames > 1: | |
| pyboy.tick(remaining_frames - 1, render=False) | |
| if remaining_frames > 0: | |
| pyboy.tick() # Final tick with render | |
| else: | |
| pyboy.tick() # Render at least once | |
| # Yield display update at specified frequency | |
| if step_count % display_every == 0: | |
| # Get frame (lightweight - no text overlay) | |
| screen_np = pyboy.screen.ndarray | |
| if screen_np.shape[2] == 4: | |
| screen_np = screen_np[:, :, :3] | |
| # Simple resize | |
| frame_display = cv2.resize( | |
| screen_np, | |
| (width, height), | |
| interpolation=cv2.INTER_NEAREST | |
| ) | |
| # Create action info | |
| action_info = f"**Step {step_count}/{max_steps}**\n\n" | |
| action_info += f"🎮 **Buttons:** {', '.join(buttons_to_press) if buttons_to_press else 'None'}\n\n" | |
| action_info += f"⚡ **Speed:** {frame_skip}x frame skip\n\n" | |
| action_info += f"📊 **Progress:** {step_count/max_steps*100:.1f}%" | |
| # Create stats info | |
| stats_info = f"**Inference Details**\n\n" | |
| if len(pred.get("buttons", [])) > 0: | |
| button_vals = pred["buttons"][0] | |
| active_buttons = [ | |
| f"{BUTTON_ACTION_TOKENS[i]}: {button_vals[i]:.2f}" | |
| for i in range(min(len(button_vals), len(BUTTON_ACTION_TOKENS))) | |
| if button_vals[i] > button_threshold | |
| ] | |
| if active_buttons: | |
| stats_info += "**Active Predictions:**\n" | |
| stats_info += "\n".join(f"- {btn}" for btn in active_buttons[:5]) | |
| else: | |
| stats_info += "No buttons above threshold" | |
| # Yield frame and info (no encoding overhead) | |
| yield frame_display, action_info, stats_info | |
| # Delay to allow Gradio to load images properly | |
| time.sleep(update_delay) | |
| step_count += 1 | |
| finally: | |
| # Stop emulator | |
| pyboy.stop() | |
| # Create Gradio interface | |
| with gr.Blocks(title="NitroGen Pokemon Red Player") as app: | |
| gr.Markdown("# 🎮 NitroGen Pokemon Red Player") | |
| gr.Markdown("Stream Pokemon Red gameplay powered by NitroGen AI model") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🤖 Model Settings") | |
| gr.Markdown("**Model:** nvidia/NitroGen (ng.pt) - automatically downloaded from HuggingFace Hub") | |
| gr.Markdown("**ROM:** Automatically downloaded from configured URL") | |
| gr.Markdown(f"**Save State:** {STATE_PATH}") | |
| cfg_input = gr.Slider( | |
| label="CFG Scale", | |
| minimum=0.0, | |
| maximum=3.0, | |
| value=1.0, | |
| step=0.1, | |
| info="Classifier-free guidance scale" | |
| ) | |
| ctx_input = gr.Slider( | |
| label="Context Length", | |
| minimum=1, | |
| maximum=32, | |
| value=1, | |
| step=1, | |
| info="Number of past frames to use" | |
| ) | |
| gr.Markdown("### ⚙️ Playback Settings") | |
| max_steps_input = gr.Slider( | |
| label="Max Steps", | |
| minimum=100, | |
| maximum=10000, | |
| value=1000, | |
| step=100, | |
| info="Maximum inference steps" | |
| ) | |
| frame_skip_input = gr.Slider( | |
| label="Frame Skip", | |
| minimum=1, | |
| maximum=64, | |
| value=16, | |
| step=1, | |
| info="Emulator frames per inference" | |
| ) | |
| button_threshold_input = gr.Slider( | |
| label="Button Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.05, | |
| info="Threshold for button activation" | |
| ) | |
| display_every_input = gr.Slider( | |
| label="Display Every N Steps", | |
| minimum=1, | |
| maximum=10, | |
| value=1, | |
| step=1, | |
| info="Update display frequency (1=every step, higher=faster but less frequent)" | |
| ) | |
| update_delay_input = gr.Slider( | |
| label="Update Delay (seconds)", | |
| minimum=0.1, | |
| maximum=3.0, | |
| value=1.0, | |
| step=0.1, | |
| info="Wait time after each display update (higher=more time for image to load)" | |
| ) | |
| start_btn = gr.Button("🚀 Start Playing", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| image_output = gr.Image( | |
| label="Game Stream", | |
| height=600, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| action_output = gr.Markdown( | |
| label="Actions", | |
| value="**Waiting to start...**" | |
| ) | |
| with gr.Column(): | |
| stats_output = gr.Markdown( | |
| label="Statistics", | |
| value="**No data yet**" | |
| ) | |
| gr.Markdown(""" | |
| ### 📝 Instructions | |
| 1. Adjust playback settings as needed | |
| 2. Click "Start Playing" to begin streaming | |
| 3. Game frames update in real-time with actions | |
| **Automatic Setup:** | |
| - **Model**: nvidia/NitroGen checkpoint (ng.pt) from HuggingFace Hub | |
| - **ROM**: Downloaded from configured URL | |
| - **Save State**: Loaded from `./init.state` if available | |
| - Model and ROM are cached in temp directory for faster subsequent runs | |
| **Tips:** | |
| - **Display Every N Steps**: 1 = update every step, higher = faster but less frequent | |
| - **Update Delay**: 1s default gives images time to load, reduce for faster updates | |
| - **Frame Skip**: 16 = game runs 16 frames per inference (faster gameplay) | |
| """) | |
| # Connect the button to the play function | |
| start_btn.click( | |
| fn=play_pokemon, | |
| inputs=[ | |
| cfg_input, | |
| ctx_input, | |
| max_steps_input, | |
| frame_skip_input, | |
| button_threshold_input, | |
| display_every_input, | |
| update_delay_input | |
| ], | |
| outputs=[image_output, action_output, stats_output] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) | |