| | import os |
| | import sys |
| | import yaml |
| | import argparse |
| | import wandb |
| | import torch |
| | import torch.distributed as dist |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| | from torch.utils.data import DataLoader, DistributedSampler, Subset |
| | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | from tqdm import tqdm |
| | import time |
| | import numpy as np |
| | import signal |
| |
|
| | |
| | _model = None |
| | _optimizer = None |
| | _step = 0 |
| | _epoch = 0 |
| | _ckpt_dir = "" |
| | _wandb_run_id = None |
| |
|
| | def signal_handler(sig, frame): |
| | """Save checkpoint on SIGTERM (Slurm timeout/preemption).""" |
| | global _model, _optimizer, _step, _epoch, _ckpt_dir, _wandb_run_id |
| | if _model is not None and _ckpt_dir: |
| | rank = 0 |
| | if dist.is_initialized(): |
| | rank = dist.get_rank() |
| | |
| | if rank == 0: |
| | print(f"\n[SIGNAL {sig}] Saving emergency checkpoint at step {_step}...") |
| | ckpt_path = os.path.join(_ckpt_dir, f"checkpoint_signal_{_step}.pt") |
| | save_checkpoint(_model, _optimizer, _step, _epoch, ckpt_path, wandb_run_id=_wandb_run_id) |
| | print(f"--- Emergency Checkpoint Saved: {ckpt_path} ---") |
| | wandb.finish() |
| | sys.exit(0) |
| |
|
| | |
| | signal.signal(signal.SIGTERM, signal_handler) |
| |
|
| | |
| | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) |
| |
|
| | from wm.model.interface import get_dynamics_class |
| | from wm.dataset.dataset import RoboticsDatasetWrapper |
| | from wm.utils.visualization import visualize_layout |
| |
|
| | def setup_ddp(): |
| | if 'RANK' in os.environ: |
| | dist.init_process_group("nccl") |
| | rank = int(os.environ['RANK']) |
| | local_rank = int(os.environ['LOCAL_RANK']) |
| | world_size = int(os.environ['WORLD_SIZE']) |
| | torch.cuda.set_device(local_rank) |
| | return rank, local_rank, world_size |
| | else: |
| | return 0, 0, 1 |
| |
|
| | def cleanup_ddp(): |
| | if dist.is_initialized(): |
| | dist.destroy_process_group() |
| |
|
| | def save_checkpoint(model, optimizer, step, epoch, path, wandb_run_id=None, save_numbered=True): |
| | checkpoint = { |
| | 'model_state_dict': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'step': step, |
| | 'epoch': epoch, |
| | 'wandb_run_id': wandb_run_id |
| | } |
| | if save_numbered and path: |
| | torch.save(checkpoint, path) |
| | |
| | |
| | ckpt_dir = os.path.dirname(path) if path else _ckpt_dir |
| | latest_path = os.path.join(ckpt_dir, "latest.pt") |
| | torch.save(checkpoint, latest_path) |
| |
|
| | def load_checkpoint(model, optimizer, path, device): |
| | if not os.path.exists(path): |
| | return 0, 0, None |
| | |
| | print(f"--- Loading Checkpoint from {path} ---") |
| | checkpoint = torch.load(path, map_location=device, weights_only=False) |
| | |
| | |
| | state_dict = checkpoint['model_state_dict'] |
| | |
| | |
| | |
| | scheduler_buffers = [ |
| | 'scheduler.sigmas', |
| | 'scheduler.timesteps', |
| | 'scheduler.linear_timesteps_weights' |
| | ] |
| | for k in scheduler_buffers: |
| | if k in state_dict: |
| | del state_dict[k] |
| | if f"module.{k}" in state_dict: |
| | del state_dict[f"module.{k}"] |
| | |
| | if hasattr(model, 'module'): |
| | model.module.load_state_dict(state_dict, strict=False) |
| | else: |
| | |
| | if any(k.startswith('module.') for k in state_dict.keys()): |
| | state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} |
| | model.load_state_dict(state_dict, strict=False) |
| | |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | return checkpoint['step'], checkpoint['epoch'], checkpoint.get('wandb_run_id') |
| |
|
| | def log_videos_to_wandb(model, val_loader, device, step, dataset_name, gen_mode="parallel", num_inference_steps=50): |
| | model.eval() |
| | video_logs = [] |
| | |
| | with torch.no_grad(): |
| | try: |
| | |
| | batch = next(iter(val_loader)) |
| | except StopIteration: |
| | return |
| | |
| | obs = batch['obs'].to(device) |
| | action = batch['action'].to(device) |
| | |
| | |
| | |
| | |
| | o_0 = obs[:, 0].permute(0, 2, 3, 1).contiguous() |
| | |
| | |
| | if hasattr(model, 'module'): |
| | curr_model = model.module |
| | else: |
| | curr_model = model |
| |
|
| | try: |
| | pred_video = curr_model.generate(o_0, action, mode=gen_mode, num_inference_steps=num_inference_steps) |
| | except TypeError: |
| | |
| | pred_video = curr_model.generate(o_0, action) |
| | |
| | |
| | |
| | for b in range(min(obs.shape[0], 8)): |
| | |
| | gt_with_layout = visualize_layout(obs[b].cpu().numpy(), action[b].cpu().numpy(), dataset_name) |
| | |
| | |
| | |
| | pred_obs_b = pred_video[b].permute(0, 3, 1, 2).cpu().numpy() |
| | pred_with_layout = visualize_layout(pred_obs_b, action[b].cpu().numpy(), dataset_name) |
| | |
| | |
| | |
| | combined = np.concatenate([gt_with_layout, pred_with_layout], axis=2) |
| | combined = combined.transpose(0, 3, 1, 2) |
| | |
| | video_logs.append(wandb.Video(combined, fps=10, format="mp4", caption=f"Step {step} - Sample {b} (GT vs Pred)")) |
| | |
| | if video_logs: |
| | wandb.log({"val/videos": video_logs}, step=step) |
| | model.train() |
| |
|
| | def evaluate_mse(model, val_loader, device, step, num_batches=1, gen_mode="parallel", num_inference_steps=50): |
| | model.eval() |
| | all_mse = [] |
| | |
| | with torch.no_grad(): |
| | for i, batch in enumerate(val_loader): |
| | if i >= num_batches: |
| | break |
| | |
| | obs = batch['obs'].to(device) |
| | action = batch['action'].to(device) |
| | |
| | o_0 = obs[:, 0].permute(0, 2, 3, 1).contiguous() |
| | |
| | |
| | if hasattr(model, 'module'): |
| | curr_model = model.module |
| | else: |
| | curr_model = model |
| |
|
| | try: |
| | pred_video = curr_model.generate(o_0, action, mode=gen_mode, num_inference_steps=num_inference_steps) |
| | except TypeError: |
| | pred_video = curr_model.generate(o_0, action) |
| | |
| | |
| | |
| | gt_video = obs.permute(0, 1, 3, 4, 2).contiguous() |
| | |
| | mse = torch.mean((pred_video - gt_video) ** 2) |
| | all_mse.append(mse.item()) |
| | |
| | avg_mse = sum(all_mse) / len(all_mse) if all_mse else 0 |
| | wandb.log({"val/mse_rollout": avg_mse}, step=step) |
| | model.train() |
| | return avg_mse |
| |
|
| | def main(): |
| | global _model, _optimizer, _step, _epoch, _ckpt_dir, _wandb_run_id |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--config", type=str, required=True, help="Path to yaml config") |
| | parser.add_argument("--resume", action="store_true", help="Resume from latest checkpoint in wandb dir") |
| | parser.add_argument("--ckpt_path", type=str, default=None, help="Explicit path to checkpoint to resume from") |
| | args = parser.parse_args() |
| | |
| | |
| | with open(args.config, 'r') as f: |
| | config = yaml.safe_load(f) |
| | |
| | rank, local_rank, world_size = setup_ddp() |
| | device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu") |
| | |
| | |
| | ckpt_dir = os.path.join("checkpoints", config['wandb']['run_name']) |
| | os.makedirs(ckpt_dir, exist_ok=True) |
| | _ckpt_dir = ckpt_dir |
| | |
| | |
| | dynamics_class_name = config['dynamics_class'] |
| | model_name = config['model_name'] |
| | model_config = config['model_config'] |
| | |
| | if rank == 0: |
| | print(f"--- Initializing Dynamics Model: {dynamics_class_name} ({model_name}) ---") |
| |
|
| | dynamics_class = get_dynamics_class(dynamics_class_name) |
| | dynamics_model = dynamics_class(model_name, model_config).to(device) |
| | |
| | |
| | optimizer = torch.optim.AdamW(dynamics_model.parameters(), lr=float(config['training']['learning_rate'])) |
| | _optimizer = optimizer |
| |
|
| | |
| | start_step = 0 |
| | start_epoch = 0 |
| | wandb_run_id = None |
| | |
| | resume_path = args.ckpt_path |
| | if args.resume and resume_path is None: |
| | potential_latest = os.path.join(ckpt_dir, "latest.pt") |
| | if os.path.exists(potential_latest): |
| | resume_path = potential_latest |
| |
|
| | if resume_path: |
| | start_step, start_epoch, wandb_run_id = load_checkpoint(dynamics_model, optimizer, resume_path, device) |
| | _wandb_run_id = wandb_run_id |
| |
|
| | |
| | if config['distributed']['use_fsdp']: |
| | model = FSDP(dynamics_model) |
| | elif world_size > 1: |
| | model = DDP(dynamics_model, device_ids=[local_rank], find_unused_parameters=True) |
| | else: |
| | model = dynamics_model |
| | _model = model |
| |
|
| | if rank == 0: |
| | params = sum(p.numel() for p in dynamics_model.model.parameters() if p.requires_grad) |
| | print(f"Model Parameters: {params / 1e6:.2f}M") |
| | print(f"--- Distributed Setup Finished ---") |
| | print(f"World Size: {world_size}") |
| | print(f"Device: {device}") |
| |
|
| | |
| | if rank == 0: |
| | if config['wandb']['api_key'] != "YOUR_WANDB_API_KEY_HERE": |
| | os.environ["WANDB_API_KEY"] = config['wandb']['api_key'] |
| | |
| | wandb.init( |
| | project=config['wandb']['project'], |
| | name=config['wandb']['run_name'], |
| | config=config, |
| | id=wandb_run_id, |
| | resume="allow" |
| | ) |
| | wandb_run_id = wandb.run.id |
| | _wandb_run_id = wandb_run_id |
| | |
| | |
| | dataset_name = config['dataset']['name'] |
| | if rank == 0: |
| | print(f"--- Loading Dataset: {dataset_name} ---") |
| | |
| | |
| | train_seq_len = config['dataset'].get('train_seq_len', config['dataset'].get('seq_len')) |
| | eval_seq_len = config['dataset'].get('eval_seq_len', config['dataset'].get('seq_len')) |
| | gen_mode = config['training'].get('gen_mode', 'parallel') |
| | inference_steps = config['training'].get('inference_steps', 50) |
| |
|
| | |
| | train_dataset_full = RoboticsDatasetWrapper.get_dataset(dataset_name, seq_len=train_seq_len) |
| | val_dataset_full = RoboticsDatasetWrapper.get_dataset(dataset_name, seq_len=eval_seq_len) |
| | |
| | |
| | |
| | unique_traj_ids = sorted(list(set([idx[0] for idx in train_dataset_full.indices]))) |
| | num_total_trajs = len(unique_traj_ids) |
| | |
| | split_ratio = config['dataset'].get('train_test_split', 10) |
| | num_val_trajs = max(1, num_total_trajs // (split_ratio + 1)) |
| | |
| | |
| | import random |
| | random.seed(42) |
| | random.shuffle(unique_traj_ids) |
| | |
| | val_traj_ids = set(unique_traj_ids[:num_val_trajs]) |
| | |
| | train_indices = [] |
| | val_indices = [] |
| | |
| | |
| | for i, (traj_idx, start_f) in enumerate(train_dataset_full.indices): |
| | if traj_idx not in val_traj_ids: |
| | train_indices.append(i) |
| | |
| | for i, (traj_idx, start_f) in enumerate(val_dataset_full.indices): |
| | if traj_idx in val_traj_ids: |
| | val_indices.append(i) |
| | |
| | if rank == 0: |
| | print(f"Split: {len(train_indices)} train windows (T={train_seq_len}), " |
| | f"{len(val_indices)} val windows (T={eval_seq_len})") |
| | |
| | train_dataset = Subset(train_dataset_full, train_indices) |
| | val_dataset = Subset(val_dataset_full, val_indices) |
| | |
| | train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) if world_size > 1 else None |
| | |
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=config['training']['batch_size'], |
| | sampler=train_sampler, |
| | shuffle=(train_sampler is None), |
| | num_workers=config['training']['num_workers'], |
| | pin_memory=True |
| | ) |
| | |
| | |
| | val_g = torch.Generator() |
| | val_g.manual_seed(42) |
| | val_loader = DataLoader( |
| | val_dataset, |
| | batch_size=config['training']['batch_size'], |
| | shuffle=True, |
| | num_workers=config['training']['num_workers'], |
| | generator=val_g |
| | ) |
| | |
| | |
| | num_epochs = config['training']['num_epochs'] |
| | step = start_step |
| | _step = step |
| | _epoch = start_epoch |
| | |
| | if rank == 0: |
| | print(f"--- Starting Training Loop: {num_epochs} Epochs (Starting from Epoch {start_epoch}, Step {start_step}) ---") |
| |
|
| | for epoch in range(start_epoch, num_epochs): |
| | _epoch = epoch |
| | if train_sampler: |
| | train_sampler.set_epoch(epoch) |
| | |
| | model.train() |
| | pbar = tqdm(train_loader, desc=f"Epoch {epoch}", disable=(rank != 0)) |
| | |
| | last_step_end = time.time() |
| | for batch in pbar: |
| | |
| | data_time = time.time() - last_step_end |
| | |
| | obs = batch['obs'].to(device) |
| | action = batch['action'].to(device) |
| | |
| | optimizer.zero_grad() |
| | |
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.synchronize() |
| | |
| | |
| | t_enc_start = time.time() |
| | |
| | with torch.no_grad(): |
| | z = model.module.encode_obs(obs) if hasattr(model, 'module') else model.encode_obs(obs) |
| | if torch.cuda.is_available(): |
| | torch.cuda.synchronize() |
| | vae_time = time.time() - t_enc_start |
| | |
| | |
| | t_update_start = time.time() |
| | |
| | |
| | loss = model.module.training_loss(z, action) if hasattr(model, 'module') else model.training_loss(z, action) |
| | |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), config['training']['grad_clip']) |
| | optimizer.step() |
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.synchronize() |
| | update_time = time.time() - t_update_start |
| | |
| | |
| | step += 1 |
| | |
| | |
| | if rank == 0: |
| | step_time = time.time() - last_step_end |
| | pbar.set_postfix({ |
| | "loss": f"{loss.item():.4f}", |
| | "dt": f"{data_time:.2f}s", |
| | "vae": f"{vae_time:.2f}s", |
| | "up": f"{update_time:.2f}s", |
| | "st": f"{step_time:.2f}s" |
| | }) |
| | |
| | if step % config['training']['log_freq'] == 0: |
| | print(f"Step {step} (Epoch {epoch}) | Loss: {loss.item():.4f} | " |
| | f"Data: {data_time:.3f}s | VAE: {vae_time:.3f}s | " |
| | f"Update: {update_time:.3f}s | Step: {step_time:.3f}s") |
| | wandb.log({ |
| | "train/loss": loss.item(), |
| | "train/epoch": epoch, |
| | "time/data_loading": data_time, |
| | "time/vae_encoding": vae_time, |
| | "time/model_update": update_time, |
| | "time/seconds_per_step": step_time, |
| | }, step=step) |
| | |
| | |
| | eval_freq = config['training'].get('eval_freq', 50) |
| | if step % eval_freq == 0: |
| | print(f"\n--- Calculating Val MSE at Step {step} ---") |
| | evaluate_mse(model, val_loader, device, step, num_batches=2, gen_mode=gen_mode, num_inference_steps=inference_steps) |
| | |
| | |
| | if step % config['training']['val_freq'] == 0: |
| | print(f"\n--- Logging Validation Videos at Step {step} ---") |
| | log_videos_to_wandb(model, val_loader, device, step, dataset_name, gen_mode=gen_mode, num_inference_steps=inference_steps) |
| | print(f"--- Video Logging Finished ---") |
| | |
| | |
| | |
| | ckpt_freq = config['training'].get('checkpoint_freq', 2000) |
| | latest_freq = config['training'].get('latest_freq', 500) |
| | |
| | if step % ckpt_freq == 0: |
| | ckpt_path = os.path.join(ckpt_dir, f"checkpoint_{step}.pt") |
| | save_checkpoint(model, optimizer, step, epoch, ckpt_path, wandb_run_id=wandb_run_id, save_numbered=True) |
| | print(f"--- Numbered Checkpoint Saved: {ckpt_path} ---") |
| | |
| | elif step % latest_freq == 0: |
| | save_checkpoint(model, optimizer, step, epoch, None, wandb_run_id=wandb_run_id, save_numbered=False) |
| | print(f"--- Latest Checkpoint Updated (Step {step}) ---") |
| | |
| | |
| | _step = step |
| | last_step_end = time.time() |
| | |
| | dist.barrier() if world_size > 1 else None |
| | |
| | if rank == 0: |
| | wandb.finish() |
| | cleanup_ddp() |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|