File size: 7,667 Bytes
2ad4d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import gradio as gr
from PIL import Image
import numpy as np
import torch
from collections import deque
import base64
import io
import os

from src.model import GameNGen, ActionEncoder
from src.config import ModelConfig, PredictionConfig
from huggingface_hub import hf_hub_download
from torchvision import transforms

# --- Configuration and Model Loading ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = ModelConfig()
pred_config = PredictionConfig()

print("Loading models...")
engine = GameNGen(model_config.model_id, model_config.num_timesteps, history_len=model_config.history_len).to(device)
cross_attention_dim = engine.unet.config.cross_attention_dim
action_encoder = ActionEncoder(model_config.num_actions, cross_attention_dim).to(device)
print("Models loaded.")

# --- Model Weight and Asset Downloading ---
output_dir = pred_config.output_dir
os.makedirs(output_dir, exist_ok=True)

def download_asset(filename, repo_id, repo_type="model"):
    """Downloads an asset from HF Hub, with a local fallback."""
    local_path = os.path.join(output_dir, os.path.basename(filename))
    if not os.path.exists(local_path):
        print(f"Downloading {filename} from {repo_id}...")
        try:
            hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                local_dir=output_dir,
                repo_type=repo_type,
                local_dir_use_symlinks=False
            )
            print(f"Successfully downloaded {filename}.")
            return local_path
        except Exception as e:
            print(f"Error downloading {filename}: {e}")
            gamelogs_path = os.path.join("gamelogs", filename)
            if os.path.exists(gamelogs_path):
                print(f"Using local file from gamelogs: {gamelogs_path}")
                return gamelogs_path
            print(f"Asset {filename} not found on Hub or locally.")
            return None
    return local_path

# Load weights
print("Loading model weights...")
unet_path = download_asset("pytorch_lora_weights.bin" if model_config.use_lora else "unet.pth", pred_config.model_repo_id)
if unet_path:
    if model_config.use_lora:
        state_dict = torch.load(unet_path, map_location=device)
        engine.unet.load_attn_procs(state_dict)
        print("LoRA weights loaded.")
    else:
        engine.unet.load_state_dict(torch.load(unet_path, map_location=device))
        print("UNet weights loaded.")
else:
    print("Warning: UNet weights not found. Using base UNet.")

action_encoder_path = download_asset("action_encoder.pth", pred_config.model_repo_id)
if action_encoder_path:
    action_encoder.load_state_dict(torch.load(action_encoder_path, map_location=device))
    print("Action Encoder weights loaded.")
else:
    print("Warning: Action encoder weights not found.")

engine.eval()
action_encoder.eval()

# --- Image Transformations & Helpers ---
transform = transforms.Compose([
    transforms.Resize(model_config.image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

action_map = pred_config.action_map

def tensor_to_pil(tensor):
    tensor = (tensor.squeeze(0).cpu() / 2 + 0.5).clamp(0, 1)
    return transforms.ToPILImage()(tensor)

# --- Core Logic for Gradio ---
@torch.inference_mode()
def start_game():
    """Initializes a new game session and returns the first frame and state."""
    print("Starting a new game session...")
    # Get initial frame
    first_frame_filename = "frames/frame_000000008.png"
    first_frame_path = download_asset(first_frame_filename, pred_config.dataset_repo_id, repo_type="dataset")

    if not first_frame_path:
        # Return a black screen as a fallback
        print("Could not load initial frame. Returning blank image.")
        return Image.new("RGB", (320, 240)), None, None

    pil_image = Image.open(first_frame_path).convert("RGB")

    # Initialize histories
    initial_frame_tensor = transform(pil_image).unsqueeze(0).to(device)
    initial_latent = engine.vae.encode(initial_frame_tensor).latent_dist.sample()
    
    frame_history = deque([initial_latent] * model_config.history_len, maxlen=model_config.history_len)
    
    noop_action = torch.tensor(action_map["noop"], dtype=torch.float32, device=device).unsqueeze(0)
    action_history = deque([noop_action] * model_config.history_len, maxlen=model_config.history_len)
    
    print("Game session started.")
    return pil_image, frame_history, action_history


@torch.inference_mode()
def predict_step(action_name, frame_history, action_history):
    """Predicts the next frame based on an action and the current state."""
    if frame_history is None or action_history is None:
        return Image.new("RGB", (320, 240)), None, None

    print(f"Received action: {action_name}")
    action_list = action_map.get(action_name)
    action_tensor = torch.tensor(action_list, dtype=torch.float32, device=device).unsqueeze(0)

    # Inference
    history_latents = torch.cat(list(frame_history), dim=1)
    action_conditioning = action_encoder(action_tensor).unsqueeze(1)

    out_channels = 4
    current_latents = torch.randn(
        (1, out_channels, model_config.image_size[0] // 8, model_config.image_size[1] // 8),
        device=device
    )

    for t in engine.scheduler.timesteps:
        model_input = torch.cat([current_latents, history_latents], dim=1)
        noise_pred = engine(model_input, t, action_conditioning)
        current_latents = engine.scheduler.step(noise_pred, t, current_latents).prev_sample

    predicted_latent_unscaled = current_latents / engine.vae.config.scaling_factor
    image_tensor = engine.vae.decode(predicted_latent_unscaled).sample

    # Update State
    frame_history.append(predicted_latent_unscaled)
    action_history.append(action_tensor)

    # Convert to PIL for display
    pil_image = tensor_to_pil(image_tensor)
    print("Prediction complete.")
    return pil_image, frame_history, action_history

# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Tiny Engine Game")
    gr.Markdown("Press 'Start Game' and then use the controls to generate the next frame.")

    # State variables to hold the session history between steps
    frame_history_state = gr.State(None)
    action_history_state = gr.State(None)
    
    with gr.Row():
        start_button = gr.Button("Start Game", variant="primary")
    
    with gr.Row():
        game_display = gr.Image(label="Game View", interactive=False)

    with gr.Row():
        with gr.Column():
            gr.Markdown("### Controls")
            fwd_button = gr.Button("W (Forward)")
            s_button = gr.Button("S (Backward)")
            a_button = gr.Button("A (Left)")
            d_button = gr.Button("D (Right)")
            turn_l_button = gr.Button("ArrowLeft (Turn Left)")
            turn_r_button = gr.Button("ArrowRight (Turn Right)")
            attack_button = gr.Button("Space (Attack)")

    # --- Button Click Handlers ---
    start_button.click(
        fn=start_game,
        inputs=[],
        outputs=[game_display, frame_history_state, action_history_state]
    )
    
    action_buttons = [fwd_button, s_button, a_button, d_button, turn_l_button, turn_r_button, attack_button]
    action_names = ["w", "s", "a", "d", "ArrowLeft", "ArrowRight", " "]

    for button, name in zip(action_buttons, action_names):
        button.click(
            fn=predict_step,
            inputs=[gr.State(name), frame_history_state, action_history_state],
            outputs=[game_display, frame_history_state, action_history_state]
        )

if __name__ == "__main__":
    demo.launch()