Spaces:
Sleeping
Sleeping
Your Name commited on
Commit Β·
2ad4d00
1
Parent(s): 7af425f
Add Gradio application files
Browse files- app.py +202 -0
- requirements.txt +9 -0
- src/__pycache__/config.cpython-312.pyc +0 -0
- src/__pycache__/model.cpython-312.pyc +0 -0
- src/agent.py +87 -0
- src/config.py +41 -0
- src/model.py +65 -0
- src/tiny_engine.egg-info/PKG-INFO +79 -0
- src/tiny_engine.egg-info/SOURCES.txt +11 -0
- src/tiny_engine.egg-info/dependency_links.txt +1 -0
- src/tiny_engine.egg-info/requires.txt +16 -0
- src/tiny_engine.egg-info/top_level.txt +4 -0
- src/train.py +318 -0
app.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from collections import deque
|
| 6 |
+
import base64
|
| 7 |
+
import io
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from src.model import GameNGen, ActionEncoder
|
| 11 |
+
from src.config import ModelConfig, PredictionConfig
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
|
| 15 |
+
# --- Configuration and Model Loading ---
|
| 16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 17 |
+
model_config = ModelConfig()
|
| 18 |
+
pred_config = PredictionConfig()
|
| 19 |
+
|
| 20 |
+
print("Loading models...")
|
| 21 |
+
engine = GameNGen(model_config.model_id, model_config.num_timesteps, history_len=model_config.history_len).to(device)
|
| 22 |
+
cross_attention_dim = engine.unet.config.cross_attention_dim
|
| 23 |
+
action_encoder = ActionEncoder(model_config.num_actions, cross_attention_dim).to(device)
|
| 24 |
+
print("Models loaded.")
|
| 25 |
+
|
| 26 |
+
# --- Model Weight and Asset Downloading ---
|
| 27 |
+
output_dir = pred_config.output_dir
|
| 28 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
def download_asset(filename, repo_id, repo_type="model"):
|
| 31 |
+
"""Downloads an asset from HF Hub, with a local fallback."""
|
| 32 |
+
local_path = os.path.join(output_dir, os.path.basename(filename))
|
| 33 |
+
if not os.path.exists(local_path):
|
| 34 |
+
print(f"Downloading {filename} from {repo_id}...")
|
| 35 |
+
try:
|
| 36 |
+
hf_hub_download(
|
| 37 |
+
repo_id=repo_id,
|
| 38 |
+
filename=filename,
|
| 39 |
+
local_dir=output_dir,
|
| 40 |
+
repo_type=repo_type,
|
| 41 |
+
local_dir_use_symlinks=False
|
| 42 |
+
)
|
| 43 |
+
print(f"Successfully downloaded {filename}.")
|
| 44 |
+
return local_path
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"Error downloading {filename}: {e}")
|
| 47 |
+
gamelogs_path = os.path.join("gamelogs", filename)
|
| 48 |
+
if os.path.exists(gamelogs_path):
|
| 49 |
+
print(f"Using local file from gamelogs: {gamelogs_path}")
|
| 50 |
+
return gamelogs_path
|
| 51 |
+
print(f"Asset {filename} not found on Hub or locally.")
|
| 52 |
+
return None
|
| 53 |
+
return local_path
|
| 54 |
+
|
| 55 |
+
# Load weights
|
| 56 |
+
print("Loading model weights...")
|
| 57 |
+
unet_path = download_asset("pytorch_lora_weights.bin" if model_config.use_lora else "unet.pth", pred_config.model_repo_id)
|
| 58 |
+
if unet_path:
|
| 59 |
+
if model_config.use_lora:
|
| 60 |
+
state_dict = torch.load(unet_path, map_location=device)
|
| 61 |
+
engine.unet.load_attn_procs(state_dict)
|
| 62 |
+
print("LoRA weights loaded.")
|
| 63 |
+
else:
|
| 64 |
+
engine.unet.load_state_dict(torch.load(unet_path, map_location=device))
|
| 65 |
+
print("UNet weights loaded.")
|
| 66 |
+
else:
|
| 67 |
+
print("Warning: UNet weights not found. Using base UNet.")
|
| 68 |
+
|
| 69 |
+
action_encoder_path = download_asset("action_encoder.pth", pred_config.model_repo_id)
|
| 70 |
+
if action_encoder_path:
|
| 71 |
+
action_encoder.load_state_dict(torch.load(action_encoder_path, map_location=device))
|
| 72 |
+
print("Action Encoder weights loaded.")
|
| 73 |
+
else:
|
| 74 |
+
print("Warning: Action encoder weights not found.")
|
| 75 |
+
|
| 76 |
+
engine.eval()
|
| 77 |
+
action_encoder.eval()
|
| 78 |
+
|
| 79 |
+
# --- Image Transformations & Helpers ---
|
| 80 |
+
transform = transforms.Compose([
|
| 81 |
+
transforms.Resize(model_config.image_size),
|
| 82 |
+
transforms.ToTensor(),
|
| 83 |
+
transforms.Normalize([0.5], [0.5])
|
| 84 |
+
])
|
| 85 |
+
|
| 86 |
+
action_map = pred_config.action_map
|
| 87 |
+
|
| 88 |
+
def tensor_to_pil(tensor):
|
| 89 |
+
tensor = (tensor.squeeze(0).cpu() / 2 + 0.5).clamp(0, 1)
|
| 90 |
+
return transforms.ToPILImage()(tensor)
|
| 91 |
+
|
| 92 |
+
# --- Core Logic for Gradio ---
|
| 93 |
+
@torch.inference_mode()
|
| 94 |
+
def start_game():
|
| 95 |
+
"""Initializes a new game session and returns the first frame and state."""
|
| 96 |
+
print("Starting a new game session...")
|
| 97 |
+
# Get initial frame
|
| 98 |
+
first_frame_filename = "frames/frame_000000008.png"
|
| 99 |
+
first_frame_path = download_asset(first_frame_filename, pred_config.dataset_repo_id, repo_type="dataset")
|
| 100 |
+
|
| 101 |
+
if not first_frame_path:
|
| 102 |
+
# Return a black screen as a fallback
|
| 103 |
+
print("Could not load initial frame. Returning blank image.")
|
| 104 |
+
return Image.new("RGB", (320, 240)), None, None
|
| 105 |
+
|
| 106 |
+
pil_image = Image.open(first_frame_path).convert("RGB")
|
| 107 |
+
|
| 108 |
+
# Initialize histories
|
| 109 |
+
initial_frame_tensor = transform(pil_image).unsqueeze(0).to(device)
|
| 110 |
+
initial_latent = engine.vae.encode(initial_frame_tensor).latent_dist.sample()
|
| 111 |
+
|
| 112 |
+
frame_history = deque([initial_latent] * model_config.history_len, maxlen=model_config.history_len)
|
| 113 |
+
|
| 114 |
+
noop_action = torch.tensor(action_map["noop"], dtype=torch.float32, device=device).unsqueeze(0)
|
| 115 |
+
action_history = deque([noop_action] * model_config.history_len, maxlen=model_config.history_len)
|
| 116 |
+
|
| 117 |
+
print("Game session started.")
|
| 118 |
+
return pil_image, frame_history, action_history
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@torch.inference_mode()
|
| 122 |
+
def predict_step(action_name, frame_history, action_history):
|
| 123 |
+
"""Predicts the next frame based on an action and the current state."""
|
| 124 |
+
if frame_history is None or action_history is None:
|
| 125 |
+
return Image.new("RGB", (320, 240)), None, None
|
| 126 |
+
|
| 127 |
+
print(f"Received action: {action_name}")
|
| 128 |
+
action_list = action_map.get(action_name)
|
| 129 |
+
action_tensor = torch.tensor(action_list, dtype=torch.float32, device=device).unsqueeze(0)
|
| 130 |
+
|
| 131 |
+
# Inference
|
| 132 |
+
history_latents = torch.cat(list(frame_history), dim=1)
|
| 133 |
+
action_conditioning = action_encoder(action_tensor).unsqueeze(1)
|
| 134 |
+
|
| 135 |
+
out_channels = 4
|
| 136 |
+
current_latents = torch.randn(
|
| 137 |
+
(1, out_channels, model_config.image_size[0] // 8, model_config.image_size[1] // 8),
|
| 138 |
+
device=device
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
for t in engine.scheduler.timesteps:
|
| 142 |
+
model_input = torch.cat([current_latents, history_latents], dim=1)
|
| 143 |
+
noise_pred = engine(model_input, t, action_conditioning)
|
| 144 |
+
current_latents = engine.scheduler.step(noise_pred, t, current_latents).prev_sample
|
| 145 |
+
|
| 146 |
+
predicted_latent_unscaled = current_latents / engine.vae.config.scaling_factor
|
| 147 |
+
image_tensor = engine.vae.decode(predicted_latent_unscaled).sample
|
| 148 |
+
|
| 149 |
+
# Update State
|
| 150 |
+
frame_history.append(predicted_latent_unscaled)
|
| 151 |
+
action_history.append(action_tensor)
|
| 152 |
+
|
| 153 |
+
# Convert to PIL for display
|
| 154 |
+
pil_image = tensor_to_pil(image_tensor)
|
| 155 |
+
print("Prediction complete.")
|
| 156 |
+
return pil_image, frame_history, action_history
|
| 157 |
+
|
| 158 |
+
# --- Gradio UI ---
|
| 159 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 160 |
+
gr.Markdown("# Tiny Engine Game")
|
| 161 |
+
gr.Markdown("Press 'Start Game' and then use the controls to generate the next frame.")
|
| 162 |
+
|
| 163 |
+
# State variables to hold the session history between steps
|
| 164 |
+
frame_history_state = gr.State(None)
|
| 165 |
+
action_history_state = gr.State(None)
|
| 166 |
+
|
| 167 |
+
with gr.Row():
|
| 168 |
+
start_button = gr.Button("Start Game", variant="primary")
|
| 169 |
+
|
| 170 |
+
with gr.Row():
|
| 171 |
+
game_display = gr.Image(label="Game View", interactive=False)
|
| 172 |
+
|
| 173 |
+
with gr.Row():
|
| 174 |
+
with gr.Column():
|
| 175 |
+
gr.Markdown("### Controls")
|
| 176 |
+
fwd_button = gr.Button("W (Forward)")
|
| 177 |
+
s_button = gr.Button("S (Backward)")
|
| 178 |
+
a_button = gr.Button("A (Left)")
|
| 179 |
+
d_button = gr.Button("D (Right)")
|
| 180 |
+
turn_l_button = gr.Button("ArrowLeft (Turn Left)")
|
| 181 |
+
turn_r_button = gr.Button("ArrowRight (Turn Right)")
|
| 182 |
+
attack_button = gr.Button("Space (Attack)")
|
| 183 |
+
|
| 184 |
+
# --- Button Click Handlers ---
|
| 185 |
+
start_button.click(
|
| 186 |
+
fn=start_game,
|
| 187 |
+
inputs=[],
|
| 188 |
+
outputs=[game_display, frame_history_state, action_history_state]
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
action_buttons = [fwd_button, s_button, a_button, d_button, turn_l_button, turn_r_button, attack_button]
|
| 192 |
+
action_names = ["w", "s", "a", "d", "ArrowLeft", "ArrowRight", " "]
|
| 193 |
+
|
| 194 |
+
for button, name in zip(action_buttons, action_names):
|
| 195 |
+
button.click(
|
| 196 |
+
fn=predict_step,
|
| 197 |
+
inputs=[gr.State(name), frame_history_state, action_history_state],
|
| 198 |
+
outputs=[game_display, frame_history_state, action_history_state]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
gradio
|
| 4 |
+
diffusers
|
| 5 |
+
transformers
|
| 6 |
+
huggingface_hub
|
| 7 |
+
Pillow
|
| 8 |
+
opencv-python-headless
|
| 9 |
+
accelerate
|
src/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (2.67 kB). View file
|
|
|
src/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (4.11 kB). View file
|
|
|
src/agent.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stable_baselines3 import PPO
|
| 2 |
+
from stable_baselines3.common.callbacks import BaseCallback
|
| 3 |
+
import os
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import logging
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
import csv
|
| 9 |
+
import gymnasium
|
| 10 |
+
from vizdoom import gymnasium_wrapper # This import is needed to register the env
|
| 11 |
+
|
| 12 |
+
DATASET_DIR = "gamelogs"
|
| 13 |
+
FRAMES_DIR = os.path.join(DATASET_DIR, "frames")
|
| 14 |
+
os.makedirs(FRAMES_DIR, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class NpEncoder(json.JSONEncoder):
|
| 18 |
+
def default(self, obj):
|
| 19 |
+
if isinstance(obj, np.integer):
|
| 20 |
+
return int(obj)
|
| 21 |
+
if isinstance(obj, np.floating):
|
| 22 |
+
return float(obj)
|
| 23 |
+
if isinstance(obj, np.ndarray):
|
| 24 |
+
return obj.tolist()
|
| 25 |
+
return super(NpEncoder, self).default(obj)
|
| 26 |
+
|
| 27 |
+
class GameNGenCallback(BaseCallback):
|
| 28 |
+
def __init__(self, verbose: bool, save_path: str):
|
| 29 |
+
super(GameNGenCallback, self).__init__(verbose)
|
| 30 |
+
self.save_path = save_path
|
| 31 |
+
self.frame_log = open(os.path.join(self.save_path, "metadata.csv"), mode="w", newline="")
|
| 32 |
+
self.csv_writer = csv.writer(self.frame_log)
|
| 33 |
+
# CSV Header
|
| 34 |
+
self.csv_writer.writerow(["frame_id", "action"])
|
| 35 |
+
|
| 36 |
+
def _on_step(self) -> bool:
|
| 37 |
+
frame_id = self.n_calls
|
| 38 |
+
key = f"{frame_id:09d}"
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
obs_dict = self.locals["new_obs"]
|
| 42 |
+
# The observation from the callback is in Channels-First format (C, H, W)
|
| 43 |
+
frame_data = obs_dict['screen'][0]
|
| 44 |
+
action = self.locals["actions"][0]
|
| 45 |
+
|
| 46 |
+
# --- DEFINITIVE FIX ---
|
| 47 |
+
# Check if the frame is in the expected Channels-First format (C, H, W).
|
| 48 |
+
# A valid RGB image will have 3 channels in its first dimension.
|
| 49 |
+
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
| 50 |
+
# Pillow's fromarray function needs the image in Channels-Last format (H, W, C).
|
| 51 |
+
# We must transpose the axes from (C, H, W) to (H, W, C).
|
| 52 |
+
transposed_frame = np.transpose(frame_data, (1, 2, 0))
|
| 53 |
+
image = Image.fromarray(transposed_frame)
|
| 54 |
+
image.save(os.path.join(FRAMES_DIR, f"frame_{key}.png"))
|
| 55 |
+
|
| 56 |
+
json_action = json.dumps(action, cls=NpEncoder)
|
| 57 |
+
self.csv_writer.writerow([key, json_action])
|
| 58 |
+
else:
|
| 59 |
+
# This will now correctly catch the junk frames from terminal states.
|
| 60 |
+
logging.warning(f"Skipping corrupted frame {key} with invalid shape: {frame_data.shape}")
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
# This will now only catch truly unexpected errors.
|
| 64 |
+
logging.error(f"Could not process or save frame {key} due to an unexpected error: {e}")
|
| 65 |
+
|
| 66 |
+
return True
|
| 67 |
+
|
| 68 |
+
def _on_training_end(self) -> None:
|
| 69 |
+
self.frame_log.close()
|
| 70 |
+
|
| 71 |
+
# --- Main script ---
|
| 72 |
+
logging.basicConfig(level=logging.INFO)
|
| 73 |
+
|
| 74 |
+
# Create the VizDoom environment. No wrappers are needed.
|
| 75 |
+
env = gymnasium.make("VizdoomHealthGatheringSupreme-v0")
|
| 76 |
+
|
| 77 |
+
callback = GameNGenCallback(verbose=True, save_path=DATASET_DIR)
|
| 78 |
+
|
| 79 |
+
model = PPO(
|
| 80 |
+
"MultiInputPolicy",
|
| 81 |
+
env,
|
| 82 |
+
verbose=1,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
model.learn(total_timesteps=2_000_000, callback=callback)
|
| 86 |
+
|
| 87 |
+
env.close()
|
src/config.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Dict, List, Tuple
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class ModelConfig:
|
| 6 |
+
"""Parameters defining the model architecture and basic properties."""
|
| 7 |
+
model_id: str = "CompVis/stable-diffusion-v1-4"
|
| 8 |
+
image_size: Tuple[int, int] = (240, 320)
|
| 9 |
+
num_timesteps: int = 100
|
| 10 |
+
history_len: int = 4
|
| 11 |
+
num_actions: int = 7
|
| 12 |
+
use_lora: bool = True
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class TrainingConfig:
|
| 16 |
+
"""Parameters specific to the training process."""
|
| 17 |
+
repo_id: str = "RevanthGundala/tiny_engine" # Dataset repository
|
| 18 |
+
learning_rate: float = 1e-4
|
| 19 |
+
subset_percentage: float = 0.01
|
| 20 |
+
batch_size: int = 16
|
| 21 |
+
num_epochs: int = 1
|
| 22 |
+
lora_rank: int = 4 # Only used if ModelConfig.use_lora is True
|
| 23 |
+
lora_alpha: int = 4 # Only used if ModelConfig.use_lora is True
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class PredictionConfig:
|
| 27 |
+
"""Parameters for the prediction server (app.py)."""
|
| 28 |
+
model_repo_id: str = "RevanthGundala/tiny_engine" # For model weights
|
| 29 |
+
dataset_repo_id: str = "RevanthGundala/tiny_engine" # For starting frame video
|
| 30 |
+
prediction_epoch: int = 99
|
| 31 |
+
output_dir: str = "output" # To load weights if not using MLflow
|
| 32 |
+
action_map: Dict[str, List[int]] = field(default_factory=lambda: {
|
| 33 |
+
"w": [1, 0, 0, 0, 0, 0, 0], # MOVE_FORWARD
|
| 34 |
+
"s": [0, 1, 0, 0, 0, 0, 0], # MOVE_BACKWARD
|
| 35 |
+
"d": [0, 0, 1, 0, 0, 0, 0], # MOVE_RIGHT
|
| 36 |
+
"a": [0, 0, 0, 1, 0, 0, 0], # MOVE_LEFT
|
| 37 |
+
"ArrowLeft": [0, 0, 0, 0, 1, 0, 0], # TURN_LEFT
|
| 38 |
+
"ArrowRight": [0, 0, 0, 0, 0, 1, 0], # TURN_RIGHT
|
| 39 |
+
" ": [0, 0, 0, 0, 0, 0, 1], # ATTACK
|
| 40 |
+
"noop": [0, 0, 0, 0, 0, 0, 0], # No operation
|
| 41 |
+
})
|
src/model.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
|
| 5 |
+
|
| 6 |
+
class GameNGen(nn.Module):
|
| 7 |
+
def __init__(self, model_id: str, timesteps: int, history_len: int):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.model_id = model_id
|
| 10 |
+
self.history_len = history_len
|
| 11 |
+
self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
|
| 12 |
+
self.unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
|
| 13 |
+
self.scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
| 14 |
+
self.scheduler.set_timesteps(timesteps)
|
| 15 |
+
|
| 16 |
+
# Modify the U-Net to accept history
|
| 17 |
+
original_in_channels = self.unet.config.in_channels # Should be 4
|
| 18 |
+
new_in_channels = original_in_channels * (1 + self.history_len)
|
| 19 |
+
|
| 20 |
+
original_conv_in = self.unet.conv_in
|
| 21 |
+
|
| 22 |
+
self.unet.conv_in = nn.Conv2d(
|
| 23 |
+
in_channels=new_in_channels,
|
| 24 |
+
out_channels=original_conv_in.out_channels,
|
| 25 |
+
kernel_size=original_conv_in.kernel_size,
|
| 26 |
+
stride=original_conv_in.stride,
|
| 27 |
+
padding=original_conv_in.padding,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Initialize the new weights
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
# Copy original weights for the main noisy latent
|
| 33 |
+
self.unet.conv_in.weight[:, :original_in_channels, :, :] = original_conv_in.weight
|
| 34 |
+
# Zero-initialize weights for the history latents
|
| 35 |
+
self.unet.conv_in.weight[:, original_in_channels:, :, :].zero_()
|
| 36 |
+
# Copy bias
|
| 37 |
+
self.unet.conv_in.bias = original_conv_in.bias
|
| 38 |
+
|
| 39 |
+
# Update the model's config
|
| 40 |
+
self.unet.config.in_channels = new_in_channels
|
| 41 |
+
|
| 42 |
+
# not training so freeze
|
| 43 |
+
self.vae.requires_grad_(False)
|
| 44 |
+
|
| 45 |
+
def forward(self, noisy_latents: torch.Tensor, timesteps: int, conditioning: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
noise_pred = self.unet(
|
| 47 |
+
sample=noisy_latents,
|
| 48 |
+
timestep=timesteps,
|
| 49 |
+
encoder_hidden_states=conditioning
|
| 50 |
+
).sample
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
return noise_pred
|
| 54 |
+
|
| 55 |
+
class ActionEncoder(nn.Module):
|
| 56 |
+
def __init__(self, num_actions: int, cross_attention_dim: int):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.encoder = nn.Sequential(
|
| 59 |
+
nn.Linear(in_features=num_actions, out_features=cross_attention_dim),
|
| 60 |
+
nn.SiLU(inplace=True),
|
| 61 |
+
nn.Linear(in_features=cross_attention_dim, out_features=cross_attention_dim)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
return self.encoder(x)
|
src/tiny_engine.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: tiny-engine
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Author-email: Revanth Gundala <revanth.gundala@gmail.com>
|
| 5 |
+
Requires-Python: <3.13,>=3.10
|
| 6 |
+
Description-Content-Type: text/markdown
|
| 7 |
+
Requires-Dist: torch@ https://download.pytorch.org/whl/cpu/torch-2.3.1%2Bcpu-cp312-cp312-linux_x86_64.whl
|
| 8 |
+
Requires-Dist: torchvision@ https://download.pytorch.org/whl/cpu/torchvision-0.18.1%2Bcpu-cp312-cp312-linux_x86_64.whl
|
| 9 |
+
Requires-Dist: vizdoom<2.0.0,>=1.2.3
|
| 10 |
+
Requires-Dist: pandas<3.0.0,>=2.2.0
|
| 11 |
+
Requires-Dist: opencv-python>=4.8.0
|
| 12 |
+
Requires-Dist: pillow<11.0.0,>=10.3.0
|
| 13 |
+
Requires-Dist: diffusers<0.28.0,>=0.27.2
|
| 14 |
+
Requires-Dist: stable-baselines3[extra]<3.0.0,>=2.3.0
|
| 15 |
+
Requires-Dist: transformers<5.0.0,>=4.40.0
|
| 16 |
+
Requires-Dist: accelerate<0.30.0,>=0.29.0
|
| 17 |
+
Requires-Dist: tqdm<5.0.0,>=4.66.0
|
| 18 |
+
Requires-Dist: peft<0.11.0,>=0.10.0
|
| 19 |
+
Requires-Dist: huggingface-hub<0.23.0,>=0.22.0
|
| 20 |
+
Requires-Dist: fastapi>=0.111.0
|
| 21 |
+
Requires-Dist: uvicorn[standard]>=0.29.0
|
| 22 |
+
Requires-Dist: python-multipart>=0.0.9
|
| 23 |
+
|
| 24 |
+
# Tiny Engine
|
| 25 |
+
|
| 26 |
+
This project uses a generative model to predict the next frame of a game based on the current frame and a player's action. It's served via a FastAPI backend and includes an interactive Next.js frontend.
|
| 27 |
+
|
| 28 |
+
## Setup and Installation
|
| 29 |
+
|
| 30 |
+
This project uses `uv` for Python package management.
|
| 31 |
+
|
| 32 |
+
1. **Install `uv`**:
|
| 33 |
+
If you don't have `uv` installed, follow the official installation instructions:
|
| 34 |
+
```bash
|
| 35 |
+
# For macOS and Linux:
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
2. **Create a Virtual Environment**:
|
| 40 |
+
```bash
|
| 41 |
+
uv venv
|
| 42 |
+
```
|
| 43 |
+
This will create a `.venv` directory in your project folder.
|
| 44 |
+
|
| 45 |
+
3. **Activate the Virtual Environment**:
|
| 46 |
+
```bash
|
| 47 |
+
# For macOS and Linux:
|
| 48 |
+
source .venv/bin/activate
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
4. **Install Python Dependencies**:
|
| 52 |
+
Install the required packages, including PyTorch from its specific download source.
|
| 53 |
+
```bash
|
| 54 |
+
uv pip install --find-links https://download.pytorch.org/whl/cpu -e .
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Running the Application
|
| 58 |
+
|
| 59 |
+
You need to run the backend and frontend servers in two separate terminals.
|
| 60 |
+
|
| 61 |
+
**1. Start the Backend Server**:
|
| 62 |
+
|
| 63 |
+
Make sure your virtual environment is activated.
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
uv run python app.py
|
| 67 |
+
```
|
| 68 |
+
The backend will be available at `http://localhost:8000`.
|
| 69 |
+
|
| 70 |
+
**2. Start the Frontend Server**:
|
| 71 |
+
|
| 72 |
+
In a new terminal, navigate to the `frontend` directory.
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
cd frontend
|
| 76 |
+
npm install
|
| 77 |
+
npm run dev
|
| 78 |
+
```
|
| 79 |
+
The frontend will be available at `http://localhost:3000`. You can now open this URL in your browser to play the game.
|
src/tiny_engine.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
src/agent.py
|
| 4 |
+
src/config.py
|
| 5 |
+
src/model.py
|
| 6 |
+
src/train.py
|
| 7 |
+
src/tiny_engine.egg-info/PKG-INFO
|
| 8 |
+
src/tiny_engine.egg-info/SOURCES.txt
|
| 9 |
+
src/tiny_engine.egg-info/dependency_links.txt
|
| 10 |
+
src/tiny_engine.egg-info/requires.txt
|
| 11 |
+
src/tiny_engine.egg-info/top_level.txt
|
src/tiny_engine.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
src/tiny_engine.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch@ https://download.pytorch.org/whl/cpu/torch-2.3.1%2Bcpu-cp312-cp312-linux_x86_64.whl
|
| 2 |
+
torchvision@ https://download.pytorch.org/whl/cpu/torchvision-0.18.1%2Bcpu-cp312-cp312-linux_x86_64.whl
|
| 3 |
+
vizdoom<2.0.0,>=1.2.3
|
| 4 |
+
pandas<3.0.0,>=2.2.0
|
| 5 |
+
opencv-python>=4.8.0
|
| 6 |
+
pillow<11.0.0,>=10.3.0
|
| 7 |
+
diffusers<0.28.0,>=0.27.2
|
| 8 |
+
stable-baselines3[extra]<3.0.0,>=2.3.0
|
| 9 |
+
transformers<5.0.0,>=4.40.0
|
| 10 |
+
accelerate<0.30.0,>=0.29.0
|
| 11 |
+
tqdm<5.0.0,>=4.66.0
|
| 12 |
+
peft<0.11.0,>=0.10.0
|
| 13 |
+
huggingface-hub<0.23.0,>=0.22.0
|
| 14 |
+
fastapi>=0.111.0
|
| 15 |
+
uvicorn[standard]>=0.29.0
|
| 16 |
+
python-multipart>=0.0.9
|
src/tiny_engine.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent
|
| 2 |
+
config
|
| 3 |
+
model
|
| 4 |
+
train
|
src/train.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import tqdm
|
| 2 |
+
from model import GameNGen, ActionEncoder
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.data import DataLoader, Dataset
|
| 6 |
+
from config import ModelConfig, TrainingConfig
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
import os
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import json
|
| 12 |
+
import logging
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from diffusers.optimization import get_cosine_schedule_with_warmup
|
| 15 |
+
from accelerate import Accelerator
|
| 16 |
+
from huggingface_hub import hf_hub_download
|
| 17 |
+
from peft import LoraConfig
|
| 18 |
+
import mlflow
|
| 19 |
+
import argparse
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class NextFrameDataset(Dataset):
|
| 23 |
+
def __init__(self, num_actions: int, metadata_path: str, frames_dir: str, image_size: tuple, history_len: int, subset_percentage: float):
|
| 24 |
+
self.metadata = pd.read_csv(metadata_path)
|
| 25 |
+
self.frames_dir = frames_dir
|
| 26 |
+
# List files and filter out non-image files if necessary
|
| 27 |
+
self.frame_files = sorted(
|
| 28 |
+
[f for f in os.listdir(frames_dir) if f.endswith('.png')],
|
| 29 |
+
key=lambda x: int(x.split('_')[1].split('.')[0])
|
| 30 |
+
)
|
| 31 |
+
# Calculate the number of frames to use based on the percentage
|
| 32 |
+
num_to_use = int(len(self.frame_files) * subset_percentage)
|
| 33 |
+
self.frame_files = self.frame_files[:num_to_use]
|
| 34 |
+
self.metadata = self.metadata.iloc[:num_to_use]
|
| 35 |
+
print(f"Using a {subset_percentage*100}% subset of the data: {len(self.frame_files)} frames.")
|
| 36 |
+
self.num_actions = num_actions
|
| 37 |
+
self.total_frames = len(self.frame_files)
|
| 38 |
+
self.history_len = history_len
|
| 39 |
+
|
| 40 |
+
self.transform = transforms.Compose([
|
| 41 |
+
transforms.Resize(image_size),
|
| 42 |
+
transforms.ToTensor(),
|
| 43 |
+
transforms.Normalize([0.5], [0.5]) # Normalize VAE to [-1, 1]
|
| 44 |
+
])
|
| 45 |
+
|
| 46 |
+
def __len__(self) -> int:
|
| 47 |
+
# We can't use the first `history_len` frames as they don't have enough history
|
| 48 |
+
return min(len(self.metadata), self.total_frames) - self.history_len - 1
|
| 49 |
+
|
| 50 |
+
def __getitem__(self, idx: int) -> dict:
|
| 51 |
+
# We are getting the item at `idx` in our shortened dataset.
|
| 52 |
+
# The actual index in the video/metadata is `idx + self.history_len`.
|
| 53 |
+
actual_idx = idx + self.history_len
|
| 54 |
+
|
| 55 |
+
history_frames = []
|
| 56 |
+
for i in range(self.history_len):
|
| 57 |
+
frame_idx = actual_idx - self.history_len + i
|
| 58 |
+
# Use the sorted file list to get the correct frame
|
| 59 |
+
img_path = os.path.join(self.frames_dir, self.frame_files[frame_idx])
|
| 60 |
+
try:
|
| 61 |
+
pil_image = Image.open(img_path).convert("RGB")
|
| 62 |
+
except FileNotFoundError:
|
| 63 |
+
raise IndexError(f"Could not read history frame {frame_idx} from {img_path}.")
|
| 64 |
+
history_frames.append(self.transform(pil_image))
|
| 65 |
+
|
| 66 |
+
history_tensor = torch.stack(history_frames)
|
| 67 |
+
|
| 68 |
+
# Get the target frame (next_frame)
|
| 69 |
+
next_frame_img_path = os.path.join(self.frames_dir, self.frame_files[actual_idx])
|
| 70 |
+
try:
|
| 71 |
+
next_pil_image = Image.open(next_frame_img_path).convert("RGB")
|
| 72 |
+
except FileNotFoundError:
|
| 73 |
+
raise IndexError(f"Could not read frame {actual_idx} from {next_frame_img_path}.")
|
| 74 |
+
next_image = self.transform(next_pil_image)
|
| 75 |
+
|
| 76 |
+
# Get the action that led to the `next_frame`
|
| 77 |
+
action_row = self.metadata.iloc[actual_idx]
|
| 78 |
+
action_data = json.loads(str(action_row['action']))
|
| 79 |
+
action_int = int(action_data[0] if isinstance(action_data, list) else action_data)
|
| 80 |
+
curr_action = torch.zeros(self.num_actions)
|
| 81 |
+
curr_action[action_int] = 1.0
|
| 82 |
+
|
| 83 |
+
return {
|
| 84 |
+
"frame_history": history_tensor,
|
| 85 |
+
"action": curr_action,
|
| 86 |
+
"next_frame": next_image
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
def train():
|
| 90 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
|
| 91 |
+
|
| 92 |
+
parser = argparse.ArgumentParser(description="GameNGen Finetuning")
|
| 93 |
+
parser.add_argument("--metadata_input", type=str, required=True, help="Path to the metadata CSV file")
|
| 94 |
+
parser.add_argument("--frames_input", type=str, required=True, help="Path to the frames directory")
|
| 95 |
+
parser.add_argument("--experiment_name", type=str, default="GameNGen Finetuning", help="Name of the MLflow experiment.")
|
| 96 |
+
args = parser.parse_args()
|
| 97 |
+
|
| 98 |
+
# --- MLflow Integration ---
|
| 99 |
+
# Check for Azure ML environment.
|
| 100 |
+
# The v1 SDK may set AZUREML_MLFLOW_URI, while v2 sets MLFLOW_TRACKING_URI.
|
| 101 |
+
is_azureml_env = "AZUREML_MLFLOW_URI" in os.environ or \
|
| 102 |
+
("MLFLOW_TRACKING_URI" in os.environ and "azureml" in os.environ["MLFLOW_TRACKING_URI"])
|
| 103 |
+
|
| 104 |
+
if is_azureml_env:
|
| 105 |
+
# In Azure ML, MLflow is configured automatically by environment variables.
|
| 106 |
+
# We don't need to set the tracking URI or experiment name.
|
| 107 |
+
logging.info("β
MLflow using Azure ML environment configuration.")
|
| 108 |
+
else:
|
| 109 |
+
# For local runs, explicitly set up a local tracking URI and experiment.
|
| 110 |
+
# This will save runs to a local 'mlruns' directory.
|
| 111 |
+
mlflow.set_tracking_uri("file:./mlruns")
|
| 112 |
+
mlflow.set_experiment(args.experiment_name)
|
| 113 |
+
logging.info(f"β οΈ Using local MLflow tracking (./mlruns) for experiment '{args.experiment_name}'.")
|
| 114 |
+
|
| 115 |
+
# --- Setup ---
|
| 116 |
+
accelerator = Accelerator(
|
| 117 |
+
mixed_precision="fp16",
|
| 118 |
+
gradient_accumulation_steps=1
|
| 119 |
+
)
|
| 120 |
+
model_config = ModelConfig()
|
| 121 |
+
train_config = TrainingConfig()
|
| 122 |
+
|
| 123 |
+
# Define file paths using the config
|
| 124 |
+
metadata_path = args.metadata_input
|
| 125 |
+
frames_dir = args.frames_input
|
| 126 |
+
|
| 127 |
+
engine = GameNGen(model_config.model_id, model_config.num_timesteps, history_len=model_config.history_len)
|
| 128 |
+
|
| 129 |
+
# --- Memory Saving Optimizations ---
|
| 130 |
+
engine.unet.enable_gradient_checkpointing()
|
| 131 |
+
# try:
|
| 132 |
+
# engine.unet.enable_xformers_memory_efficient_attention()
|
| 133 |
+
# logging.info("xformers memory-efficient attention enabled.")
|
| 134 |
+
# except ImportError:
|
| 135 |
+
# logging.warning("xformers is not installed. For better memory efficiency, run: pip install xformers")
|
| 136 |
+
|
| 137 |
+
dataset = NextFrameDataset(model_config.num_actions, metadata_path, frames_dir, model_config.image_size, history_len=model_config.history_len, subset_percentage=train_config.subset_percentage)
|
| 138 |
+
dataloader = DataLoader(
|
| 139 |
+
dataset=dataset,
|
| 140 |
+
batch_size=train_config.batch_size,
|
| 141 |
+
shuffle=True,
|
| 142 |
+
num_workers=0
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
cross_attention_dim = engine.unet.config.cross_attention_dim
|
| 146 |
+
action_encoder = ActionEncoder(model_config.num_actions, cross_attention_dim)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if model_config.use_lora:
|
| 150 |
+
engine.unet.requires_grad_(False)
|
| 151 |
+
lora_config = LoraConfig(
|
| 152 |
+
r=train_config.lora_rank,
|
| 153 |
+
lora_alpha=train_config.lora_alpha,
|
| 154 |
+
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
| 155 |
+
lora_dropout=0.1,
|
| 156 |
+
bias="lora_only",
|
| 157 |
+
)
|
| 158 |
+
engine.unet.add_adapter(lora_config)
|
| 159 |
+
lora_layers = filter(lambda p: p.requires_grad, engine.unet.parameters())
|
| 160 |
+
params_to_train = list(lora_layers) + list(action_encoder.parameters())
|
| 161 |
+
else:
|
| 162 |
+
params_to_train = list(engine.unet.parameters()) + list(action_encoder.parameters())
|
| 163 |
+
|
| 164 |
+
optim = torch.optim.AdamW(params=params_to_train, lr=train_config.learning_rate)
|
| 165 |
+
|
| 166 |
+
lr_scheduler = get_cosine_schedule_with_warmup(
|
| 167 |
+
optimizer=optim, num_warmup_steps=500, num_training_steps=len(dataloader) * train_config.num_epochs
|
| 168 |
+
)
|
| 169 |
+
engine, action_encoder, optim, dataloader, lr_scheduler = accelerator.prepare(
|
| 170 |
+
engine, action_encoder, optim, dataloader, lr_scheduler
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
mlflow.autolog(log_models=False)
|
| 174 |
+
|
| 175 |
+
# --- Add an output directory for checkpoints ---
|
| 176 |
+
output_dir = "./outputs"
|
| 177 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 178 |
+
|
| 179 |
+
logging.info("Starting training loop...")
|
| 180 |
+
|
| 181 |
+
mlflow.log_params({
|
| 182 |
+
"learning_rate": train_config.learning_rate,
|
| 183 |
+
"batch_size": train_config.batch_size,
|
| 184 |
+
"num_epochs": train_config.num_epochs,
|
| 185 |
+
"use_lora": model_config.use_lora,
|
| 186 |
+
"lora_rank": train_config.lora_rank if model_config.use_lora else None,
|
| 187 |
+
"subset_percentage": train_config.subset_percentage
|
| 188 |
+
})
|
| 189 |
+
|
| 190 |
+
global_step = 0
|
| 191 |
+
for epoch in range(train_config.num_epochs):
|
| 192 |
+
progress_bar = tqdm(total=len(dataloader), disable=not accelerator.is_local_main_process)
|
| 193 |
+
progress_bar.set_description(f"Epoch {epoch}")
|
| 194 |
+
for batch in dataloader:
|
| 195 |
+
optim.zero_grad()
|
| 196 |
+
next_frames, actions, frame_history = batch["next_frame"], batch["action"], batch["frame_history"]
|
| 197 |
+
|
| 198 |
+
# Encode into latent space
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
vae = accelerator.unwrap_model(engine).vae
|
| 201 |
+
latent_dist = vae.encode(next_frames).latent_dist
|
| 202 |
+
clean_latents = latent_dist.sample() * vae.config.scaling_factor
|
| 203 |
+
|
| 204 |
+
# Encode history frames
|
| 205 |
+
bs, hist_len, C, H, W = frame_history.shape
|
| 206 |
+
frame_history = frame_history.view(bs * hist_len, C, H, W)
|
| 207 |
+
history_latents = vae.encode(frame_history).latent_dist.sample()
|
| 208 |
+
_, latent_C, latent_H, latent_W = history_latents.shape
|
| 209 |
+
history_latents = history_latents.reshape(bs, hist_len * latent_C, latent_H, latent_W)
|
| 210 |
+
|
| 211 |
+
# Add noise to history latents to prevent drift (noise augmentation)
|
| 212 |
+
noise_level = 0.1 # Start with a small, fixed amount of noise
|
| 213 |
+
history_noise = torch.randn_like(history_latents) * noise_level
|
| 214 |
+
corrupted_history_latents = history_latents + history_noise
|
| 215 |
+
|
| 216 |
+
# Conditioning is now only the action
|
| 217 |
+
action_conditioning = action_encoder(actions)
|
| 218 |
+
conditioning_batch = action_conditioning.unsqueeze(1)
|
| 219 |
+
|
| 220 |
+
# create random noise
|
| 221 |
+
noise = torch.randn_like(clean_latents)
|
| 222 |
+
|
| 223 |
+
# pick random timestep. High timstep means more noise
|
| 224 |
+
timesteps = torch.randint(0, engine.scheduler.config.num_train_timesteps, (clean_latents.shape[0], ), device=clean_latents.device).long()
|
| 225 |
+
|
| 226 |
+
noisy_latents = engine.scheduler.add_noise(clean_latents, noise, timesteps)
|
| 227 |
+
|
| 228 |
+
# Concatenate history latents with noisy latents
|
| 229 |
+
model_input = torch.cat([noisy_latents, corrupted_history_latents], dim=1)
|
| 230 |
+
|
| 231 |
+
with accelerator.accumulate(engine):
|
| 232 |
+
noise_pred = engine(model_input, timesteps, conditioning_batch)
|
| 233 |
+
loss = F.mse_loss(noise_pred, noise)
|
| 234 |
+
accelerator.backward(loss)
|
| 235 |
+
|
| 236 |
+
accelerator.clip_grad_norm_(engine.unet.parameters(), 1.0)
|
| 237 |
+
optim.step()
|
| 238 |
+
lr_scheduler.step()
|
| 239 |
+
|
| 240 |
+
progress_bar.update(1)
|
| 241 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
|
| 242 |
+
|
| 243 |
+
# Log metrics to MLflow
|
| 244 |
+
if global_step % 10 == 0: # Log every 10 steps to avoid too much overhead
|
| 245 |
+
mlflow.log_metric("loss", logs["loss"], step=global_step)
|
| 246 |
+
mlflow.log_metric("learning_rate", logs["lr"], step=global_step)
|
| 247 |
+
|
| 248 |
+
progress_bar.set_postfix(**logs)
|
| 249 |
+
global_step += 1
|
| 250 |
+
|
| 251 |
+
progress_bar.close()
|
| 252 |
+
|
| 253 |
+
if accelerator.is_main_process:
|
| 254 |
+
logging.info(f"Epoch {epoch} complete. Saving checkpoint...")
|
| 255 |
+
|
| 256 |
+
# Define a unique directory for this epoch's checkpoint
|
| 257 |
+
checkpoint_dir = os.path.join(output_dir, f"checkpoint_epoch_{epoch}")
|
| 258 |
+
|
| 259 |
+
# Use accelerator.save_state to save everything
|
| 260 |
+
accelerator.save_state(checkpoint_dir)
|
| 261 |
+
|
| 262 |
+
logging.info(f"Checkpoint saved to {checkpoint_dir}")
|
| 263 |
+
|
| 264 |
+
# Save models at the end of training
|
| 265 |
+
if accelerator.is_main_process:
|
| 266 |
+
unwrapped_unet = accelerator.unwrap_model(engine).unet
|
| 267 |
+
unwrapped_action_encoder = accelerator.unwrap_model(action_encoder)
|
| 268 |
+
|
| 269 |
+
try:
|
| 270 |
+
# Log the action encoder
|
| 271 |
+
mlflow.pytorch.log_model(unwrapped_action_encoder, "action_encoder")
|
| 272 |
+
logging.info("β
Action encoder logged to MLflow")
|
| 273 |
+
|
| 274 |
+
# Log the UNet (or its LoRA weights)
|
| 275 |
+
if model_config.use_lora:
|
| 276 |
+
from peft import get_peft_model_state_dict
|
| 277 |
+
import json
|
| 278 |
+
|
| 279 |
+
lora_save_path = "unet_lora_weights"
|
| 280 |
+
os.makedirs(lora_save_path, exist_ok=True)
|
| 281 |
+
|
| 282 |
+
# Save LoRA weights using PEFT method
|
| 283 |
+
lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
|
| 284 |
+
torch.save(lora_state_dict, os.path.join(lora_save_path, "pytorch_lora_weights.bin"))
|
| 285 |
+
|
| 286 |
+
# Save adapter config
|
| 287 |
+
adapter_config = unwrapped_unet.peft_config
|
| 288 |
+
with open(os.path.join(lora_save_path, "adapter_config.json"), "w") as f:
|
| 289 |
+
json.dump(adapter_config, f, indent=2, default=str)
|
| 290 |
+
|
| 291 |
+
mlflow.log_artifacts(lora_save_path, artifact_path="unet_lora")
|
| 292 |
+
logging.info("β
LoRA weights logged to MLflow")
|
| 293 |
+
else:
|
| 294 |
+
mlflow.pytorch.log_model(unwrapped_unet, "unet")
|
| 295 |
+
logging.info("β
UNet logged to MLflow")
|
| 296 |
+
|
| 297 |
+
logging.info(f"β
Training completed. MLflow Run ID: {mlflow.active_run().info.run_id}")
|
| 298 |
+
|
| 299 |
+
except Exception as e:
|
| 300 |
+
logging.error(f"β Error logging models to MLflow: {e}")
|
| 301 |
+
# Save models locally as fallback
|
| 302 |
+
torch.save(unwrapped_action_encoder.state_dict(), os.path.join(output_dir, "action_encoder.pth"))
|
| 303 |
+
if model_config.use_lora:
|
| 304 |
+
try:
|
| 305 |
+
from peft import get_peft_model_state_dict
|
| 306 |
+
lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
|
| 307 |
+
torch.save(lora_state_dict, os.path.join(output_dir, "lora_weights.bin"))
|
| 308 |
+
logging.info("π LoRA weights saved locally")
|
| 309 |
+
except Exception as lora_e:
|
| 310 |
+
logging.error(f"β Error saving LoRA weights: {lora_e}")
|
| 311 |
+
torch.save(unwrapped_unet.state_dict(), os.path.join(output_dir, "unet_full.pth"))
|
| 312 |
+
logging.info("π Full UNet saved locally as fallback")
|
| 313 |
+
else:
|
| 314 |
+
torch.save(unwrapped_unet.state_dict(), os.path.join(output_dir, "unet.pth"))
|
| 315 |
+
logging.info("π Models saved locally as fallback")
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
train()
|