import gradio as gr from PIL import Image import numpy as np import torch from collections import deque import base64 import io import os from src.model import GameNGen, ActionEncoder from src.config import ModelConfig, PredictionConfig from huggingface_hub import hf_hub_download from torchvision import transforms # --- Configuration and Model Loading --- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_config = ModelConfig() pred_config = PredictionConfig() print("Loading models...") engine = GameNGen(model_config.model_id, model_config.num_timesteps, history_len=model_config.history_len).to(device) cross_attention_dim = engine.unet.config.cross_attention_dim action_encoder = ActionEncoder(model_config.num_actions, cross_attention_dim).to(device) print("Models loaded.") # --- Model Weight and Asset Downloading --- output_dir = pred_config.output_dir os.makedirs(output_dir, exist_ok=True) def download_asset(filename, repo_id, repo_type="model"): """Downloads an asset from HF Hub, with a local fallback.""" local_path = os.path.join(output_dir, os.path.basename(filename)) if not os.path.exists(local_path): print(f"Downloading {filename} from {repo_id}...") try: hf_hub_download( repo_id=repo_id, filename=filename, local_dir=output_dir, repo_type=repo_type, local_dir_use_symlinks=False ) print(f"Successfully downloaded {filename}.") return local_path except Exception as e: print(f"Error downloading {filename}: {e}") gamelogs_path = os.path.join("gamelogs", filename) if os.path.exists(gamelogs_path): print(f"Using local file from gamelogs: {gamelogs_path}") return gamelogs_path print(f"Asset {filename} not found on Hub or locally.") return None return local_path # Load weights print("Loading model weights...") unet_path = download_asset("pytorch_lora_weights.bin" if model_config.use_lora else "unet.pth", pred_config.model_repo_id) if unet_path: if model_config.use_lora: state_dict = torch.load(unet_path, map_location=device) engine.unet.load_attn_procs(state_dict) print("LoRA weights loaded.") else: engine.unet.load_state_dict(torch.load(unet_path, map_location=device)) print("UNet weights loaded.") else: print("Warning: UNet weights not found. Using base UNet.") action_encoder_path = download_asset("action_encoder.pth", pred_config.model_repo_id) if action_encoder_path: action_encoder.load_state_dict(torch.load(action_encoder_path, map_location=device)) print("Action Encoder weights loaded.") else: print("Warning: Action encoder weights not found.") engine.eval() action_encoder.eval() # --- Image Transformations & Helpers --- transform = transforms.Compose([ transforms.Resize(model_config.image_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) action_map = pred_config.action_map def tensor_to_pil(tensor): tensor = (tensor.squeeze(0).cpu() / 2 + 0.5).clamp(0, 1) return transforms.ToPILImage()(tensor) # --- Core Logic for Gradio --- @torch.inference_mode() def start_game(): """Initializes a new game session and returns the first frame and state.""" print("Starting a new game session...") # Get initial frame first_frame_filename = "frames/frame_000000008.png" first_frame_path = download_asset(first_frame_filename, pred_config.dataset_repo_id, repo_type="dataset") if not first_frame_path: # Return a black screen as a fallback print("Could not load initial frame. Returning blank image.") return Image.new("RGB", (320, 240)), None, None pil_image = Image.open(first_frame_path).convert("RGB") # Initialize histories initial_frame_tensor = transform(pil_image).unsqueeze(0).to(device) initial_latent = engine.vae.encode(initial_frame_tensor).latent_dist.sample() frame_history = deque([initial_latent] * model_config.history_len, maxlen=model_config.history_len) noop_action = torch.tensor(action_map["noop"], dtype=torch.float32, device=device).unsqueeze(0) action_history = deque([noop_action] * model_config.history_len, maxlen=model_config.history_len) print("Game session started.") return pil_image, frame_history, action_history @torch.inference_mode() def predict_step(action_name, frame_history, action_history): """Predicts the next frame based on an action and the current state.""" if frame_history is None or action_history is None: return Image.new("RGB", (320, 240)), None, None print(f"Received action: {action_name}") action_list = action_map.get(action_name) action_tensor = torch.tensor(action_list, dtype=torch.float32, device=device).unsqueeze(0) # Inference history_latents = torch.cat(list(frame_history), dim=1) action_conditioning = action_encoder(action_tensor).unsqueeze(1) out_channels = 4 current_latents = torch.randn( (1, out_channels, model_config.image_size[0] // 8, model_config.image_size[1] // 8), device=device ) for t in engine.scheduler.timesteps: model_input = torch.cat([current_latents, history_latents], dim=1) noise_pred = engine(model_input, t, action_conditioning) current_latents = engine.scheduler.step(noise_pred, t, current_latents).prev_sample predicted_latent_unscaled = current_latents / engine.vae.config.scaling_factor image_tensor = engine.vae.decode(predicted_latent_unscaled).sample # Update State frame_history.append(predicted_latent_unscaled) action_history.append(action_tensor) # Convert to PIL for display pil_image = tensor_to_pil(image_tensor) print("Prediction complete.") return pil_image, frame_history, action_history # --- Gradio UI --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# Tiny Engine Game") gr.Markdown("Press 'Start Game' and then use the controls to generate the next frame.") # State variables to hold the session history between steps frame_history_state = gr.State(None) action_history_state = gr.State(None) with gr.Row(): start_button = gr.Button("Start Game", variant="primary") with gr.Row(): game_display = gr.Image(label="Game View", interactive=False) with gr.Row(): with gr.Column(): gr.Markdown("### Controls") fwd_button = gr.Button("W (Forward)") s_button = gr.Button("S (Backward)") a_button = gr.Button("A (Left)") d_button = gr.Button("D (Right)") turn_l_button = gr.Button("ArrowLeft (Turn Left)") turn_r_button = gr.Button("ArrowRight (Turn Right)") attack_button = gr.Button("Space (Attack)") # --- Button Click Handlers --- start_button.click( fn=start_game, inputs=[], outputs=[game_display, frame_history_state, action_history_state] ) action_buttons = [fwd_button, s_button, a_button, d_button, turn_l_button, turn_r_button, attack_button] action_names = ["w", "s", "a", "d", "ArrowLeft", "ArrowRight", " "] for button, name in zip(action_buttons, action_names): button.click( fn=predict_step, inputs=[gr.State(name), frame_history_state, action_history_state], outputs=[game_display, frame_history_state, action_history_state] ) if __name__ == "__main__": demo.launch()