Your Name
Add Gradio application files
2ad4d00
import gradio as gr
from PIL import Image
import numpy as np
import torch
from collections import deque
import base64
import io
import os
from src.model import GameNGen, ActionEncoder
from src.config import ModelConfig, PredictionConfig
from huggingface_hub import hf_hub_download
from torchvision import transforms
# --- Configuration and Model Loading ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = ModelConfig()
pred_config = PredictionConfig()
print("Loading models...")
engine = GameNGen(model_config.model_id, model_config.num_timesteps, history_len=model_config.history_len).to(device)
cross_attention_dim = engine.unet.config.cross_attention_dim
action_encoder = ActionEncoder(model_config.num_actions, cross_attention_dim).to(device)
print("Models loaded.")
# --- Model Weight and Asset Downloading ---
output_dir = pred_config.output_dir
os.makedirs(output_dir, exist_ok=True)
def download_asset(filename, repo_id, repo_type="model"):
"""Downloads an asset from HF Hub, with a local fallback."""
local_path = os.path.join(output_dir, os.path.basename(filename))
if not os.path.exists(local_path):
print(f"Downloading {filename} from {repo_id}...")
try:
hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=output_dir,
repo_type=repo_type,
local_dir_use_symlinks=False
)
print(f"Successfully downloaded {filename}.")
return local_path
except Exception as e:
print(f"Error downloading {filename}: {e}")
gamelogs_path = os.path.join("gamelogs", filename)
if os.path.exists(gamelogs_path):
print(f"Using local file from gamelogs: {gamelogs_path}")
return gamelogs_path
print(f"Asset {filename} not found on Hub or locally.")
return None
return local_path
# Load weights
print("Loading model weights...")
unet_path = download_asset("pytorch_lora_weights.bin" if model_config.use_lora else "unet.pth", pred_config.model_repo_id)
if unet_path:
if model_config.use_lora:
state_dict = torch.load(unet_path, map_location=device)
engine.unet.load_attn_procs(state_dict)
print("LoRA weights loaded.")
else:
engine.unet.load_state_dict(torch.load(unet_path, map_location=device))
print("UNet weights loaded.")
else:
print("Warning: UNet weights not found. Using base UNet.")
action_encoder_path = download_asset("action_encoder.pth", pred_config.model_repo_id)
if action_encoder_path:
action_encoder.load_state_dict(torch.load(action_encoder_path, map_location=device))
print("Action Encoder weights loaded.")
else:
print("Warning: Action encoder weights not found.")
engine.eval()
action_encoder.eval()
# --- Image Transformations & Helpers ---
transform = transforms.Compose([
transforms.Resize(model_config.image_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
action_map = pred_config.action_map
def tensor_to_pil(tensor):
tensor = (tensor.squeeze(0).cpu() / 2 + 0.5).clamp(0, 1)
return transforms.ToPILImage()(tensor)
# --- Core Logic for Gradio ---
@torch.inference_mode()
def start_game():
"""Initializes a new game session and returns the first frame and state."""
print("Starting a new game session...")
# Get initial frame
first_frame_filename = "frames/frame_000000008.png"
first_frame_path = download_asset(first_frame_filename, pred_config.dataset_repo_id, repo_type="dataset")
if not first_frame_path:
# Return a black screen as a fallback
print("Could not load initial frame. Returning blank image.")
return Image.new("RGB", (320, 240)), None, None
pil_image = Image.open(first_frame_path).convert("RGB")
# Initialize histories
initial_frame_tensor = transform(pil_image).unsqueeze(0).to(device)
initial_latent = engine.vae.encode(initial_frame_tensor).latent_dist.sample()
frame_history = deque([initial_latent] * model_config.history_len, maxlen=model_config.history_len)
noop_action = torch.tensor(action_map["noop"], dtype=torch.float32, device=device).unsqueeze(0)
action_history = deque([noop_action] * model_config.history_len, maxlen=model_config.history_len)
print("Game session started.")
return pil_image, frame_history, action_history
@torch.inference_mode()
def predict_step(action_name, frame_history, action_history):
"""Predicts the next frame based on an action and the current state."""
if frame_history is None or action_history is None:
return Image.new("RGB", (320, 240)), None, None
print(f"Received action: {action_name}")
action_list = action_map.get(action_name)
action_tensor = torch.tensor(action_list, dtype=torch.float32, device=device).unsqueeze(0)
# Inference
history_latents = torch.cat(list(frame_history), dim=1)
action_conditioning = action_encoder(action_tensor).unsqueeze(1)
out_channels = 4
current_latents = torch.randn(
(1, out_channels, model_config.image_size[0] // 8, model_config.image_size[1] // 8),
device=device
)
for t in engine.scheduler.timesteps:
model_input = torch.cat([current_latents, history_latents], dim=1)
noise_pred = engine(model_input, t, action_conditioning)
current_latents = engine.scheduler.step(noise_pred, t, current_latents).prev_sample
predicted_latent_unscaled = current_latents / engine.vae.config.scaling_factor
image_tensor = engine.vae.decode(predicted_latent_unscaled).sample
# Update State
frame_history.append(predicted_latent_unscaled)
action_history.append(action_tensor)
# Convert to PIL for display
pil_image = tensor_to_pil(image_tensor)
print("Prediction complete.")
return pil_image, frame_history, action_history
# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Tiny Engine Game")
gr.Markdown("Press 'Start Game' and then use the controls to generate the next frame.")
# State variables to hold the session history between steps
frame_history_state = gr.State(None)
action_history_state = gr.State(None)
with gr.Row():
start_button = gr.Button("Start Game", variant="primary")
with gr.Row():
game_display = gr.Image(label="Game View", interactive=False)
with gr.Row():
with gr.Column():
gr.Markdown("### Controls")
fwd_button = gr.Button("W (Forward)")
s_button = gr.Button("S (Backward)")
a_button = gr.Button("A (Left)")
d_button = gr.Button("D (Right)")
turn_l_button = gr.Button("ArrowLeft (Turn Left)")
turn_r_button = gr.Button("ArrowRight (Turn Right)")
attack_button = gr.Button("Space (Attack)")
# --- Button Click Handlers ---
start_button.click(
fn=start_game,
inputs=[],
outputs=[game_display, frame_history_state, action_history_state]
)
action_buttons = [fwd_button, s_button, a_button, d_button, turn_l_button, turn_r_button, attack_button]
action_names = ["w", "s", "a", "d", "ArrowLeft", "ArrowRight", " "]
for button, name in zip(action_buttons, action_names):
button.click(
fn=predict_step,
inputs=[gr.State(name), frame_history_state, action_history_state],
outputs=[game_display, frame_history_state, action_history_state]
)
if __name__ == "__main__":
demo.launch()