Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from collections import deque | |
| import base64 | |
| import io | |
| import os | |
| from src.model import GameNGen, ActionEncoder | |
| from src.config import ModelConfig, PredictionConfig | |
| from huggingface_hub import hf_hub_download | |
| from torchvision import transforms | |
| # --- Configuration and Model Loading --- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_config = ModelConfig() | |
| pred_config = PredictionConfig() | |
| print("Loading models...") | |
| engine = GameNGen(model_config.model_id, model_config.num_timesteps, history_len=model_config.history_len).to(device) | |
| cross_attention_dim = engine.unet.config.cross_attention_dim | |
| action_encoder = ActionEncoder(model_config.num_actions, cross_attention_dim).to(device) | |
| print("Models loaded.") | |
| # --- Model Weight and Asset Downloading --- | |
| output_dir = pred_config.output_dir | |
| os.makedirs(output_dir, exist_ok=True) | |
| def download_asset(filename, repo_id, repo_type="model"): | |
| """Downloads an asset from HF Hub, with a local fallback.""" | |
| local_path = os.path.join(output_dir, os.path.basename(filename)) | |
| if not os.path.exists(local_path): | |
| print(f"Downloading {filename} from {repo_id}...") | |
| try: | |
| hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| local_dir=output_dir, | |
| repo_type=repo_type, | |
| local_dir_use_symlinks=False | |
| ) | |
| print(f"Successfully downloaded {filename}.") | |
| return local_path | |
| except Exception as e: | |
| print(f"Error downloading {filename}: {e}") | |
| gamelogs_path = os.path.join("gamelogs", filename) | |
| if os.path.exists(gamelogs_path): | |
| print(f"Using local file from gamelogs: {gamelogs_path}") | |
| return gamelogs_path | |
| print(f"Asset {filename} not found on Hub or locally.") | |
| return None | |
| return local_path | |
| # Load weights | |
| print("Loading model weights...") | |
| unet_path = download_asset("pytorch_lora_weights.bin" if model_config.use_lora else "unet.pth", pred_config.model_repo_id) | |
| if unet_path: | |
| if model_config.use_lora: | |
| state_dict = torch.load(unet_path, map_location=device) | |
| engine.unet.load_attn_procs(state_dict) | |
| print("LoRA weights loaded.") | |
| else: | |
| engine.unet.load_state_dict(torch.load(unet_path, map_location=device)) | |
| print("UNet weights loaded.") | |
| else: | |
| print("Warning: UNet weights not found. Using base UNet.") | |
| action_encoder_path = download_asset("action_encoder.pth", pred_config.model_repo_id) | |
| if action_encoder_path: | |
| action_encoder.load_state_dict(torch.load(action_encoder_path, map_location=device)) | |
| print("Action Encoder weights loaded.") | |
| else: | |
| print("Warning: Action encoder weights not found.") | |
| engine.eval() | |
| action_encoder.eval() | |
| # --- Image Transformations & Helpers --- | |
| transform = transforms.Compose([ | |
| transforms.Resize(model_config.image_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]) | |
| ]) | |
| action_map = pred_config.action_map | |
| def tensor_to_pil(tensor): | |
| tensor = (tensor.squeeze(0).cpu() / 2 + 0.5).clamp(0, 1) | |
| return transforms.ToPILImage()(tensor) | |
| # --- Core Logic for Gradio --- | |
| def start_game(): | |
| """Initializes a new game session and returns the first frame and state.""" | |
| print("Starting a new game session...") | |
| # Get initial frame | |
| first_frame_filename = "frames/frame_000000008.png" | |
| first_frame_path = download_asset(first_frame_filename, pred_config.dataset_repo_id, repo_type="dataset") | |
| if not first_frame_path: | |
| # Return a black screen as a fallback | |
| print("Could not load initial frame. Returning blank image.") | |
| return Image.new("RGB", (320, 240)), None, None | |
| pil_image = Image.open(first_frame_path).convert("RGB") | |
| # Initialize histories | |
| initial_frame_tensor = transform(pil_image).unsqueeze(0).to(device) | |
| initial_latent = engine.vae.encode(initial_frame_tensor).latent_dist.sample() | |
| frame_history = deque([initial_latent] * model_config.history_len, maxlen=model_config.history_len) | |
| noop_action = torch.tensor(action_map["noop"], dtype=torch.float32, device=device).unsqueeze(0) | |
| action_history = deque([noop_action] * model_config.history_len, maxlen=model_config.history_len) | |
| print("Game session started.") | |
| return pil_image, frame_history, action_history | |
| def predict_step(action_name, frame_history, action_history): | |
| """Predicts the next frame based on an action and the current state.""" | |
| if frame_history is None or action_history is None: | |
| return Image.new("RGB", (320, 240)), None, None | |
| print(f"Received action: {action_name}") | |
| action_list = action_map.get(action_name) | |
| action_tensor = torch.tensor(action_list, dtype=torch.float32, device=device).unsqueeze(0) | |
| # Inference | |
| history_latents = torch.cat(list(frame_history), dim=1) | |
| action_conditioning = action_encoder(action_tensor).unsqueeze(1) | |
| out_channels = 4 | |
| current_latents = torch.randn( | |
| (1, out_channels, model_config.image_size[0] // 8, model_config.image_size[1] // 8), | |
| device=device | |
| ) | |
| for t in engine.scheduler.timesteps: | |
| model_input = torch.cat([current_latents, history_latents], dim=1) | |
| noise_pred = engine(model_input, t, action_conditioning) | |
| current_latents = engine.scheduler.step(noise_pred, t, current_latents).prev_sample | |
| predicted_latent_unscaled = current_latents / engine.vae.config.scaling_factor | |
| image_tensor = engine.vae.decode(predicted_latent_unscaled).sample | |
| # Update State | |
| frame_history.append(predicted_latent_unscaled) | |
| action_history.append(action_tensor) | |
| # Convert to PIL for display | |
| pil_image = tensor_to_pil(image_tensor) | |
| print("Prediction complete.") | |
| return pil_image, frame_history, action_history | |
| # --- Gradio UI --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Tiny Engine Game") | |
| gr.Markdown("Press 'Start Game' and then use the controls to generate the next frame.") | |
| # State variables to hold the session history between steps | |
| frame_history_state = gr.State(None) | |
| action_history_state = gr.State(None) | |
| with gr.Row(): | |
| start_button = gr.Button("Start Game", variant="primary") | |
| with gr.Row(): | |
| game_display = gr.Image(label="Game View", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Controls") | |
| fwd_button = gr.Button("W (Forward)") | |
| s_button = gr.Button("S (Backward)") | |
| a_button = gr.Button("A (Left)") | |
| d_button = gr.Button("D (Right)") | |
| turn_l_button = gr.Button("ArrowLeft (Turn Left)") | |
| turn_r_button = gr.Button("ArrowRight (Turn Right)") | |
| attack_button = gr.Button("Space (Attack)") | |
| # --- Button Click Handlers --- | |
| start_button.click( | |
| fn=start_game, | |
| inputs=[], | |
| outputs=[game_display, frame_history_state, action_history_state] | |
| ) | |
| action_buttons = [fwd_button, s_button, a_button, d_button, turn_l_button, turn_r_button, attack_button] | |
| action_names = ["w", "s", "a", "d", "ArrowLeft", "ArrowRight", " "] | |
| for button, name in zip(action_buttons, action_names): | |
| button.click( | |
| fn=predict_step, | |
| inputs=[gr.State(name), frame_history_state, action_history_state], | |
| outputs=[game_display, frame_history_state, action_history_state] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |