|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers import LongT5ForConditionalGeneration, T5ForConditionalGeneration, T5Tokenizer |
|
|
from accelerate import Accelerator |
|
|
from accelerate.utils import set_seed |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
import yaml |
|
|
from tqdm import tqdm |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
import argparse |
|
|
import os |
|
|
import re |
|
|
import warnings |
|
|
from collections import defaultdict |
|
|
import time |
|
|
from datetime import datetime |
|
|
import sys |
|
|
import matplotlib |
|
|
matplotlib.use("Agg") |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
SCRIPT_DIR = Path(__file__).resolve().parent |
|
|
WAVEGEN_ROOT = SCRIPT_DIR.parent |
|
|
if str(WAVEGEN_ROOT) not in sys.path: |
|
|
sys.path.insert(0, str(WAVEGEN_ROOT)) |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", message="Passing a tuple of `past_key_values` is deprecated") |
|
|
|
|
|
from data.movi_dataset import create_dataloader |
|
|
from utils.save_generation_results import save_generation_results |
|
|
|
|
|
|
|
|
class Text2WaveModel(nn.Module): |
|
|
"""Text to Superquadric Wave Parameters Model""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "google/long-t5-tglobal-base", |
|
|
max_objects: int = 10, |
|
|
num_frames: int = 24, |
|
|
max_history_frames: int = 3, |
|
|
random_history_sampling: bool = True, |
|
|
decoder_noise_std: float = 0.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.max_objects = max_objects |
|
|
self.num_frames = num_frames |
|
|
self.max_history_frames = max_history_frames |
|
|
self.random_history_sampling = random_history_sampling |
|
|
self.decoder_noise_std = float(decoder_noise_std) |
|
|
|
|
|
self.object_param_dim = 15 |
|
|
|
|
|
|
|
|
self.model_name = model_name |
|
|
self.is_longt5 = "long-t5" in model_name.lower() |
|
|
self.tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
|
if self.is_longt5: |
|
|
self.t5_model = LongT5ForConditionalGeneration.from_pretrained(model_name) |
|
|
else: |
|
|
self.t5_model = T5ForConditionalGeneration.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
if self.tokenizer.vocab_size != self.t5_model.config.vocab_size: |
|
|
self.t5_model.resize_token_embeddings(self.tokenizer.vocab_size) |
|
|
|
|
|
|
|
|
self.hidden_size = self.t5_model.config.d_model |
|
|
|
|
|
|
|
|
|
|
|
self.object_proj = nn.Linear(self.hidden_size, max_objects * self.object_param_dim) |
|
|
|
|
|
|
|
|
self.world_proj = nn.Linear(self.hidden_size, 8) |
|
|
|
|
|
|
|
|
self.physics_proj = nn.Linear(self.hidden_size, max_objects * 3) |
|
|
|
|
|
|
|
|
self.time_embed = nn.Linear(1, self.hidden_size) |
|
|
|
|
|
|
|
|
history_feature_dim = max_history_frames * (max_objects * self.object_param_dim + 8) + max_objects * 3 |
|
|
self.history_feature_dim = history_feature_dim |
|
|
self.history_proj = nn.Linear(history_feature_dim, self.hidden_size) |
|
|
|
|
|
|
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
"""Initialize weights for stability""" |
|
|
|
|
|
for module in [self.object_proj, self.world_proj, self.physics_proj]: |
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
|
|
|
nn.init.normal_(self.time_embed.weight, mean=0.0, std=0.02) |
|
|
nn.init.zeros_(self.time_embed.bias) |
|
|
|
|
|
|
|
|
nn.init.normal_(self.history_proj.weight, mean=0.0, std=0.02) |
|
|
nn.init.zeros_(self.history_proj.bias) |
|
|
|
|
|
def _initialize_history_state( |
|
|
self, |
|
|
history_frames: Optional[Dict[str, torch.Tensor]], |
|
|
batch_size: int, |
|
|
device: torch.device, |
|
|
) -> Tuple[List[Dict[str, torch.Tensor]], torch.Tensor]: |
|
|
"""Prepare history buffer and physics state for autoregressive decoding.""" |
|
|
history_buffer: List[Dict[str, torch.Tensor]] = [] |
|
|
|
|
|
physics_state = torch.zeros( |
|
|
batch_size, |
|
|
self.max_objects, |
|
|
3, |
|
|
device=device, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
|
|
|
if history_frames is not None: |
|
|
objects_hist = history_frames.get('objects') |
|
|
world_hist = history_frames.get('world') |
|
|
physics_hist = history_frames.get('physics') |
|
|
|
|
|
if physics_hist is not None: |
|
|
physics_state = physics_hist.to(device=device, dtype=torch.float32) |
|
|
|
|
|
if objects_hist is not None and world_hist is not None: |
|
|
history_len = objects_hist.shape[1] |
|
|
for idx in range(history_len): |
|
|
history_buffer.append({ |
|
|
'objects': objects_hist[:, idx, :, :self.object_param_dim].to(device=device, dtype=torch.float32), |
|
|
'world': world_hist[:, idx, :8].to(device=device, dtype=torch.float32), |
|
|
}) |
|
|
|
|
|
if len(history_buffer) == 0: |
|
|
history_buffer.append({ |
|
|
'objects': torch.zeros(batch_size, self.max_objects, self.object_param_dim, device=device), |
|
|
'world': torch.zeros(batch_size, 8, device=device), |
|
|
}) |
|
|
|
|
|
history_buffer = history_buffer[-self.max_history_frames:] |
|
|
|
|
|
return history_buffer, physics_state |
|
|
|
|
|
def sample_decoder_noise(self, batch_size: int, device: torch.device) -> Optional[torch.Tensor]: |
|
|
"""Sample decoder noise embedding when noise std > 0.""" |
|
|
if self.decoder_noise_std <= 0: |
|
|
return None |
|
|
noise = torch.randn(batch_size, self.hidden_size, device=device) |
|
|
return noise * self.decoder_noise_std |
|
|
|
|
|
def _build_history_embedding( |
|
|
self, |
|
|
history_buffer: List[Dict[str, torch.Tensor]], |
|
|
physics_state: torch.Tensor, |
|
|
use_frames: int, |
|
|
) -> torch.Tensor: |
|
|
"""Convert most recent history frames into conditioning embedding.""" |
|
|
batch_size = physics_state.shape[0] |
|
|
device = physics_state.device |
|
|
|
|
|
frame_dim = self.max_objects * self.object_param_dim + 8 |
|
|
history_tensor = torch.zeros( |
|
|
batch_size, |
|
|
self.max_history_frames * frame_dim, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
use_frames = min(use_frames, self.max_history_frames) |
|
|
recent_frames = history_buffer[-use_frames:] if use_frames > 0 else [] |
|
|
for slot, frame in enumerate(recent_frames): |
|
|
offset = slot * frame_dim |
|
|
obj_flat = frame['objects'].reshape(batch_size, -1) |
|
|
world_feat = frame['world'] |
|
|
history_tensor[:, offset:offset + obj_flat.shape[1]] = obj_flat |
|
|
history_tensor[:, offset + obj_flat.shape[1]:offset + frame_dim] = world_feat |
|
|
|
|
|
physics_flat = physics_state.reshape(batch_size, -1) |
|
|
history_features = torch.cat([history_tensor, physics_flat], dim=-1) |
|
|
return self.history_proj(history_features) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_text: List[str], |
|
|
target_frames: torch.Tensor, |
|
|
history_frames: Optional[Dict[str, torch.Tensor]] = None, |
|
|
relative_times: torch.Tensor = None, |
|
|
static_object_params: Optional[torch.Tensor] = None, |
|
|
noise: Optional[torch.Tensor] = None, |
|
|
): |
|
|
""" |
|
|
Forward pass for text to wave parameter generation |
|
|
|
|
|
Args: |
|
|
input_text: List of text descriptions |
|
|
target_frames: Target frame indices to predict |
|
|
history_frames: Optional history frames for conditioning |
|
|
relative_times: Relative time positions [-1, 1] for each target frame |
|
|
""" |
|
|
batch_size = len(input_text) |
|
|
num_target_frames = target_frames.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
formatted_text = [f"translate to wave: {text}" for text in input_text] |
|
|
|
|
|
|
|
|
text_inputs = self.tokenizer( |
|
|
formatted_text, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
return_tensors="pt" |
|
|
).to(target_frames.device) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
decoder_start_token_id = self.t5_model.config.pad_token_id |
|
|
decoder_input_ids = torch.full( |
|
|
(batch_size, 1), |
|
|
decoder_start_token_id, |
|
|
dtype=torch.long, |
|
|
device=text_inputs.input_ids.device |
|
|
) |
|
|
|
|
|
|
|
|
outputs = self.t5_model( |
|
|
input_ids=text_inputs.input_ids, |
|
|
attention_mask=text_inputs.attention_mask, |
|
|
decoder_input_ids=decoder_input_ids, |
|
|
return_dict=True, |
|
|
output_hidden_states=True |
|
|
) |
|
|
|
|
|
encoder_outputs = outputs.encoder_last_hidden_state |
|
|
except Exception as e: |
|
|
if 'log_message' in globals(): |
|
|
log_message(f"ERROR in encoder: {e}") |
|
|
else: |
|
|
print(f"ERROR in encoder: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
history_buffer, physics_state = self._initialize_history_state( |
|
|
history_frames, |
|
|
batch_size, |
|
|
target_frames.device, |
|
|
) |
|
|
|
|
|
if static_object_params is not None: |
|
|
static_object_params = static_object_params.to( |
|
|
device=target_frames.device, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
|
|
|
if noise is not None: |
|
|
noise = noise.to(device=encoder_outputs.device, dtype=encoder_outputs.dtype) |
|
|
|
|
|
outputs = [] |
|
|
|
|
|
for f in range(num_target_frames): |
|
|
if self.random_history_sampling: |
|
|
max_available = min(len(history_buffer), self.max_history_frames) |
|
|
if max_available > 0: |
|
|
use_history = int(torch.randint( |
|
|
low=0, |
|
|
high=max_available + 1, |
|
|
size=(1,), |
|
|
device=encoder_outputs.device, |
|
|
).item()) |
|
|
else: |
|
|
use_history = 0 |
|
|
else: |
|
|
use_history = min(len(history_buffer), self.max_history_frames) |
|
|
|
|
|
if relative_times is not None: |
|
|
time_input = relative_times[:, f:f+1].unsqueeze(-1) |
|
|
time_embed = self.time_embed(time_input).squeeze(1) |
|
|
else: |
|
|
time_embed = torch.zeros( |
|
|
batch_size, |
|
|
self.hidden_size, |
|
|
device=encoder_outputs.device, |
|
|
) |
|
|
|
|
|
history_embed = self._build_history_embedding(history_buffer, physics_state, use_history) |
|
|
decoder_embed = time_embed + history_embed |
|
|
if noise is not None: |
|
|
decoder_embed = decoder_embed + noise |
|
|
|
|
|
decoder_output = self.t5_model.decoder( |
|
|
inputs_embeds=decoder_embed.unsqueeze(1), |
|
|
encoder_hidden_states=encoder_outputs, |
|
|
encoder_attention_mask=text_inputs.attention_mask, |
|
|
) |
|
|
|
|
|
hidden = decoder_output.last_hidden_state[:, 0] |
|
|
|
|
|
object_params = self.object_proj(hidden).view(batch_size, self.max_objects, self.object_param_dim) |
|
|
if static_object_params is not None: |
|
|
|
|
|
static_slice = static_object_params[:, :, :6] |
|
|
if static_slice.shape[-1] < 6: |
|
|
pad_width = 6 - static_slice.shape[-1] |
|
|
pad = torch.zeros(*static_slice.shape[:-1], pad_width, device=object_params.device) |
|
|
static_slice = torch.cat([static_slice, pad], dim=-1) |
|
|
object_params = object_params.clone() |
|
|
object_params[:, :, :6] = static_slice |
|
|
world_params = self.world_proj(hidden) |
|
|
physics_params = self.physics_proj(hidden).view(batch_size, self.max_objects, 3) |
|
|
|
|
|
outputs.append({ |
|
|
'objects': object_params, |
|
|
'world': world_params, |
|
|
'physics': physics_params, |
|
|
}) |
|
|
|
|
|
history_buffer.append({ |
|
|
'objects': object_params, |
|
|
'world': world_params, |
|
|
}) |
|
|
if len(history_buffer) > self.max_history_frames: |
|
|
history_buffer = history_buffer[-self.max_history_frames:] |
|
|
|
|
|
physics_state = physics_params |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
class BidirectionalTrainer: |
|
|
"""Trainer for bidirectional prediction from middle frame""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: Text2WaveModel, |
|
|
config: Dict, |
|
|
accelerator: Accelerator, |
|
|
): |
|
|
self.model = model |
|
|
self.config = config |
|
|
self.accelerator = accelerator |
|
|
base_model = accelerator.unwrap_model(model) if hasattr(accelerator, "unwrap_model") else model |
|
|
self.object_param_dim = getattr(base_model, "object_param_dim", 12) |
|
|
self.freeze_static_params = bool(config['training'].get('freeze_static_from_anchor', True)) |
|
|
self.base_model = base_model |
|
|
self.sample_attempts = int(config['training'].get('multi_sample_attempts', 1)) |
|
|
self.sample_attempts = max(1, self.sample_attempts) |
|
|
|
|
|
|
|
|
self.world_loss_fn = nn.MSELoss() |
|
|
self.physics_loss_fn = nn.MSELoss() |
|
|
|
|
|
|
|
|
loss_weights_config = config.get('loss', {}).get('weights', {}) |
|
|
self.loss_weights = { |
|
|
'wave_loss(superquadric)': loss_weights_config.get('wave_loss', 1.0), |
|
|
'wave_contrastive_loss': loss_weights_config.get('wave_contrastive_loss', 2.0), |
|
|
'world_info_loss(camera,scale,time)': loss_weights_config.get('world_info_loss', 0.5), |
|
|
'controllable_info_loss(mass,friction,restitution)': loss_weights_config.get('controllable_info_loss', 0.1), |
|
|
'pla_loss': loss_weights_config.get('pla_loss', 3.0), |
|
|
} |
|
|
|
|
|
physics_cfg = config.get('physics', {}) |
|
|
self.gravity = float(physics_cfg.get('gravity', 9.81)) |
|
|
self.collision_buffer = float(physics_cfg.get('collision_buffer', 1.05)) |
|
|
|
|
|
|
|
|
self.frame_rate = float(config['training'].get('frame_rate', 8.0)) |
|
|
self.frame_rate = max(self.frame_rate, 1e-6) |
|
|
|
|
|
presence_cfg = config.get('loss', {}).get('wave_presence', {}) |
|
|
self.wave_count_weight = float(presence_cfg.get('count_weight', 0.2)) |
|
|
self.wave_presence_threshold = float(presence_cfg.get('scale_threshold', 0.1)) |
|
|
self.wave_presence_temperature = float(presence_cfg.get('temperature', 0.1)) |
|
|
contrastive_cfg = config.get('loss', {}).get('wave_contrastive', {}) |
|
|
self.wave_contrastive_temperature = float(contrastive_cfg.get('temperature', 0.2)) |
|
|
|
|
|
|
|
|
self.velocity_slice = slice(max(self.object_param_dim - 3, 0), self.object_param_dim) |
|
|
|
|
|
def compute_loss( |
|
|
self, |
|
|
predictions: List[Dict], |
|
|
targets: Dict[str, torch.Tensor], |
|
|
frame_indices: List[int], |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Compute losses for predicted frames""" |
|
|
losses = { |
|
|
'wave_loss(superquadric)': 0.0, |
|
|
'wave_contrastive_loss': 0.0, |
|
|
'world_info_loss(camera,scale,time)': 0.0, |
|
|
'controllable_info_loss(mass,friction,restitution)': 0.0, |
|
|
'pla_loss': 0.0, |
|
|
'wave_count_mse': 0.0, |
|
|
'total': 0.0, |
|
|
} |
|
|
|
|
|
pla_entries = [] |
|
|
pred_summaries: List[torch.Tensor] = [] |
|
|
target_summaries: List[torch.Tensor] = [] |
|
|
|
|
|
for i, (pred, frame_idx) in enumerate(zip(predictions, frame_indices)): |
|
|
|
|
|
target_objects = targets['objects'][:, frame_idx] |
|
|
if target_objects.shape[-1] < self.object_param_dim: |
|
|
pad_width = self.object_param_dim - target_objects.shape[-1] |
|
|
pad = target_objects.new_zeros(*target_objects.shape[:-1], pad_width) |
|
|
target_objects = torch.cat([target_objects, pad], dim=-1) |
|
|
pred_objects = pred['objects'] |
|
|
|
|
|
|
|
|
exists_mask = target_objects[:, :, 0] > 0.5 |
|
|
|
|
|
target_core = target_objects[:, :, :self.object_param_dim] |
|
|
|
|
|
|
|
|
object_loss = self._wave_reconstruction_loss(pred_objects, target_core, exists_mask) |
|
|
losses['wave_loss(superquadric)'] += object_loss |
|
|
|
|
|
|
|
|
target_presence = target_objects[:, :, 0].float() |
|
|
pred_scale_norm = torch.linalg.norm(pred_objects[:, :, 3:6], dim=-1) |
|
|
presence_input = (pred_scale_norm - self.wave_presence_threshold) / max(self.wave_presence_temperature, 1e-6) |
|
|
pred_presence = torch.sigmoid(presence_input) |
|
|
pred_count = pred_presence.sum(dim=-1) |
|
|
target_count = target_presence.sum(dim=-1) |
|
|
count_mse = F.mse_loss(pred_count, target_count) |
|
|
losses['wave_count_mse'] += count_mse |
|
|
losses['wave_loss(superquadric)'] += self.wave_count_weight * count_mse |
|
|
|
|
|
pla_entries.append({ |
|
|
'frame_idx': frame_idx, |
|
|
'pred_objects': pred_objects, |
|
|
'exists_mask': exists_mask, |
|
|
}) |
|
|
|
|
|
|
|
|
mask = exists_mask.float().unsqueeze(-1) |
|
|
|
|
|
denom = mask.sum(dim=1).clamp_min(1.0) |
|
|
pred_summary = (pred_objects * mask).sum(dim=1) / denom |
|
|
target_summary = (target_core * mask).sum(dim=1) / denom |
|
|
pred_summaries.append(pred_summary) |
|
|
target_summaries.append(target_summary) |
|
|
|
|
|
|
|
|
target_world = targets['world'][:, frame_idx] |
|
|
pred_world = pred['world'] |
|
|
|
|
|
|
|
|
world_loss = self.world_loss_fn( |
|
|
pred_world, |
|
|
target_world[:, :8] |
|
|
) |
|
|
losses['world_info_loss(camera,scale,time)'] += world_loss |
|
|
|
|
|
|
|
|
if i == 0: |
|
|
target_physics = targets['physics'] |
|
|
pred_physics = pred['physics'] |
|
|
|
|
|
physics_loss = self.physics_loss_fn( |
|
|
pred_physics[exists_mask], |
|
|
target_physics[exists_mask] |
|
|
) |
|
|
losses['controllable_info_loss(mass,friction,restitution)'] = physics_loss |
|
|
|
|
|
|
|
|
num_frames = len(predictions) |
|
|
losses['wave_loss(superquadric)'] /= num_frames |
|
|
losses['world_info_loss(camera,scale,time)'] /= num_frames |
|
|
losses['wave_count_mse'] /= num_frames |
|
|
|
|
|
|
|
|
total_frames = targets['objects'].shape[1] |
|
|
middle_idx = total_frames // 2 |
|
|
anchor_objects = targets['objects'][:, middle_idx] |
|
|
anchor_exists = anchor_objects[:, :, 0] > 0.5 |
|
|
pla_entries.append({ |
|
|
'frame_idx': middle_idx, |
|
|
'pred_objects': anchor_objects[:, :, :self.object_param_dim].detach(), |
|
|
'exists_mask': anchor_exists, |
|
|
}) |
|
|
|
|
|
|
|
|
pla_loss = self._compute_pla_regularizer(pla_entries) |
|
|
losses['pla_loss'] = pla_loss |
|
|
|
|
|
|
|
|
if pred_summaries: |
|
|
pred_stack = torch.stack(pred_summaries, dim=0).mean(dim=0) |
|
|
target_stack = torch.stack(target_summaries, dim=0).mean(dim=0) |
|
|
losses['wave_contrastive_loss'] = self._contrastive_clip_loss(pred_stack, target_stack) |
|
|
else: |
|
|
device = targets['objects'].device |
|
|
losses['wave_contrastive_loss'] = torch.zeros((), device=device) |
|
|
|
|
|
|
|
|
for key, weight in self.loss_weights.items(): |
|
|
if key in losses: |
|
|
losses['total'] += weight * losses[key] |
|
|
|
|
|
return losses |
|
|
|
|
|
def _wave_reconstruction_loss( |
|
|
self, |
|
|
pred_objects: torch.Tensor, |
|
|
target_objects: torch.Tensor, |
|
|
exists_mask: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
"""Velocity-aware reconstruction loss combining position L1 and velocity L1.""" |
|
|
device = pred_objects.device |
|
|
dtype = pred_objects.dtype |
|
|
if not exists_mask.any(): |
|
|
return torch.zeros((), device=device, dtype=dtype) |
|
|
|
|
|
pred_active = pred_objects[exists_mask] |
|
|
target_active = target_objects[exists_mask] |
|
|
|
|
|
base_l1 = F.l1_loss(pred_active, target_active, reduction='mean') |
|
|
|
|
|
if self.velocity_slice.start >= self.velocity_slice.stop: |
|
|
velocity_l1 = torch.zeros((), device=device, dtype=dtype) |
|
|
else: |
|
|
pred_velocity = pred_active[..., self.velocity_slice] |
|
|
target_velocity = target_active[..., self.velocity_slice] |
|
|
velocity_l1 = F.l1_loss(pred_velocity, target_velocity, reduction='mean') |
|
|
|
|
|
return 0.5 * base_l1 + 0.5 * velocity_l1 |
|
|
|
|
|
def _contrastive_clip_loss( |
|
|
self, |
|
|
pred_summary: torch.Tensor, |
|
|
target_summary: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
"""InfoNCE-style contrastive loss between predicted and target clip summaries.""" |
|
|
device = pred_summary.device |
|
|
dtype = pred_summary.dtype |
|
|
batch = pred_summary.size(0) |
|
|
if batch <= 1: |
|
|
return torch.zeros((), device=device, dtype=dtype) |
|
|
|
|
|
dim = min(pred_summary.size(-1), target_summary.size(-1)) |
|
|
if dim == 0: |
|
|
return torch.zeros((), device=device, dtype=dtype) |
|
|
if pred_summary.size(-1) != dim: |
|
|
pred_summary = pred_summary[..., :dim] |
|
|
if target_summary.size(-1) != dim: |
|
|
target_summary = target_summary[..., :dim] |
|
|
|
|
|
temperature = max(self.wave_contrastive_temperature, 1e-6) |
|
|
pred_norm = F.normalize(pred_summary, dim=-1) |
|
|
target_norm = F.normalize(target_summary, dim=-1) |
|
|
dim_post = min(pred_norm.size(-1), target_norm.size(-1)) |
|
|
if dim_post == 0: |
|
|
return torch.zeros((), device=device, dtype=dtype) |
|
|
if pred_norm.size(-1) != dim_post: |
|
|
pred_norm = pred_norm[..., :dim_post] |
|
|
if target_norm.size(-1) != dim_post: |
|
|
target_norm = target_norm[..., :dim_post] |
|
|
logits = pred_norm @ target_norm.transpose(0, 1) |
|
|
logits = logits / temperature |
|
|
|
|
|
labels = torch.arange(batch, device=device) |
|
|
loss_forward = F.cross_entropy(logits, labels) |
|
|
loss_backward = F.cross_entropy(logits.transpose(0, 1), labels) |
|
|
|
|
|
return 0.5 * (loss_forward + loss_backward) |
|
|
|
|
|
def _compute_pla_regularizer(self, entries: List[Dict[str, torch.Tensor]]) -> torch.Tensor: |
|
|
"""Encourage rigid-body consistency, free-fall dynamics, and collision plausibility.""" |
|
|
model_device = next(self.model.parameters()).device |
|
|
if not entries: |
|
|
return torch.tensor(0.0, device=model_device) |
|
|
|
|
|
|
|
|
sorted_entries = sorted(entries, key=lambda x: x['frame_idx']) |
|
|
|
|
|
device = sorted_entries[0]['pred_objects'].device |
|
|
dtype = sorted_entries[0]['pred_objects'].dtype |
|
|
|
|
|
preds = torch.stack([item['pred_objects'] for item in sorted_entries], dim=0) |
|
|
exists = torch.stack([item['exists_mask'].float() for item in sorted_entries], dim=0) |
|
|
|
|
|
frame_count, batch_size, max_objects, _ = preds.shape |
|
|
|
|
|
if frame_count <= 1: |
|
|
return torch.tensor(0.0, device=device, dtype=dtype) |
|
|
|
|
|
exists_expanded = exists.unsqueeze(-1) |
|
|
exists_total = exists_expanded.sum() |
|
|
if exists_total.item() == 0: |
|
|
return torch.tensor(0.0, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
shape_params = preds[..., 1:3] |
|
|
scale_params = preds[..., 3:6] |
|
|
|
|
|
shape_mean = (shape_params * exists_expanded).sum(dim=0) / exists_expanded.sum(dim=0).clamp_min(1.0) |
|
|
scale_mean = (scale_params * exists_expanded).sum(dim=0) / exists_expanded.sum(dim=0).clamp_min(1.0) |
|
|
|
|
|
shape_loss = ((shape_params - shape_mean) ** 2 * exists_expanded).sum() / exists_expanded.sum().clamp_min(1.0) |
|
|
scale_loss = ((scale_params - scale_mean) ** 2 * exists_expanded).sum() / exists_expanded.sum().clamp_min(1.0) |
|
|
|
|
|
|
|
|
freefall_loss = torch.tensor(0.0, device=device, dtype=dtype) |
|
|
rotation_loss = torch.tensor(0.0, device=device, dtype=dtype) |
|
|
collision_penalty = torch.tensor(0.0, device=device, dtype=dtype) |
|
|
velocity_loss = torch.tensor(0.0, device=device, dtype=dtype) |
|
|
|
|
|
positions = preds[..., 6:9] |
|
|
|
|
|
if frame_count >= 3: |
|
|
radii = torch.linalg.norm(preds[..., 3:6], dim=-1) |
|
|
|
|
|
accel = positions[2:] - 2 * positions[1:-1] + positions[:-2] |
|
|
|
|
|
exists_triplet = exists[1:-1] * exists[:-2] * exists[2:] |
|
|
exists_triplet_expanded = exists_triplet.unsqueeze(-1) |
|
|
|
|
|
|
|
|
center_positions = positions[1:-1].reshape(-1, max_objects, 3) |
|
|
center_exists = exists[1:-1].reshape(-1, max_objects) |
|
|
center_radii = radii[1:-1].reshape(-1, max_objects) |
|
|
|
|
|
if center_positions.numel() > 0: |
|
|
dist = torch.cdist(center_positions, center_positions, p=2) |
|
|
radius_sum = (center_radii.unsqueeze(-1) + center_radii.unsqueeze(-2)) * self.collision_buffer |
|
|
exists_pair = center_exists.unsqueeze(-1) * center_exists.unsqueeze(-2) |
|
|
|
|
|
eye = torch.eye(max_objects, device=device).unsqueeze(0) |
|
|
non_diag = (1 - eye) |
|
|
|
|
|
penetration = torch.relu((radius_sum - dist) * non_diag) * exists_pair |
|
|
collision_penalty = penetration.pow(2).sum() / (non_diag * exists_pair).sum().clamp_min(1.0) |
|
|
|
|
|
contact_any = (penetration > 0).any(dim=-1).view(frame_count - 2, batch_size, max_objects) |
|
|
else: |
|
|
contact_any = torch.zeros(frame_count - 2, batch_size, max_objects, device=device, dtype=torch.bool) |
|
|
|
|
|
contact_mask = contact_any.float() |
|
|
|
|
|
gravity_vec = torch.tensor([0.0, 0.0, -self.gravity], device=device, dtype=dtype).view(1, 1, 1, 3) |
|
|
residual = accel + gravity_vec |
|
|
|
|
|
freefall_mask = exists_triplet_expanded * (1.0 - contact_mask.unsqueeze(-1)) |
|
|
valid_count = freefall_mask.sum().clamp_min(1.0) |
|
|
freefall_loss = (residual.pow(2) * freefall_mask).sum() / valid_count |
|
|
|
|
|
rotations = preds[..., 9:12] |
|
|
rot_sin = torch.sin(rotations) |
|
|
rot_cos = torch.cos(rotations) |
|
|
rot_features = torch.cat([rot_sin, rot_cos], dim=-1) |
|
|
rot_acc = rot_features[2:] - 2 * rot_features[1:-1] + rot_features[:-2] |
|
|
|
|
|
rot_mask = exists_triplet_expanded * (1.0 - contact_mask.unsqueeze(-1)) |
|
|
rot_valid = rot_mask.sum().clamp_min(1.0) |
|
|
rotation_loss = (rot_acc.pow(2) * rot_mask).sum() / rot_valid |
|
|
|
|
|
if frame_count >= 2: |
|
|
velocities = preds[..., 12:15] |
|
|
diff = (positions[1:] - positions[:-1]) * self.frame_rate |
|
|
exists_pair = exists[1:] * exists[:-1] |
|
|
diff_expanded = exists_pair.unsqueeze(-1) |
|
|
|
|
|
velocity_residual = (velocities[1:] - diff).pow(2) * diff_expanded |
|
|
valid_velocity = diff_expanded.sum() |
|
|
velocity_loss = velocity_residual.sum() |
|
|
|
|
|
first_pair = (exists[0] * exists[1]).unsqueeze(-1) |
|
|
velocity_loss += ((velocities[0] - diff[0]) ** 2 * first_pair).sum() |
|
|
valid_velocity += first_pair.sum() |
|
|
|
|
|
velocity_loss = velocity_loss / valid_velocity.clamp_min(1.0) |
|
|
|
|
|
pla_loss = ( |
|
|
shape_loss |
|
|
+ scale_loss |
|
|
+ freefall_loss |
|
|
+ rotation_loss |
|
|
+ collision_penalty |
|
|
+ velocity_loss |
|
|
) |
|
|
return pla_loss |
|
|
|
|
|
def _select_anchor_frame(self, num_frames: int) -> int: |
|
|
"""Determine which frame should serve as the initial anchor.""" |
|
|
cfg = self.config['training'].get('initial_frame', {}) |
|
|
strategy = cfg.get('strategy', 'middle') |
|
|
|
|
|
if strategy == 'random': |
|
|
base_idx = int(torch.randint(low=0, high=num_frames, size=(1,), device=torch.device('cpu')).item()) |
|
|
elif strategy == 'fixed': |
|
|
base_idx = int(cfg.get('index', num_frames // 2)) |
|
|
else: |
|
|
base_idx = num_frames // 2 |
|
|
|
|
|
offset = int(cfg.get('offset', 0)) |
|
|
anchor_idx = base_idx + offset |
|
|
anchor_idx = max(0, min(num_frames - 1, anchor_idx)) |
|
|
return anchor_idx |
|
|
|
|
|
def _generate_full_sequence( |
|
|
self, |
|
|
text: List[str], |
|
|
objects: torch.Tensor, |
|
|
world: torch.Tensor, |
|
|
physics: torch.Tensor, |
|
|
teacher_prob: float, |
|
|
anchor_idx: Optional[int] = None, |
|
|
use_noise: bool = False, |
|
|
) -> Tuple[List[Dict[str, torch.Tensor]], List[int], float]: |
|
|
"""Generate a full sequence of predictions given an anchor frame.""" |
|
|
batch_size, num_frames = objects.shape[:2] |
|
|
if anchor_idx is None: |
|
|
anchor_idx = self._select_anchor_frame(num_frames) |
|
|
|
|
|
static_object_params = None |
|
|
if self.freeze_static_params: |
|
|
anchor_static = objects[:, anchor_idx, :, :6] |
|
|
static_object_params = anchor_static |
|
|
|
|
|
if teacher_prob > 0.0: |
|
|
teacher_mask = (torch.rand(batch_size, device=objects.device) < teacher_prob).float() |
|
|
else: |
|
|
teacher_mask = torch.zeros(batch_size, device=objects.device, dtype=torch.float32) |
|
|
|
|
|
def sample_noise(): |
|
|
return self.base_model.sample_decoder_noise(batch_size, objects.device) if use_noise else None |
|
|
|
|
|
half_span = max(num_frames - 1, 1) / 2.0 |
|
|
inference_time = 0.0 |
|
|
predictions_by_idx: Dict[int, Dict[str, torch.Tensor]] = {} |
|
|
|
|
|
anchor_rel_times = torch.zeros( |
|
|
(batch_size, 1), dtype=torch.float32, device=objects.device |
|
|
) |
|
|
anchor_targets = torch.full( |
|
|
(batch_size, 1), anchor_idx, dtype=torch.long, device=objects.device |
|
|
) |
|
|
|
|
|
start = time.time() |
|
|
anchor_preds = self.model( |
|
|
input_text=text, |
|
|
target_frames=anchor_targets, |
|
|
history_frames=None, |
|
|
relative_times=anchor_rel_times, |
|
|
static_object_params=static_object_params, |
|
|
noise=sample_noise(), |
|
|
) |
|
|
inference_time += time.time() - start |
|
|
anchor_pred = anchor_preds[0] |
|
|
predictions_by_idx[anchor_idx] = anchor_pred |
|
|
|
|
|
anchor_gt_objects = objects[:, anchor_idx, :, :self.object_param_dim] |
|
|
if anchor_gt_objects.shape[-1] < self.object_param_dim: |
|
|
pad_width = self.object_param_dim - anchor_gt_objects.shape[-1] |
|
|
pad = anchor_gt_objects.new_zeros(*anchor_gt_objects.shape[:-1], pad_width) |
|
|
anchor_gt_objects = torch.cat([anchor_gt_objects, pad], dim=-1) |
|
|
anchor_gt_world = world[:, anchor_idx, :8] |
|
|
anchor_pred_objects = anchor_pred['objects'] |
|
|
if static_object_params is not None: |
|
|
anchor_pred_objects[:, :, :6] = static_object_params[:, :, :6] |
|
|
anchor_pred_world = anchor_pred['world'] |
|
|
|
|
|
teacher_mask_objs = teacher_mask.view(batch_size, 1, 1) |
|
|
teacher_mask_world = teacher_mask.view(batch_size, 1) |
|
|
|
|
|
blended_objects = anchor_pred_objects * (1.0 - teacher_mask_objs) + anchor_gt_objects * teacher_mask_objs |
|
|
blended_world = anchor_pred_world * (1.0 - teacher_mask_world) + anchor_gt_world * teacher_mask_world |
|
|
|
|
|
history_objects = blended_objects.unsqueeze(1) |
|
|
history_world = blended_world.unsqueeze(1) |
|
|
history_physics = physics.clone() |
|
|
|
|
|
def make_history_seed(): |
|
|
return { |
|
|
'objects': history_objects.clone(), |
|
|
'world': history_world.clone(), |
|
|
'physics': history_physics.clone(), |
|
|
} |
|
|
|
|
|
backward_indices = list(range(anchor_idx - 1, -1, -1)) |
|
|
forward_indices = list(range(anchor_idx + 1, num_frames)) |
|
|
|
|
|
def run_direction(target_indices: List[int]): |
|
|
nonlocal inference_time |
|
|
if not target_indices: |
|
|
return |
|
|
|
|
|
rel_times = torch.tensor( |
|
|
[(idx - anchor_idx) / half_span for idx in target_indices], |
|
|
dtype=torch.float32, |
|
|
device=objects.device, |
|
|
).unsqueeze(0).repeat(batch_size, 1) |
|
|
|
|
|
target_tensor = torch.tensor( |
|
|
target_indices, |
|
|
dtype=torch.long, |
|
|
device=objects.device, |
|
|
).unsqueeze(0).repeat(batch_size, 1) |
|
|
|
|
|
history_frames = make_history_seed() |
|
|
|
|
|
start_time = time.time() |
|
|
preds = self.model( |
|
|
input_text=text, |
|
|
target_frames=target_tensor, |
|
|
history_frames=history_frames, |
|
|
relative_times=rel_times, |
|
|
static_object_params=static_object_params, |
|
|
noise=sample_noise(), |
|
|
) |
|
|
inference_time += time.time() - start_time |
|
|
|
|
|
for idx, pred in zip(target_indices, preds): |
|
|
if static_object_params is not None: |
|
|
pred['objects'][:, :, :6] = static_object_params[:, :, :6] |
|
|
predictions_by_idx[idx] = pred |
|
|
|
|
|
run_direction(backward_indices) |
|
|
run_direction(forward_indices) |
|
|
|
|
|
ordered_indices = list(range(num_frames)) |
|
|
predictions = [predictions_by_idx[idx] for idx in ordered_indices] |
|
|
return predictions, ordered_indices, inference_time |
|
|
|
|
|
def _compute_losses( |
|
|
self, |
|
|
batch: Dict[str, torch.Tensor], |
|
|
) -> Tuple[Dict[str, torch.Tensor], float, int]: |
|
|
"""Shared logic for computing losses and metadata.""" |
|
|
text = batch['text'] |
|
|
objects = batch['objects'] |
|
|
world = batch['world'] |
|
|
physics = batch['physics'] |
|
|
|
|
|
batch_size, num_frames = objects.shape[:2] |
|
|
anchor_idx = self._select_anchor_frame(num_frames) |
|
|
teacher_prob = float(self.config['training'].get('initial_teacher_forcing_prob', 0.5)) |
|
|
|
|
|
targets = { |
|
|
'objects': objects, |
|
|
'world': world, |
|
|
'physics': physics, |
|
|
} |
|
|
|
|
|
attempts = self.sample_attempts if self.model.training else 1 |
|
|
use_noise = attempts > 1 |
|
|
best_losses: Optional[Dict[str, torch.Tensor]] = None |
|
|
best_predictions: Optional[List[Dict[str, torch.Tensor]]] = None |
|
|
best_frame_indices: Optional[List[int]] = None |
|
|
best_inference_time: float = 0.0 |
|
|
best_total_value: Optional[float] = None |
|
|
|
|
|
for attempt in range(attempts): |
|
|
predictions, frame_indices, inference_time = self._generate_full_sequence( |
|
|
text=text, |
|
|
objects=objects, |
|
|
world=world, |
|
|
physics=physics, |
|
|
teacher_prob=teacher_prob, |
|
|
anchor_idx=anchor_idx, |
|
|
use_noise=use_noise, |
|
|
) |
|
|
|
|
|
losses = self.compute_loss(predictions, targets, frame_indices) |
|
|
total_value = float(losses['total'].detach()) |
|
|
if best_total_value is None or total_value < best_total_value: |
|
|
if best_losses is not None: |
|
|
del best_losses |
|
|
if best_predictions is not None: |
|
|
del best_predictions |
|
|
best_total_value = total_value |
|
|
best_losses = losses |
|
|
best_predictions = predictions |
|
|
best_frame_indices = frame_indices |
|
|
best_inference_time = inference_time |
|
|
else: |
|
|
del losses |
|
|
del predictions |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
assert best_losses is not None and best_predictions is not None and best_frame_indices is not None |
|
|
num_predicted_frames = len(best_predictions) |
|
|
frames_per_second = num_predicted_frames / best_inference_time if best_inference_time > 0 else 0.0 |
|
|
|
|
|
return best_losses, frames_per_second, num_predicted_frames |
|
|
|
|
|
def train_step( |
|
|
self, |
|
|
batch: Dict[str, torch.Tensor], |
|
|
step: int, |
|
|
) -> Dict[str, float]: |
|
|
"""Single training step with bidirectional prediction""" |
|
|
self.model.train() |
|
|
|
|
|
losses, frames_per_second, num_predicted_frames = self._compute_losses(batch) |
|
|
|
|
|
self.accelerator.backward(losses['total']) |
|
|
|
|
|
loss_dict = {k: v.item() if torch.is_tensor(v) else float(v) for k, v in losses.items()} |
|
|
loss_dict['inference_fps'] = frames_per_second |
|
|
loss_dict['frames_predicted'] = num_predicted_frames |
|
|
|
|
|
return loss_dict |
|
|
|
|
|
def evaluate_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: |
|
|
"""Compute losses without gradient updates.""" |
|
|
was_training = self.model.training |
|
|
self.model.eval() |
|
|
with torch.no_grad(): |
|
|
losses, frames_per_second, num_predicted_frames = self._compute_losses(batch) |
|
|
if was_training: |
|
|
self.model.train() |
|
|
|
|
|
loss_dict = {k: v.item() if torch.is_tensor(v) else float(v) for k, v in losses.items()} |
|
|
loss_dict['inference_fps'] = frames_per_second |
|
|
loss_dict['frames_predicted'] = num_predicted_frames |
|
|
return loss_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--train_config', type=str, default='configs/default.yaml', |
|
|
help='Training configuration file') |
|
|
parser.add_argument('--data_root', type=str, |
|
|
default='../data/movi_a_128x128', |
|
|
help='Root directory of MOVi dataset') |
|
|
parser.add_argument('--output_dir', type=str, default='core_space', |
|
|
help='Directory to save checkpoints and generation results') |
|
|
parser.add_argument('--resume_step', type=int, default=None, |
|
|
help='Resume training from specific step') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
with open(args.train_config, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
from accelerate import DistributedDataParallelKwargs |
|
|
|
|
|
ddp_kwargs = DistributedDataParallelKwargs( |
|
|
find_unused_parameters=True, |
|
|
broadcast_buffers=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
accelerator = Accelerator( |
|
|
gradient_accumulation_steps=1, |
|
|
kwargs_handlers=[ddp_kwargs] |
|
|
) |
|
|
|
|
|
|
|
|
set_seed(42) |
|
|
|
|
|
|
|
|
model_name = config.get('text2wave_model', {}).get('model_name', "google/t5-v1_1-small") |
|
|
model = Text2WaveModel( |
|
|
model_name=model_name, |
|
|
max_objects=10, |
|
|
num_frames=24, |
|
|
max_history_frames=config['training']['max_history_frames'], |
|
|
random_history_sampling=config['training'].get('random_history_sampling', True), |
|
|
decoder_noise_std=config['training'].get('decoder_noise_std', 0.0), |
|
|
) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
|
model.parameters(), |
|
|
lr=config['training']['learning_rate'], |
|
|
weight_decay=0.01, |
|
|
) |
|
|
|
|
|
|
|
|
train_dataloader = create_dataloader( |
|
|
data_root=args.data_root, |
|
|
split='train', |
|
|
batch_size=config['training']['batch_size'], |
|
|
num_workers=config['data']['num_workers'], |
|
|
shuffle=True, |
|
|
max_samples=config['data'].get('max_sequences', -1), |
|
|
) |
|
|
|
|
|
val_dataloader = create_dataloader( |
|
|
data_root=args.data_root, |
|
|
split='validation', |
|
|
batch_size=config['training']['batch_size'], |
|
|
num_workers=config['data']['num_workers'], |
|
|
shuffle=False, |
|
|
max_samples=10, |
|
|
) |
|
|
|
|
|
|
|
|
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare( |
|
|
model, optimizer, train_dataloader, val_dataloader |
|
|
) |
|
|
|
|
|
checkpoint_dir = Path("checkpoints_text2wave") |
|
|
if accelerator.is_main_process: |
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
log_file_path = checkpoint_dir / "training_log.txt" |
|
|
|
|
|
def log_message(message: str): |
|
|
"""Log to stdout and append to training_log.txt from main process.""" |
|
|
if not accelerator.is_main_process: |
|
|
return |
|
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
formatted = f"{timestamp} {message}" |
|
|
print(formatted) |
|
|
try: |
|
|
with open(log_file_path, 'a') as fp: |
|
|
fp.write(formatted + "\n") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
best_metrics_path = checkpoint_dir / "best_metrics.json" |
|
|
if best_metrics_path.exists(): |
|
|
try: |
|
|
best_metrics_path.unlink() |
|
|
except OSError as exc: |
|
|
log_message(f"Warning: failed to remove legacy best_metrics.json due to {exc}") |
|
|
|
|
|
best_train_loss = float('inf') |
|
|
best_val_loss = float('inf') |
|
|
|
|
|
evaluation_cfg = config['training'].get('evaluation', {}) |
|
|
eval_max_batches = evaluation_cfg.get('max_batches', 5) |
|
|
|
|
|
training_stats_path = checkpoint_dir / "training_stats.npz" |
|
|
loaded_step_history: Optional[List[int]] = None |
|
|
loaded_loss_history: Dict[str, List[float]] = {} |
|
|
if training_stats_path.exists(): |
|
|
try: |
|
|
stats = np.load(training_stats_path, allow_pickle=True) |
|
|
best_train_loss = float(stats.get('best_train_loss', best_train_loss)) |
|
|
best_val_loss = float(stats.get('best_val_loss', best_val_loss)) |
|
|
if 'step_history' in stats: |
|
|
loaded_step_history = stats['step_history'].tolist() |
|
|
if 'loss_history_keys' in stats and 'loss_history_values' in stats: |
|
|
keys = stats['loss_history_keys'].tolist() |
|
|
values = stats['loss_history_values'].tolist() |
|
|
for key, value in zip(keys, values): |
|
|
loaded_loss_history[str(key)] = list(np.asarray(value, dtype=float)) |
|
|
except Exception as exc: |
|
|
log_message(f"Warning: failed to load training_stats.npz due to {exc}") |
|
|
|
|
|
executor = ThreadPoolExecutor(max_workers=1) |
|
|
pending_futures: List = [] |
|
|
|
|
|
def cleanup_futures(): |
|
|
pending_futures[:] = [f for f in pending_futures if not f.done()] |
|
|
|
|
|
def submit_task(fn, *args, **kwargs): |
|
|
cleanup_futures() |
|
|
future = executor.submit(fn, *args, **kwargs) |
|
|
pending_futures.append(future) |
|
|
return future |
|
|
|
|
|
def recursive_to_cpu(obj): |
|
|
if isinstance(obj, torch.Tensor): |
|
|
return obj.detach().cpu() |
|
|
if isinstance(obj, dict): |
|
|
return {k: recursive_to_cpu(v) for k, v in obj.items()} |
|
|
if isinstance(obj, list): |
|
|
return [recursive_to_cpu(v) for v in obj] |
|
|
if isinstance(obj, tuple): |
|
|
return tuple(recursive_to_cpu(v) for v in obj) |
|
|
return obj |
|
|
|
|
|
def save_checkpoint_async(path: Path, payload: Dict): |
|
|
def _task(): |
|
|
torch.save(payload, path) |
|
|
submit_task(_task) |
|
|
|
|
|
def save_generation_async(predictions: List[Dict], targets: Dict[str, torch.Tensor], texts: List[str], step: int, save_config: Dict, metadata: Dict, batch_data: Dict, data_root: str, data_split: str): |
|
|
def _task(): |
|
|
save_generation_results( |
|
|
predictions=predictions, |
|
|
targets=targets, |
|
|
texts=texts, |
|
|
step=step, |
|
|
output_dir=args.output_dir, |
|
|
save_config=save_config, |
|
|
metadata=metadata, |
|
|
batch_data=batch_data, |
|
|
data_root=data_root, |
|
|
data_split=data_split |
|
|
) |
|
|
submit_task(_task) |
|
|
|
|
|
def compute_validation_loss(max_batches: Optional[int]) -> Optional[float]: |
|
|
limit = -1 if max_batches is None else max_batches |
|
|
if limit == 0: |
|
|
return None |
|
|
total = 0.0 |
|
|
count = 0 |
|
|
for batch_idx, val_batch in enumerate(val_dataloader): |
|
|
val_losses = trainer.evaluate_batch(val_batch) |
|
|
total += val_losses['total'] |
|
|
count += 1 |
|
|
if limit > 0 and (batch_idx + 1) >= limit: |
|
|
break |
|
|
if count == 0: |
|
|
return None |
|
|
return total / count |
|
|
|
|
|
|
|
|
trainer = BidirectionalTrainer(model, config, accelerator) |
|
|
|
|
|
|
|
|
max_steps = config['training']['max_steps'] |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
steps_per_epoch = len(train_dataloader) |
|
|
total_epochs = max_steps / steps_per_epoch |
|
|
log_message("=" * 60) |
|
|
log_message("Dataset Information:") |
|
|
log_message(f"- Training samples: {len(train_dataloader.dataset) if hasattr(train_dataloader, 'dataset') else 'N/A'}") |
|
|
log_message(f"- Batch size: {config['training']['batch_size']}") |
|
|
log_message(f"- Steps per epoch (full dataset): {steps_per_epoch}") |
|
|
log_message(f"- Total training steps: {max_steps}") |
|
|
log_message(f"- Will traverse dataset: {total_epochs:.2f} times") |
|
|
log_message("=" * 60) |
|
|
|
|
|
|
|
|
start_step = 0 |
|
|
resumed_from = None |
|
|
if args.resume_step is not None: |
|
|
checkpoint_path = checkpoint_dir / f"step{args.resume_step}.pt" |
|
|
if checkpoint_path.exists(): |
|
|
log_message(f"Resuming from checkpoint step {args.resume_step}") |
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
accelerator.unwrap_model(model).load_state_dict(checkpoint['model_state_dict']) |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
start_step = checkpoint.get('step', args.resume_step) |
|
|
resumed_from = checkpoint_path |
|
|
else: |
|
|
log_message(f"Warning: Checkpoint for step {args.resume_step} not found, starting from scratch") |
|
|
else: |
|
|
latest_checkpoint_path = checkpoint_dir / "latest.pt" |
|
|
if latest_checkpoint_path.exists(): |
|
|
try: |
|
|
log_message("Resuming from latest checkpoint") |
|
|
checkpoint = torch.load(latest_checkpoint_path, map_location='cpu') |
|
|
accelerator.unwrap_model(model).load_state_dict(checkpoint['model_state_dict']) |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
start_step = checkpoint.get('step', 0) |
|
|
resumed_from = latest_checkpoint_path |
|
|
except Exception as exc: |
|
|
log_message(f"Warning: failed to load latest checkpoint due to {exc}; attempting best checkpoint") |
|
|
try: |
|
|
corrupt_path = latest_checkpoint_path.with_suffix(latest_checkpoint_path.suffix + ".corrupt") |
|
|
latest_checkpoint_path.rename(corrupt_path) |
|
|
log_message(f"Renamed corrupt latest checkpoint to {corrupt_path.name}") |
|
|
except Exception as rename_exc: |
|
|
log_message(f"Warning: could not rename corrupt latest checkpoint: {rename_exc}") |
|
|
if resumed_from is None: |
|
|
best_checkpoint_path = checkpoint_dir / "best.pt" |
|
|
if best_checkpoint_path.exists(): |
|
|
try: |
|
|
log_message("Resuming from best checkpoint") |
|
|
checkpoint = torch.load(best_checkpoint_path, map_location='cpu') |
|
|
accelerator.unwrap_model(model).load_state_dict(checkpoint['model_state_dict']) |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
start_step = checkpoint.get('step', 0) |
|
|
resumed_from = best_checkpoint_path |
|
|
except Exception as exc: |
|
|
log_message(f"Warning: failed to load best checkpoint due to {exc}; starting from scratch") |
|
|
|
|
|
|
|
|
log_dir = checkpoint_dir |
|
|
loss_history = defaultdict(list) |
|
|
step_history: List[int] = [] |
|
|
if loaded_step_history: |
|
|
step_history.extend(int(s) for s in loaded_step_history) |
|
|
if loaded_loss_history: |
|
|
for key, values in loaded_loss_history.items(): |
|
|
loss_history[key].extend(values) |
|
|
last_plot_time = time.time() |
|
|
plot_path = log_dir / "losses.png" |
|
|
|
|
|
def save_training_stats(): |
|
|
if not accelerator.is_main_process: |
|
|
return |
|
|
keys = sorted(loss_history.keys()) |
|
|
loss_arrays = [np.array(loss_history[k], dtype=np.float32) for k in keys] |
|
|
np.savez( |
|
|
training_stats_path, |
|
|
best_train_loss=best_train_loss, |
|
|
best_val_loss=best_val_loss, |
|
|
step_history=np.array(step_history, dtype=np.int64), |
|
|
loss_history_keys=np.array(keys, dtype=object), |
|
|
loss_history_values=np.array(loss_arrays, dtype=object), |
|
|
) |
|
|
|
|
|
def update_loss_plot(): |
|
|
if not accelerator.is_main_process or not step_history: |
|
|
return |
|
|
x_values = np.array(step_history, dtype=np.int64) |
|
|
keys = [k for k, v in sorted(loss_history.items()) if v] |
|
|
if not keys: |
|
|
return |
|
|
|
|
|
def align_series(series: List[float]) -> np.ndarray: |
|
|
y_vals = np.array(series, dtype=np.float32) |
|
|
if len(y_vals) > len(x_values): |
|
|
y_vals = y_vals[-len(x_values):] |
|
|
elif len(y_vals) < len(x_values): |
|
|
pad = np.full(len(x_values) - len(y_vals), np.nan, dtype=np.float32) |
|
|
y_vals = np.concatenate([pad, y_vals]) |
|
|
return y_vals |
|
|
|
|
|
fig_height = 3 * (len(keys) + 1) |
|
|
fig, axes = plt.subplots(len(keys) + 1, 1, figsize=(10, fig_height), sharex=True) |
|
|
if not isinstance(axes, np.ndarray): |
|
|
axes = np.array([axes]) |
|
|
|
|
|
cmap = plt.get_cmap('tab10', len(keys)) |
|
|
|
|
|
aggregated_ax = axes[0] |
|
|
aggregated_ax.set_title("Training Losses (all)") |
|
|
aggregated_ax.set_ylabel("Loss") |
|
|
aggregated_ax.grid(True, alpha=0.3) |
|
|
|
|
|
for idx, key in enumerate(keys): |
|
|
y_aligned = align_series(loss_history[key]) |
|
|
if np.all(np.isnan(y_aligned)): |
|
|
continue |
|
|
color = cmap(idx % cmap.N) |
|
|
aggregated_ax.plot(x_values, y_aligned, label=key, color=color) |
|
|
ax = axes[idx + 1] |
|
|
ax.plot(x_values, y_aligned, color=color) |
|
|
ax.set_ylabel(key) |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
axes[-1].set_xlabel("Step") |
|
|
aggregated_ax.legend() |
|
|
fig.tight_layout() |
|
|
fig.savefig(plot_path) |
|
|
plt.close(fig) |
|
|
save_training_stats() |
|
|
|
|
|
if accelerator.is_main_process and step_history: |
|
|
update_loss_plot() |
|
|
|
|
|
|
|
|
global_step = start_step |
|
|
|
|
|
with tqdm(total=max_steps, initial=start_step, disable=not accelerator.is_local_main_process, position=0, leave=True) as pbar: |
|
|
while global_step < max_steps: |
|
|
for batch in train_dataloader: |
|
|
|
|
|
losses = trainer.train_step(batch, global_step) |
|
|
|
|
|
|
|
|
if accelerator.is_local_main_process: |
|
|
pbar.update(1) |
|
|
|
|
|
display_losses = losses.copy() |
|
|
display_losses['fps'] = losses['inference_fps'] |
|
|
pbar.set_postfix(display_losses) |
|
|
|
|
|
|
|
|
loss_str = f"Step {global_step}: " |
|
|
for k, v in losses.items(): |
|
|
if k not in ['inference_fps', 'frames_predicted']: |
|
|
loss_str += f"{k}={v:.4f} " |
|
|
loss_str += f"| {losses['frames_predicted']} frames @ {losses['inference_fps']:.1f} fps (training speed, inference faster)" |
|
|
tqdm.write(loss_str) |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
step_history.append(global_step) |
|
|
for k, v in losses.items(): |
|
|
if k in ['inference_fps', 'frames_predicted']: |
|
|
continue |
|
|
loss_history[k].append(v) |
|
|
current_time = time.time() |
|
|
if current_time - last_plot_time >= 10: |
|
|
update_loss_plot() |
|
|
last_plot_time = current_time |
|
|
|
|
|
|
|
|
|
|
|
save_condition = (global_step == 5) or (global_step > 0 and global_step % config['training']['save_generation']['save_interval'] == 0) |
|
|
if save_condition: |
|
|
if accelerator.is_main_process: |
|
|
generation_save_dir = Path(args.output_dir) |
|
|
generation_save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
current_train_loss = losses['total'] |
|
|
val_loss = compute_validation_loss(eval_max_batches) |
|
|
|
|
|
model_state = recursive_to_cpu(accelerator.get_state_dict(model)) |
|
|
optimizer_state = recursive_to_cpu(optimizer.state_dict()) |
|
|
payload = { |
|
|
'step': global_step, |
|
|
'model_state_dict': model_state, |
|
|
'optimizer_state_dict': optimizer_state, |
|
|
'config': config, |
|
|
} |
|
|
|
|
|
latest_checkpoint_path = checkpoint_dir / "latest.pt" |
|
|
save_checkpoint_async(latest_checkpoint_path, dict(payload)) |
|
|
save_training_stats() |
|
|
|
|
|
is_new_best = False |
|
|
if val_loss is not None: |
|
|
if val_loss < best_val_loss: |
|
|
best_val_loss = val_loss |
|
|
best_train_loss = min(best_train_loss, current_train_loss) |
|
|
is_new_best = True |
|
|
else: |
|
|
if current_train_loss < best_train_loss: |
|
|
best_train_loss = current_train_loss |
|
|
is_new_best = True |
|
|
|
|
|
if is_new_best: |
|
|
best_checkpoint_path = checkpoint_dir / "best.pt" |
|
|
save_checkpoint_async(best_checkpoint_path, dict(payload)) |
|
|
save_training_stats() |
|
|
if val_loss is not None: |
|
|
log_message(f"New best checkpoint at step {global_step}: train_loss={current_train_loss:.6f}, val_loss={val_loss:.6f}") |
|
|
else: |
|
|
log_message(f"New best checkpoint at step {global_step}: train_loss={current_train_loss:.6f}") |
|
|
|
|
|
if config['training']['save_generation']['enabled']: |
|
|
with torch.no_grad(): |
|
|
val_batch = next(iter(val_dataloader)) |
|
|
texts = val_batch['text'][:5] |
|
|
val_objects = val_batch['objects'][:5] |
|
|
val_world = val_batch['world'][:5] |
|
|
val_physics = val_batch.get('physics') |
|
|
if val_physics is not None: |
|
|
val_physics = val_physics[:5] |
|
|
else: |
|
|
val_physics = torch.zeros_like(val_objects[:, 0, :, :3]) |
|
|
val_device = val_objects.device |
|
|
val_batch_size, val_num_frames = val_objects.shape[:2] |
|
|
anchor_idx = trainer._select_anchor_frame(val_num_frames) |
|
|
predictions, generated_indices, _ = trainer._generate_full_sequence( |
|
|
text=texts, |
|
|
objects=val_objects, |
|
|
world=val_world, |
|
|
physics=val_physics, |
|
|
teacher_prob=0.0, |
|
|
anchor_idx=anchor_idx, |
|
|
) |
|
|
|
|
|
val_objects_cpu = val_objects.detach().cpu() |
|
|
val_world_cpu = val_world.detach().cpu() |
|
|
val_physics_cpu = val_physics.detach().cpu() |
|
|
val_batch_cpu = recursive_to_cpu(val_batch) |
|
|
predictions_cpu = [{ |
|
|
'objects': pred['objects'].detach().cpu(), |
|
|
'world': pred['world'].detach().cpu(), |
|
|
'physics': pred['physics'].detach().cpu(), |
|
|
} for pred in predictions] |
|
|
targets_cpu = { |
|
|
'objects': val_objects_cpu, |
|
|
'world': val_world_cpu, |
|
|
'physics': val_physics_cpu, |
|
|
} |
|
|
metadata = { |
|
|
'sequence_names': val_batch.get('sequence_names', None)[:5] if 'sequence_names' in val_batch else None, |
|
|
'generated_indices': generated_indices, |
|
|
} |
|
|
save_generation_async( |
|
|
predictions=predictions_cpu, |
|
|
targets=targets_cpu, |
|
|
texts=list(texts), |
|
|
step=global_step, |
|
|
save_config=config['training']['save_generation'], |
|
|
metadata=metadata, |
|
|
batch_data=val_batch_cpu, |
|
|
data_root=args.data_root, |
|
|
data_split='validation' |
|
|
) |
|
|
else: |
|
|
msg = f"No improvement at step {global_step}: train_loss={current_train_loss:.6f}" |
|
|
if val_loss is not None: |
|
|
msg += f", val_loss={val_loss:.6f}" |
|
|
log_message(msg) |
|
|
|
|
|
|
|
|
if accelerator.sync_gradients: |
|
|
clip_val = config['training'].get('gradient_clip_val', 1.0) |
|
|
accelerator.clip_grad_norm_(model.parameters(), max_norm=clip_val) |
|
|
|
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
global_step += 1 |
|
|
|
|
|
if global_step >= max_steps: |
|
|
break |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
update_loss_plot() |
|
|
|
|
|
|
|
|
executor.shutdown(wait=True) |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
final_checkpoint_path = checkpoint_dir / f"step{global_step}_final.pt" |
|
|
|
|
|
torch.save({ |
|
|
'step': global_step, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'config': config, |
|
|
}, final_checkpoint_path) |
|
|
|
|
|
|
|
|
best_path = checkpoint_dir / "best.pt" |
|
|
if best_path.exists() or best_path.is_symlink(): |
|
|
best_path.unlink() |
|
|
best_path.symlink_to(final_checkpoint_path.name) |
|
|
|
|
|
log_message(f"Saved final checkpoint: {final_checkpoint_path}") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|