Your Name commited on
Commit
2ad4d00
Β·
1 Parent(s): 7af425f

Add Gradio application files

Browse files
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+ from collections import deque
6
+ import base64
7
+ import io
8
+ import os
9
+
10
+ from src.model import GameNGen, ActionEncoder
11
+ from src.config import ModelConfig, PredictionConfig
12
+ from huggingface_hub import hf_hub_download
13
+ from torchvision import transforms
14
+
15
+ # --- Configuration and Model Loading ---
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model_config = ModelConfig()
18
+ pred_config = PredictionConfig()
19
+
20
+ print("Loading models...")
21
+ engine = GameNGen(model_config.model_id, model_config.num_timesteps, history_len=model_config.history_len).to(device)
22
+ cross_attention_dim = engine.unet.config.cross_attention_dim
23
+ action_encoder = ActionEncoder(model_config.num_actions, cross_attention_dim).to(device)
24
+ print("Models loaded.")
25
+
26
+ # --- Model Weight and Asset Downloading ---
27
+ output_dir = pred_config.output_dir
28
+ os.makedirs(output_dir, exist_ok=True)
29
+
30
+ def download_asset(filename, repo_id, repo_type="model"):
31
+ """Downloads an asset from HF Hub, with a local fallback."""
32
+ local_path = os.path.join(output_dir, os.path.basename(filename))
33
+ if not os.path.exists(local_path):
34
+ print(f"Downloading {filename} from {repo_id}...")
35
+ try:
36
+ hf_hub_download(
37
+ repo_id=repo_id,
38
+ filename=filename,
39
+ local_dir=output_dir,
40
+ repo_type=repo_type,
41
+ local_dir_use_symlinks=False
42
+ )
43
+ print(f"Successfully downloaded {filename}.")
44
+ return local_path
45
+ except Exception as e:
46
+ print(f"Error downloading {filename}: {e}")
47
+ gamelogs_path = os.path.join("gamelogs", filename)
48
+ if os.path.exists(gamelogs_path):
49
+ print(f"Using local file from gamelogs: {gamelogs_path}")
50
+ return gamelogs_path
51
+ print(f"Asset {filename} not found on Hub or locally.")
52
+ return None
53
+ return local_path
54
+
55
+ # Load weights
56
+ print("Loading model weights...")
57
+ unet_path = download_asset("pytorch_lora_weights.bin" if model_config.use_lora else "unet.pth", pred_config.model_repo_id)
58
+ if unet_path:
59
+ if model_config.use_lora:
60
+ state_dict = torch.load(unet_path, map_location=device)
61
+ engine.unet.load_attn_procs(state_dict)
62
+ print("LoRA weights loaded.")
63
+ else:
64
+ engine.unet.load_state_dict(torch.load(unet_path, map_location=device))
65
+ print("UNet weights loaded.")
66
+ else:
67
+ print("Warning: UNet weights not found. Using base UNet.")
68
+
69
+ action_encoder_path = download_asset("action_encoder.pth", pred_config.model_repo_id)
70
+ if action_encoder_path:
71
+ action_encoder.load_state_dict(torch.load(action_encoder_path, map_location=device))
72
+ print("Action Encoder weights loaded.")
73
+ else:
74
+ print("Warning: Action encoder weights not found.")
75
+
76
+ engine.eval()
77
+ action_encoder.eval()
78
+
79
+ # --- Image Transformations & Helpers ---
80
+ transform = transforms.Compose([
81
+ transforms.Resize(model_config.image_size),
82
+ transforms.ToTensor(),
83
+ transforms.Normalize([0.5], [0.5])
84
+ ])
85
+
86
+ action_map = pred_config.action_map
87
+
88
+ def tensor_to_pil(tensor):
89
+ tensor = (tensor.squeeze(0).cpu() / 2 + 0.5).clamp(0, 1)
90
+ return transforms.ToPILImage()(tensor)
91
+
92
+ # --- Core Logic for Gradio ---
93
+ @torch.inference_mode()
94
+ def start_game():
95
+ """Initializes a new game session and returns the first frame and state."""
96
+ print("Starting a new game session...")
97
+ # Get initial frame
98
+ first_frame_filename = "frames/frame_000000008.png"
99
+ first_frame_path = download_asset(first_frame_filename, pred_config.dataset_repo_id, repo_type="dataset")
100
+
101
+ if not first_frame_path:
102
+ # Return a black screen as a fallback
103
+ print("Could not load initial frame. Returning blank image.")
104
+ return Image.new("RGB", (320, 240)), None, None
105
+
106
+ pil_image = Image.open(first_frame_path).convert("RGB")
107
+
108
+ # Initialize histories
109
+ initial_frame_tensor = transform(pil_image).unsqueeze(0).to(device)
110
+ initial_latent = engine.vae.encode(initial_frame_tensor).latent_dist.sample()
111
+
112
+ frame_history = deque([initial_latent] * model_config.history_len, maxlen=model_config.history_len)
113
+
114
+ noop_action = torch.tensor(action_map["noop"], dtype=torch.float32, device=device).unsqueeze(0)
115
+ action_history = deque([noop_action] * model_config.history_len, maxlen=model_config.history_len)
116
+
117
+ print("Game session started.")
118
+ return pil_image, frame_history, action_history
119
+
120
+
121
+ @torch.inference_mode()
122
+ def predict_step(action_name, frame_history, action_history):
123
+ """Predicts the next frame based on an action and the current state."""
124
+ if frame_history is None or action_history is None:
125
+ return Image.new("RGB", (320, 240)), None, None
126
+
127
+ print(f"Received action: {action_name}")
128
+ action_list = action_map.get(action_name)
129
+ action_tensor = torch.tensor(action_list, dtype=torch.float32, device=device).unsqueeze(0)
130
+
131
+ # Inference
132
+ history_latents = torch.cat(list(frame_history), dim=1)
133
+ action_conditioning = action_encoder(action_tensor).unsqueeze(1)
134
+
135
+ out_channels = 4
136
+ current_latents = torch.randn(
137
+ (1, out_channels, model_config.image_size[0] // 8, model_config.image_size[1] // 8),
138
+ device=device
139
+ )
140
+
141
+ for t in engine.scheduler.timesteps:
142
+ model_input = torch.cat([current_latents, history_latents], dim=1)
143
+ noise_pred = engine(model_input, t, action_conditioning)
144
+ current_latents = engine.scheduler.step(noise_pred, t, current_latents).prev_sample
145
+
146
+ predicted_latent_unscaled = current_latents / engine.vae.config.scaling_factor
147
+ image_tensor = engine.vae.decode(predicted_latent_unscaled).sample
148
+
149
+ # Update State
150
+ frame_history.append(predicted_latent_unscaled)
151
+ action_history.append(action_tensor)
152
+
153
+ # Convert to PIL for display
154
+ pil_image = tensor_to_pil(image_tensor)
155
+ print("Prediction complete.")
156
+ return pil_image, frame_history, action_history
157
+
158
+ # --- Gradio UI ---
159
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
160
+ gr.Markdown("# Tiny Engine Game")
161
+ gr.Markdown("Press 'Start Game' and then use the controls to generate the next frame.")
162
+
163
+ # State variables to hold the session history between steps
164
+ frame_history_state = gr.State(None)
165
+ action_history_state = gr.State(None)
166
+
167
+ with gr.Row():
168
+ start_button = gr.Button("Start Game", variant="primary")
169
+
170
+ with gr.Row():
171
+ game_display = gr.Image(label="Game View", interactive=False)
172
+
173
+ with gr.Row():
174
+ with gr.Column():
175
+ gr.Markdown("### Controls")
176
+ fwd_button = gr.Button("W (Forward)")
177
+ s_button = gr.Button("S (Backward)")
178
+ a_button = gr.Button("A (Left)")
179
+ d_button = gr.Button("D (Right)")
180
+ turn_l_button = gr.Button("ArrowLeft (Turn Left)")
181
+ turn_r_button = gr.Button("ArrowRight (Turn Right)")
182
+ attack_button = gr.Button("Space (Attack)")
183
+
184
+ # --- Button Click Handlers ---
185
+ start_button.click(
186
+ fn=start_game,
187
+ inputs=[],
188
+ outputs=[game_display, frame_history_state, action_history_state]
189
+ )
190
+
191
+ action_buttons = [fwd_button, s_button, a_button, d_button, turn_l_button, turn_r_button, attack_button]
192
+ action_names = ["w", "s", "a", "d", "ArrowLeft", "ArrowRight", " "]
193
+
194
+ for button, name in zip(action_buttons, action_names):
195
+ button.click(
196
+ fn=predict_step,
197
+ inputs=[gr.State(name), frame_history_state, action_history_state],
198
+ outputs=[game_display, frame_history_state, action_history_state]
199
+ )
200
+
201
+ if __name__ == "__main__":
202
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ diffusers
5
+ transformers
6
+ huggingface_hub
7
+ Pillow
8
+ opencv-python-headless
9
+ accelerate
src/__pycache__/config.cpython-312.pyc ADDED
Binary file (2.67 kB). View file
 
src/__pycache__/model.cpython-312.pyc ADDED
Binary file (4.11 kB). View file
 
src/agent.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stable_baselines3 import PPO
2
+ from stable_baselines3.common.callbacks import BaseCallback
3
+ import os
4
+ from PIL import Image
5
+ import logging
6
+ import json
7
+ import numpy as np
8
+ import csv
9
+ import gymnasium
10
+ from vizdoom import gymnasium_wrapper # This import is needed to register the env
11
+
12
+ DATASET_DIR = "gamelogs"
13
+ FRAMES_DIR = os.path.join(DATASET_DIR, "frames")
14
+ os.makedirs(FRAMES_DIR, exist_ok=True)
15
+
16
+
17
+ class NpEncoder(json.JSONEncoder):
18
+ def default(self, obj):
19
+ if isinstance(obj, np.integer):
20
+ return int(obj)
21
+ if isinstance(obj, np.floating):
22
+ return float(obj)
23
+ if isinstance(obj, np.ndarray):
24
+ return obj.tolist()
25
+ return super(NpEncoder, self).default(obj)
26
+
27
+ class GameNGenCallback(BaseCallback):
28
+ def __init__(self, verbose: bool, save_path: str):
29
+ super(GameNGenCallback, self).__init__(verbose)
30
+ self.save_path = save_path
31
+ self.frame_log = open(os.path.join(self.save_path, "metadata.csv"), mode="w", newline="")
32
+ self.csv_writer = csv.writer(self.frame_log)
33
+ # CSV Header
34
+ self.csv_writer.writerow(["frame_id", "action"])
35
+
36
+ def _on_step(self) -> bool:
37
+ frame_id = self.n_calls
38
+ key = f"{frame_id:09d}"
39
+
40
+ try:
41
+ obs_dict = self.locals["new_obs"]
42
+ # The observation from the callback is in Channels-First format (C, H, W)
43
+ frame_data = obs_dict['screen'][0]
44
+ action = self.locals["actions"][0]
45
+
46
+ # --- DEFINITIVE FIX ---
47
+ # Check if the frame is in the expected Channels-First format (C, H, W).
48
+ # A valid RGB image will have 3 channels in its first dimension.
49
+ if frame_data.ndim == 3 and frame_data.shape[0] == 3:
50
+ # Pillow's fromarray function needs the image in Channels-Last format (H, W, C).
51
+ # We must transpose the axes from (C, H, W) to (H, W, C).
52
+ transposed_frame = np.transpose(frame_data, (1, 2, 0))
53
+ image = Image.fromarray(transposed_frame)
54
+ image.save(os.path.join(FRAMES_DIR, f"frame_{key}.png"))
55
+
56
+ json_action = json.dumps(action, cls=NpEncoder)
57
+ self.csv_writer.writerow([key, json_action])
58
+ else:
59
+ # This will now correctly catch the junk frames from terminal states.
60
+ logging.warning(f"Skipping corrupted frame {key} with invalid shape: {frame_data.shape}")
61
+
62
+ except Exception as e:
63
+ # This will now only catch truly unexpected errors.
64
+ logging.error(f"Could not process or save frame {key} due to an unexpected error: {e}")
65
+
66
+ return True
67
+
68
+ def _on_training_end(self) -> None:
69
+ self.frame_log.close()
70
+
71
+ # --- Main script ---
72
+ logging.basicConfig(level=logging.INFO)
73
+
74
+ # Create the VizDoom environment. No wrappers are needed.
75
+ env = gymnasium.make("VizdoomHealthGatheringSupreme-v0")
76
+
77
+ callback = GameNGenCallback(verbose=True, save_path=DATASET_DIR)
78
+
79
+ model = PPO(
80
+ "MultiInputPolicy",
81
+ env,
82
+ verbose=1,
83
+ )
84
+
85
+ model.learn(total_timesteps=2_000_000, callback=callback)
86
+
87
+ env.close()
src/config.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Tuple
3
+
4
+ @dataclass
5
+ class ModelConfig:
6
+ """Parameters defining the model architecture and basic properties."""
7
+ model_id: str = "CompVis/stable-diffusion-v1-4"
8
+ image_size: Tuple[int, int] = (240, 320)
9
+ num_timesteps: int = 100
10
+ history_len: int = 4
11
+ num_actions: int = 7
12
+ use_lora: bool = True
13
+
14
+ @dataclass
15
+ class TrainingConfig:
16
+ """Parameters specific to the training process."""
17
+ repo_id: str = "RevanthGundala/tiny_engine" # Dataset repository
18
+ learning_rate: float = 1e-4
19
+ subset_percentage: float = 0.01
20
+ batch_size: int = 16
21
+ num_epochs: int = 1
22
+ lora_rank: int = 4 # Only used if ModelConfig.use_lora is True
23
+ lora_alpha: int = 4 # Only used if ModelConfig.use_lora is True
24
+
25
+ @dataclass
26
+ class PredictionConfig:
27
+ """Parameters for the prediction server (app.py)."""
28
+ model_repo_id: str = "RevanthGundala/tiny_engine" # For model weights
29
+ dataset_repo_id: str = "RevanthGundala/tiny_engine" # For starting frame video
30
+ prediction_epoch: int = 99
31
+ output_dir: str = "output" # To load weights if not using MLflow
32
+ action_map: Dict[str, List[int]] = field(default_factory=lambda: {
33
+ "w": [1, 0, 0, 0, 0, 0, 0], # MOVE_FORWARD
34
+ "s": [0, 1, 0, 0, 0, 0, 0], # MOVE_BACKWARD
35
+ "d": [0, 0, 1, 0, 0, 0, 0], # MOVE_RIGHT
36
+ "a": [0, 0, 0, 1, 0, 0, 0], # MOVE_LEFT
37
+ "ArrowLeft": [0, 0, 0, 0, 1, 0, 0], # TURN_LEFT
38
+ "ArrowRight": [0, 0, 0, 0, 0, 1, 0], # TURN_RIGHT
39
+ " ": [0, 0, 0, 0, 0, 0, 1], # ATTACK
40
+ "noop": [0, 0, 0, 0, 0, 0, 0], # No operation
41
+ })
src/model.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
5
+
6
+ class GameNGen(nn.Module):
7
+ def __init__(self, model_id: str, timesteps: int, history_len: int):
8
+ super().__init__()
9
+ self.model_id = model_id
10
+ self.history_len = history_len
11
+ self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
12
+ self.unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
13
+ self.scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
14
+ self.scheduler.set_timesteps(timesteps)
15
+
16
+ # Modify the U-Net to accept history
17
+ original_in_channels = self.unet.config.in_channels # Should be 4
18
+ new_in_channels = original_in_channels * (1 + self.history_len)
19
+
20
+ original_conv_in = self.unet.conv_in
21
+
22
+ self.unet.conv_in = nn.Conv2d(
23
+ in_channels=new_in_channels,
24
+ out_channels=original_conv_in.out_channels,
25
+ kernel_size=original_conv_in.kernel_size,
26
+ stride=original_conv_in.stride,
27
+ padding=original_conv_in.padding,
28
+ )
29
+
30
+ # Initialize the new weights
31
+ with torch.no_grad():
32
+ # Copy original weights for the main noisy latent
33
+ self.unet.conv_in.weight[:, :original_in_channels, :, :] = original_conv_in.weight
34
+ # Zero-initialize weights for the history latents
35
+ self.unet.conv_in.weight[:, original_in_channels:, :, :].zero_()
36
+ # Copy bias
37
+ self.unet.conv_in.bias = original_conv_in.bias
38
+
39
+ # Update the model's config
40
+ self.unet.config.in_channels = new_in_channels
41
+
42
+ # not training so freeze
43
+ self.vae.requires_grad_(False)
44
+
45
+ def forward(self, noisy_latents: torch.Tensor, timesteps: int, conditioning: torch.Tensor) -> torch.Tensor:
46
+ noise_pred = self.unet(
47
+ sample=noisy_latents,
48
+ timestep=timesteps,
49
+ encoder_hidden_states=conditioning
50
+ ).sample
51
+
52
+
53
+ return noise_pred
54
+
55
+ class ActionEncoder(nn.Module):
56
+ def __init__(self, num_actions: int, cross_attention_dim: int):
57
+ super().__init__()
58
+ self.encoder = nn.Sequential(
59
+ nn.Linear(in_features=num_actions, out_features=cross_attention_dim),
60
+ nn.SiLU(inplace=True),
61
+ nn.Linear(in_features=cross_attention_dim, out_features=cross_attention_dim)
62
+ )
63
+
64
+ def forward(self, x):
65
+ return self.encoder(x)
src/tiny_engine.egg-info/PKG-INFO ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: tiny-engine
3
+ Version: 0.1.0
4
+ Author-email: Revanth Gundala <revanth.gundala@gmail.com>
5
+ Requires-Python: <3.13,>=3.10
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: torch@ https://download.pytorch.org/whl/cpu/torch-2.3.1%2Bcpu-cp312-cp312-linux_x86_64.whl
8
+ Requires-Dist: torchvision@ https://download.pytorch.org/whl/cpu/torchvision-0.18.1%2Bcpu-cp312-cp312-linux_x86_64.whl
9
+ Requires-Dist: vizdoom<2.0.0,>=1.2.3
10
+ Requires-Dist: pandas<3.0.0,>=2.2.0
11
+ Requires-Dist: opencv-python>=4.8.0
12
+ Requires-Dist: pillow<11.0.0,>=10.3.0
13
+ Requires-Dist: diffusers<0.28.0,>=0.27.2
14
+ Requires-Dist: stable-baselines3[extra]<3.0.0,>=2.3.0
15
+ Requires-Dist: transformers<5.0.0,>=4.40.0
16
+ Requires-Dist: accelerate<0.30.0,>=0.29.0
17
+ Requires-Dist: tqdm<5.0.0,>=4.66.0
18
+ Requires-Dist: peft<0.11.0,>=0.10.0
19
+ Requires-Dist: huggingface-hub<0.23.0,>=0.22.0
20
+ Requires-Dist: fastapi>=0.111.0
21
+ Requires-Dist: uvicorn[standard]>=0.29.0
22
+ Requires-Dist: python-multipart>=0.0.9
23
+
24
+ # Tiny Engine
25
+
26
+ This project uses a generative model to predict the next frame of a game based on the current frame and a player's action. It's served via a FastAPI backend and includes an interactive Next.js frontend.
27
+
28
+ ## Setup and Installation
29
+
30
+ This project uses `uv` for Python package management.
31
+
32
+ 1. **Install `uv`**:
33
+ If you don't have `uv` installed, follow the official installation instructions:
34
+ ```bash
35
+ # For macOS and Linux:
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh
37
+ ```
38
+
39
+ 2. **Create a Virtual Environment**:
40
+ ```bash
41
+ uv venv
42
+ ```
43
+ This will create a `.venv` directory in your project folder.
44
+
45
+ 3. **Activate the Virtual Environment**:
46
+ ```bash
47
+ # For macOS and Linux:
48
+ source .venv/bin/activate
49
+ ```
50
+
51
+ 4. **Install Python Dependencies**:
52
+ Install the required packages, including PyTorch from its specific download source.
53
+ ```bash
54
+ uv pip install --find-links https://download.pytorch.org/whl/cpu -e .
55
+ ```
56
+
57
+ ## Running the Application
58
+
59
+ You need to run the backend and frontend servers in two separate terminals.
60
+
61
+ **1. Start the Backend Server**:
62
+
63
+ Make sure your virtual environment is activated.
64
+
65
+ ```bash
66
+ uv run python app.py
67
+ ```
68
+ The backend will be available at `http://localhost:8000`.
69
+
70
+ **2. Start the Frontend Server**:
71
+
72
+ In a new terminal, navigate to the `frontend` directory.
73
+
74
+ ```bash
75
+ cd frontend
76
+ npm install
77
+ npm run dev
78
+ ```
79
+ The frontend will be available at `http://localhost:3000`. You can now open this URL in your browser to play the game.
src/tiny_engine.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ src/agent.py
4
+ src/config.py
5
+ src/model.py
6
+ src/train.py
7
+ src/tiny_engine.egg-info/PKG-INFO
8
+ src/tiny_engine.egg-info/SOURCES.txt
9
+ src/tiny_engine.egg-info/dependency_links.txt
10
+ src/tiny_engine.egg-info/requires.txt
11
+ src/tiny_engine.egg-info/top_level.txt
src/tiny_engine.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/tiny_engine.egg-info/requires.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch@ https://download.pytorch.org/whl/cpu/torch-2.3.1%2Bcpu-cp312-cp312-linux_x86_64.whl
2
+ torchvision@ https://download.pytorch.org/whl/cpu/torchvision-0.18.1%2Bcpu-cp312-cp312-linux_x86_64.whl
3
+ vizdoom<2.0.0,>=1.2.3
4
+ pandas<3.0.0,>=2.2.0
5
+ opencv-python>=4.8.0
6
+ pillow<11.0.0,>=10.3.0
7
+ diffusers<0.28.0,>=0.27.2
8
+ stable-baselines3[extra]<3.0.0,>=2.3.0
9
+ transformers<5.0.0,>=4.40.0
10
+ accelerate<0.30.0,>=0.29.0
11
+ tqdm<5.0.0,>=4.66.0
12
+ peft<0.11.0,>=0.10.0
13
+ huggingface-hub<0.23.0,>=0.22.0
14
+ fastapi>=0.111.0
15
+ uvicorn[standard]>=0.29.0
16
+ python-multipart>=0.0.9
src/tiny_engine.egg-info/top_level.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ agent
2
+ config
3
+ model
4
+ train
src/train.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from model import GameNGen, ActionEncoder
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import DataLoader, Dataset
6
+ from config import ModelConfig, TrainingConfig
7
+ import pandas as pd
8
+ from torchvision import transforms
9
+ import os
10
+ from PIL import Image
11
+ import json
12
+ import logging
13
+ import torch.nn.functional as F
14
+ from diffusers.optimization import get_cosine_schedule_with_warmup
15
+ from accelerate import Accelerator
16
+ from huggingface_hub import hf_hub_download
17
+ from peft import LoraConfig
18
+ import mlflow
19
+ import argparse
20
+
21
+
22
+ class NextFrameDataset(Dataset):
23
+ def __init__(self, num_actions: int, metadata_path: str, frames_dir: str, image_size: tuple, history_len: int, subset_percentage: float):
24
+ self.metadata = pd.read_csv(metadata_path)
25
+ self.frames_dir = frames_dir
26
+ # List files and filter out non-image files if necessary
27
+ self.frame_files = sorted(
28
+ [f for f in os.listdir(frames_dir) if f.endswith('.png')],
29
+ key=lambda x: int(x.split('_')[1].split('.')[0])
30
+ )
31
+ # Calculate the number of frames to use based on the percentage
32
+ num_to_use = int(len(self.frame_files) * subset_percentage)
33
+ self.frame_files = self.frame_files[:num_to_use]
34
+ self.metadata = self.metadata.iloc[:num_to_use]
35
+ print(f"Using a {subset_percentage*100}% subset of the data: {len(self.frame_files)} frames.")
36
+ self.num_actions = num_actions
37
+ self.total_frames = len(self.frame_files)
38
+ self.history_len = history_len
39
+
40
+ self.transform = transforms.Compose([
41
+ transforms.Resize(image_size),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize([0.5], [0.5]) # Normalize VAE to [-1, 1]
44
+ ])
45
+
46
+ def __len__(self) -> int:
47
+ # We can't use the first `history_len` frames as they don't have enough history
48
+ return min(len(self.metadata), self.total_frames) - self.history_len - 1
49
+
50
+ def __getitem__(self, idx: int) -> dict:
51
+ # We are getting the item at `idx` in our shortened dataset.
52
+ # The actual index in the video/metadata is `idx + self.history_len`.
53
+ actual_idx = idx + self.history_len
54
+
55
+ history_frames = []
56
+ for i in range(self.history_len):
57
+ frame_idx = actual_idx - self.history_len + i
58
+ # Use the sorted file list to get the correct frame
59
+ img_path = os.path.join(self.frames_dir, self.frame_files[frame_idx])
60
+ try:
61
+ pil_image = Image.open(img_path).convert("RGB")
62
+ except FileNotFoundError:
63
+ raise IndexError(f"Could not read history frame {frame_idx} from {img_path}.")
64
+ history_frames.append(self.transform(pil_image))
65
+
66
+ history_tensor = torch.stack(history_frames)
67
+
68
+ # Get the target frame (next_frame)
69
+ next_frame_img_path = os.path.join(self.frames_dir, self.frame_files[actual_idx])
70
+ try:
71
+ next_pil_image = Image.open(next_frame_img_path).convert("RGB")
72
+ except FileNotFoundError:
73
+ raise IndexError(f"Could not read frame {actual_idx} from {next_frame_img_path}.")
74
+ next_image = self.transform(next_pil_image)
75
+
76
+ # Get the action that led to the `next_frame`
77
+ action_row = self.metadata.iloc[actual_idx]
78
+ action_data = json.loads(str(action_row['action']))
79
+ action_int = int(action_data[0] if isinstance(action_data, list) else action_data)
80
+ curr_action = torch.zeros(self.num_actions)
81
+ curr_action[action_int] = 1.0
82
+
83
+ return {
84
+ "frame_history": history_tensor,
85
+ "action": curr_action,
86
+ "next_frame": next_image
87
+ }
88
+
89
+ def train():
90
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
91
+
92
+ parser = argparse.ArgumentParser(description="GameNGen Finetuning")
93
+ parser.add_argument("--metadata_input", type=str, required=True, help="Path to the metadata CSV file")
94
+ parser.add_argument("--frames_input", type=str, required=True, help="Path to the frames directory")
95
+ parser.add_argument("--experiment_name", type=str, default="GameNGen Finetuning", help="Name of the MLflow experiment.")
96
+ args = parser.parse_args()
97
+
98
+ # --- MLflow Integration ---
99
+ # Check for Azure ML environment.
100
+ # The v1 SDK may set AZUREML_MLFLOW_URI, while v2 sets MLFLOW_TRACKING_URI.
101
+ is_azureml_env = "AZUREML_MLFLOW_URI" in os.environ or \
102
+ ("MLFLOW_TRACKING_URI" in os.environ and "azureml" in os.environ["MLFLOW_TRACKING_URI"])
103
+
104
+ if is_azureml_env:
105
+ # In Azure ML, MLflow is configured automatically by environment variables.
106
+ # We don't need to set the tracking URI or experiment name.
107
+ logging.info("βœ… MLflow using Azure ML environment configuration.")
108
+ else:
109
+ # For local runs, explicitly set up a local tracking URI and experiment.
110
+ # This will save runs to a local 'mlruns' directory.
111
+ mlflow.set_tracking_uri("file:./mlruns")
112
+ mlflow.set_experiment(args.experiment_name)
113
+ logging.info(f"⚠️ Using local MLflow tracking (./mlruns) for experiment '{args.experiment_name}'.")
114
+
115
+ # --- Setup ---
116
+ accelerator = Accelerator(
117
+ mixed_precision="fp16",
118
+ gradient_accumulation_steps=1
119
+ )
120
+ model_config = ModelConfig()
121
+ train_config = TrainingConfig()
122
+
123
+ # Define file paths using the config
124
+ metadata_path = args.metadata_input
125
+ frames_dir = args.frames_input
126
+
127
+ engine = GameNGen(model_config.model_id, model_config.num_timesteps, history_len=model_config.history_len)
128
+
129
+ # --- Memory Saving Optimizations ---
130
+ engine.unet.enable_gradient_checkpointing()
131
+ # try:
132
+ # engine.unet.enable_xformers_memory_efficient_attention()
133
+ # logging.info("xformers memory-efficient attention enabled.")
134
+ # except ImportError:
135
+ # logging.warning("xformers is not installed. For better memory efficiency, run: pip install xformers")
136
+
137
+ dataset = NextFrameDataset(model_config.num_actions, metadata_path, frames_dir, model_config.image_size, history_len=model_config.history_len, subset_percentage=train_config.subset_percentage)
138
+ dataloader = DataLoader(
139
+ dataset=dataset,
140
+ batch_size=train_config.batch_size,
141
+ shuffle=True,
142
+ num_workers=0
143
+ )
144
+
145
+ cross_attention_dim = engine.unet.config.cross_attention_dim
146
+ action_encoder = ActionEncoder(model_config.num_actions, cross_attention_dim)
147
+
148
+
149
+ if model_config.use_lora:
150
+ engine.unet.requires_grad_(False)
151
+ lora_config = LoraConfig(
152
+ r=train_config.lora_rank,
153
+ lora_alpha=train_config.lora_alpha,
154
+ target_modules=["to_q", "to_k", "to_v", "to_out.0"],
155
+ lora_dropout=0.1,
156
+ bias="lora_only",
157
+ )
158
+ engine.unet.add_adapter(lora_config)
159
+ lora_layers = filter(lambda p: p.requires_grad, engine.unet.parameters())
160
+ params_to_train = list(lora_layers) + list(action_encoder.parameters())
161
+ else:
162
+ params_to_train = list(engine.unet.parameters()) + list(action_encoder.parameters())
163
+
164
+ optim = torch.optim.AdamW(params=params_to_train, lr=train_config.learning_rate)
165
+
166
+ lr_scheduler = get_cosine_schedule_with_warmup(
167
+ optimizer=optim, num_warmup_steps=500, num_training_steps=len(dataloader) * train_config.num_epochs
168
+ )
169
+ engine, action_encoder, optim, dataloader, lr_scheduler = accelerator.prepare(
170
+ engine, action_encoder, optim, dataloader, lr_scheduler
171
+ )
172
+
173
+ mlflow.autolog(log_models=False)
174
+
175
+ # --- Add an output directory for checkpoints ---
176
+ output_dir = "./outputs"
177
+ os.makedirs(output_dir, exist_ok=True)
178
+
179
+ logging.info("Starting training loop...")
180
+
181
+ mlflow.log_params({
182
+ "learning_rate": train_config.learning_rate,
183
+ "batch_size": train_config.batch_size,
184
+ "num_epochs": train_config.num_epochs,
185
+ "use_lora": model_config.use_lora,
186
+ "lora_rank": train_config.lora_rank if model_config.use_lora else None,
187
+ "subset_percentage": train_config.subset_percentage
188
+ })
189
+
190
+ global_step = 0
191
+ for epoch in range(train_config.num_epochs):
192
+ progress_bar = tqdm(total=len(dataloader), disable=not accelerator.is_local_main_process)
193
+ progress_bar.set_description(f"Epoch {epoch}")
194
+ for batch in dataloader:
195
+ optim.zero_grad()
196
+ next_frames, actions, frame_history = batch["next_frame"], batch["action"], batch["frame_history"]
197
+
198
+ # Encode into latent space
199
+ with torch.no_grad():
200
+ vae = accelerator.unwrap_model(engine).vae
201
+ latent_dist = vae.encode(next_frames).latent_dist
202
+ clean_latents = latent_dist.sample() * vae.config.scaling_factor
203
+
204
+ # Encode history frames
205
+ bs, hist_len, C, H, W = frame_history.shape
206
+ frame_history = frame_history.view(bs * hist_len, C, H, W)
207
+ history_latents = vae.encode(frame_history).latent_dist.sample()
208
+ _, latent_C, latent_H, latent_W = history_latents.shape
209
+ history_latents = history_latents.reshape(bs, hist_len * latent_C, latent_H, latent_W)
210
+
211
+ # Add noise to history latents to prevent drift (noise augmentation)
212
+ noise_level = 0.1 # Start with a small, fixed amount of noise
213
+ history_noise = torch.randn_like(history_latents) * noise_level
214
+ corrupted_history_latents = history_latents + history_noise
215
+
216
+ # Conditioning is now only the action
217
+ action_conditioning = action_encoder(actions)
218
+ conditioning_batch = action_conditioning.unsqueeze(1)
219
+
220
+ # create random noise
221
+ noise = torch.randn_like(clean_latents)
222
+
223
+ # pick random timestep. High timstep means more noise
224
+ timesteps = torch.randint(0, engine.scheduler.config.num_train_timesteps, (clean_latents.shape[0], ), device=clean_latents.device).long()
225
+
226
+ noisy_latents = engine.scheduler.add_noise(clean_latents, noise, timesteps)
227
+
228
+ # Concatenate history latents with noisy latents
229
+ model_input = torch.cat([noisy_latents, corrupted_history_latents], dim=1)
230
+
231
+ with accelerator.accumulate(engine):
232
+ noise_pred = engine(model_input, timesteps, conditioning_batch)
233
+ loss = F.mse_loss(noise_pred, noise)
234
+ accelerator.backward(loss)
235
+
236
+ accelerator.clip_grad_norm_(engine.unet.parameters(), 1.0)
237
+ optim.step()
238
+ lr_scheduler.step()
239
+
240
+ progress_bar.update(1)
241
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
242
+
243
+ # Log metrics to MLflow
244
+ if global_step % 10 == 0: # Log every 10 steps to avoid too much overhead
245
+ mlflow.log_metric("loss", logs["loss"], step=global_step)
246
+ mlflow.log_metric("learning_rate", logs["lr"], step=global_step)
247
+
248
+ progress_bar.set_postfix(**logs)
249
+ global_step += 1
250
+
251
+ progress_bar.close()
252
+
253
+ if accelerator.is_main_process:
254
+ logging.info(f"Epoch {epoch} complete. Saving checkpoint...")
255
+
256
+ # Define a unique directory for this epoch's checkpoint
257
+ checkpoint_dir = os.path.join(output_dir, f"checkpoint_epoch_{epoch}")
258
+
259
+ # Use accelerator.save_state to save everything
260
+ accelerator.save_state(checkpoint_dir)
261
+
262
+ logging.info(f"Checkpoint saved to {checkpoint_dir}")
263
+
264
+ # Save models at the end of training
265
+ if accelerator.is_main_process:
266
+ unwrapped_unet = accelerator.unwrap_model(engine).unet
267
+ unwrapped_action_encoder = accelerator.unwrap_model(action_encoder)
268
+
269
+ try:
270
+ # Log the action encoder
271
+ mlflow.pytorch.log_model(unwrapped_action_encoder, "action_encoder")
272
+ logging.info("βœ… Action encoder logged to MLflow")
273
+
274
+ # Log the UNet (or its LoRA weights)
275
+ if model_config.use_lora:
276
+ from peft import get_peft_model_state_dict
277
+ import json
278
+
279
+ lora_save_path = "unet_lora_weights"
280
+ os.makedirs(lora_save_path, exist_ok=True)
281
+
282
+ # Save LoRA weights using PEFT method
283
+ lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
284
+ torch.save(lora_state_dict, os.path.join(lora_save_path, "pytorch_lora_weights.bin"))
285
+
286
+ # Save adapter config
287
+ adapter_config = unwrapped_unet.peft_config
288
+ with open(os.path.join(lora_save_path, "adapter_config.json"), "w") as f:
289
+ json.dump(adapter_config, f, indent=2, default=str)
290
+
291
+ mlflow.log_artifacts(lora_save_path, artifact_path="unet_lora")
292
+ logging.info("βœ… LoRA weights logged to MLflow")
293
+ else:
294
+ mlflow.pytorch.log_model(unwrapped_unet, "unet")
295
+ logging.info("βœ… UNet logged to MLflow")
296
+
297
+ logging.info(f"βœ… Training completed. MLflow Run ID: {mlflow.active_run().info.run_id}")
298
+
299
+ except Exception as e:
300
+ logging.error(f"❌ Error logging models to MLflow: {e}")
301
+ # Save models locally as fallback
302
+ torch.save(unwrapped_action_encoder.state_dict(), os.path.join(output_dir, "action_encoder.pth"))
303
+ if model_config.use_lora:
304
+ try:
305
+ from peft import get_peft_model_state_dict
306
+ lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
307
+ torch.save(lora_state_dict, os.path.join(output_dir, "lora_weights.bin"))
308
+ logging.info("πŸ“ LoRA weights saved locally")
309
+ except Exception as lora_e:
310
+ logging.error(f"❌ Error saving LoRA weights: {lora_e}")
311
+ torch.save(unwrapped_unet.state_dict(), os.path.join(output_dir, "unet_full.pth"))
312
+ logging.info("πŸ“ Full UNet saved locally as fallback")
313
+ else:
314
+ torch.save(unwrapped_unet.state_dict(), os.path.join(output_dir, "unet.pth"))
315
+ logging.info("πŸ“ Models saved locally as fallback")
316
+
317
+ if __name__ == "__main__":
318
+ train()