#train.py import os import time import math import glob import json import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm import traceback import matplotlib.pyplot as plt from collections import Counter # --- New Modules --- import dataloader as dl from embeddings import GeneralistComfortDT from losses import compute_generalist_loss, GeneralistLossConfig import plots # ============================================================ # CONFIGURATION # ============================================================ DATA_DIR = "TrajectoryData_from_docker" RUNS_DIR = "training-runs" # Architecture VOCAB_SIZE = 512 D_MODEL = 256 N_LAYERS = 6 N_HEADS = 8 DROPOUT = 0.1 MAX_ZONES = 32 # Training BATCH_SIZE = 16 EPOCHS = 50 LR = 3e-4 WARMUP_STEPS = 1000 WEIGHT_DECAY = 1e-2 GRAD_CLIP = 1.0 MAX_TOKENS_PER_STEP = 64 CONTEXT_LEN = 48 CONTEXT_DIM = 10 RTG_DIM = 2 # Energy + Comfort # Loss Weights W_ACTION = 1.0 W_PHYSICS = 1.0 W_VALUE = 1.0 # Generalist Stitching Config USE_TOPK = True TOPK_FRACTION = 1.0 TOPK_MODE = "filter" TOPK_ON = "pareto" RTG_SCALE = 1.0 # Robustness RTG_DROPOUT_PROB = 0.2 SEED = 42 NUM_WORKERS = 12 # ============================================================ # UTILITIES # ============================================================ def set_seed(s): torch.manual_seed(s) torch.cuda.manual_seed_all(s) np.random.seed(s) def list_episode_npzs(data_dir: str): paths = sorted(glob.glob(os.path.join(DATA_DIR, "TrajectoryData_officesmall", "**", "traj_ep*_seed*.npz"), recursive=True)) paths = [p for p in paths if "norm_stats" not in p and "cache" not in p] return paths def load_checkpoint_if_available(run_dir, model, opt, scaler, device): last_path = os.path.join(run_dir, "last.pt") if not os.path.exists(last_path): return 1, 0 ckpt = torch.load(last_path, map_location=device) model.load_state_dict(ckpt["model"]) opt.load_state_dict(ckpt["opt"]) scaler.load_state_dict(ckpt["scaler"]) start_epoch = int(ckpt.get("epoch", 0)) + 1 global_step = int(ckpt.get("global_step", 0)) print(f"[Resume] Loaded {last_path} | start_epoch={start_epoch} global_step={global_step}") return start_epoch, global_step def save_checkpoint(run_dir, model, opt, scaler, epoch, global_step, name): ckpt = { "epoch": epoch, "global_step": global_step, "model": model.state_dict(), "opt": opt.state_dict(), "scaler": scaler.state_dict(), } torch.save(ckpt, os.path.join(run_dir, name)) def get_run_dir(): os.makedirs(RUNS_DIR, exist_ok=True) existing = len(glob.glob(os.path.join(RUNS_DIR, "run_*"))) path = os.path.join(RUNS_DIR, f"run_{existing+1:03d}") os.makedirs(path, exist_ok=True) os.makedirs(os.path.join(path, "plots"), exist_ok=True) return path def _atomic_write_json(path, obj): tmp = path + ".tmp" with open(tmp, "w") as f: json.dump(obj, f, indent=2) os.replace(tmp, path) # ============================================================ # MAIN LOOP # ============================================================ def main(): set_seed(SEED) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision("high") run_dir = get_run_dir() os.makedirs(os.path.join(run_dir, "plot_data"), exist_ok=True) report_path = os.path.join(run_dir, "report.json") metrics_csv = os.path.join(run_dir, "metrics.csv") hist = {"step": [], "loss": [], "acc": [], "phy": [], "val": [], "lr": [], "grad_norm": [], "loss_action": []} epoch_hist = {"epoch": [], "loss_mean": [], "acc_mean": [], "phy_mean": [], "val_mean": []} report = { "run_dir": run_dir, "started_at": time.strftime("%Y-%m-%d %H:%M:%S"), "config": { "DATA_DIR": DATA_DIR, "MAX_TOKENS": MAX_TOKENS_PER_STEP, "BATCH_SIZE": BATCH_SIZE, "LR": LR, "SEED": SEED }, "status": "running", "progress": {"epoch": 0, "global_step": 0}, } _atomic_write_json(report_path, report) try: print(f"Loading data from {DATA_DIR}...") all_paths = list_episode_npzs(DATA_DIR) if not all_paths: raise RuntimeError(f"No valid npz files found in {DATA_DIR}") train_ds = dl.GeneralistDataset( all_paths, seed=SEED, max_tokens=MAX_TOKENS_PER_STEP, topk_frac=TOPK_FRACTION, topk_mode=TOPK_MODE, topk_on=TOPK_ON ) train_ds.is_train = True train_loader = DataLoader( train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, pin_memory_device="cuda", persistent_workers=True, prefetch_factor=4, collate_fn=dl.generalist_collate_fn, drop_last=True ) model_config = { "VOCAB_SIZE": VOCAB_SIZE, "D_MODEL": D_MODEL, "N_LAYERS": N_LAYERS, "N_HEADS": N_HEADS, "DROPOUT": DROPOUT, "MAX_ZONES": MAX_ZONES, "CONTEXT_LEN": CONTEXT_LEN, "NUM_ACTION_BINS": dl.NUM_ACTION_BINS, "CONTEXT_DIM": CONTEXT_DIM, "RTG_DIM": RTG_DIM } model = GeneralistComfortDT(model_config).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"\n{'='*40}\nModel Params: {total_params:,}\n{'='*40}\n") opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) scaler = torch.amp.GradScaler("cuda") start_epoch, global_step = load_checkpoint_if_available(run_dir, model, opt, scaler, device) loss_cfg = GeneralistLossConfig( w_action=W_ACTION, w_physics=W_PHYSICS, w_value=W_VALUE, use_rtg_weighting=True, rtg_weight_mode="exp", rtg_weight_beta=2.0 ) _atomic_write_json(os.path.join(run_dir, "model_config.json"), model_config) total_steps = len(train_loader) * EPOCHS print(f"Starting Training | Steps: {total_steps}") csv_header = ["timestamp", "epoch", "step", "loss", "loss_action", "accuracy", "loss_physics", "loss_value", "lr", "grad_norm"] csv_buffer = [] def flush_csv(): nonlocal csv_buffer if not csv_buffer: return write_header = not os.path.exists(metrics_csv) with open(metrics_csv, "a") as f: if write_header: f.write(",".join(csv_header) + "\n") for row in csv_buffer: f.write(",".join(str(row.get(k, "")) for k in csv_header) + "\n") csv_buffer = [] for epoch in range(start_epoch, EPOCHS + 1): model.train() train_ds.set_epoch(epoch) pbar = tqdm(train_loader, desc=f"Ep {epoch}", dynamic_ncols=True) stats = {"loss": [], "acc": [], "phy": [], "val": []} for batch in pbar: # 1. LR Schedule MIN_LR = 5e-5 curr_lr = MIN_LR + 0.5 * (LR - MIN_LR) * (1 + math.cos(math.pi * global_step / total_steps)) # Warmup check stays the same if global_step < WARMUP_STEPS: curr_lr = LR * (global_step / WARMUP_STEPS) for pg in opt.param_groups: pg['lr'] = curr_lr b_gpu = {k: v.to(device, non_blocking=True) for k, v in batch.items()} # 2. RTG Prep # rtg is [B, T, 2] (Energy, Comfort) rtg_input = b_gpu["rtg"] * RTG_SCALE with torch.amp.autocast("cuda"): out = model( feature_ids=b_gpu["feature_ids"], feature_vals=b_gpu["feature_values"], zone_ids=b_gpu["zone_ids"], attn_mask=b_gpu["attention_mask"], rtg=rtg_input, context=b_gpu["context"], rtg_dropout_prob=RTG_DROPOUT_PROB ) # 3. Loss Calculation loss, metrics = compute_generalist_loss(out, b_gpu, loss_cfg) opt.zero_grad(set_to_none=True) scaler.scale(loss).backward() scaler.unscale_(opt) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) scaler.step(opt) if global_step % 500 == 0: print(f"DEBUG: Step {global_step} | Grad Norm: {grad_norm:.4f} | LR: {curr_lr:.2e}") scaler.update() global_step += 1 # 5. Logging for k in ["loss_action", "loss_physics", "loss_value", "accuracy", "total_loss"]: val = metrics.get(k, 0.0) if torch.is_tensor(val): val = val.item() if k == "total_loss": stats["loss"].append(val) elif k == "accuracy": stats["acc"].append(val) elif k == "loss_physics": stats["phy"].append(val) elif k == "loss_value": stats["val"].append(val) elif k == "loss_action": hist["loss_action"].append(val) hist["step"].append(global_step) hist["loss"].append(stats["loss"][-1]) hist["acc"].append(stats["acc"][-1]) hist["phy"].append(stats["phy"][-1]) hist["val"].append(stats["val"][-1]) hist["lr"].append(curr_lr) hist["grad_norm"].append(float(grad_norm.item()) if torch.is_tensor(grad_norm) else grad_norm) csv_buffer.append({ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "epoch": epoch, "step": global_step, "loss": stats["loss"][-1], "loss_action": metrics.get("loss_action", 0.0).item() if torch.is_tensor(metrics.get("loss_action", 0.0)) else metrics.get("loss_action", 0.0), # <--- ADDED "accuracy": stats["acc"][-1], "loss_physics": stats["phy"][-1], "loss_value": stats["val"][-1], "lr": float(curr_lr), "grad_norm": hist["grad_norm"][-1] }) if global_step % 50 == 0: flush_csv() if global_step % 20 == 0: pbar.set_postfix( act=f"{metrics.get('loss_action', 0):.2f}", # Action CE phy=f"{np.mean(stats['phy'][-20:]):.4f}", # Physics Delta MSE val=f"{np.mean(stats['val'][-20:]):.2f}", # Rescaled Value MSE acc=f"{np.mean(stats['acc'][-20:]):.2f}" ) model.eval() with torch.no_grad(): try: debug_batch = next(iter(train_loader)) except StopIteration: debug_batch = next(iter(train_loader)) b_debug = {k: v.to(device) for k, v in debug_batch.items()} rtg_input_debug = b_debug["rtg"] * RTG_SCALE # 3. Forward Pass out_debug = model( feature_ids=b_debug["feature_ids"], feature_vals=b_debug["feature_values"], zone_ids=b_debug["zone_ids"], attn_mask=b_debug["attention_mask"], rtg=rtg_input_debug, context=b_debug["context"], rtg_dropout_prob=0.0 ) # 4. Process Data logits = out_debug["action_logits"] pred_bins = torch.argmax(logits, dim=-1).cpu().numpy() target_bins = b_debug["target_action_tokens"].cpu().numpy() # Create masks # [B, T, K] -> [B, T] t_mask = b_debug["time_mask"].cpu().numpy().astype(bool) # [B, T] # [B, T, K] for actions a_mask = b_debug["target_mask"].cpu().numpy().astype(bool) # [B, T, K] valid_preds = pred_bins[a_mask] valid_targets = target_bins[a_mask] target_rtg_raw = b_debug["rtg"].cpu().numpy() pred_rtg_raw = out_debug["return_preds"].cpu().numpy() valid_target_rtg = target_rtg_raw[t_mask] valid_pred_rtg = pred_rtg_raw[t_mask] np.savez_compressed( os.path.join(run_dir, "plot_data", "distributions.npz"), target_actions=valid_targets, pred_actions=valid_preds, target_rtg=valid_target_rtg, pred_rtg=valid_pred_rtg ) # ==================================== flush_csv() save_checkpoint(run_dir, model, opt, scaler, epoch, global_step, "last.pt") if epoch % 5 == 0: save_checkpoint(run_dir, model, opt, scaler, epoch, global_step, f"ckpt_{epoch}.pt") epoch_hist["epoch"].append(epoch) epoch_hist["loss_mean"].append(np.mean(stats["loss"])) epoch_hist["acc_mean"].append(np.mean(stats["acc"])) epoch_hist["phy_mean"].append(np.mean(stats["phy"])) epoch_hist["val_mean"].append(np.mean(stats["val"])) try: plots.save_plot_arrays(run_dir, hist, epoch_hist) plots.make_plots(run_dir) except Exception as e: print(f"Plotting failed: {e}") report["status"] = "complete" _atomic_write_json(report_path, report) print("Training Complete.") except Exception as e: _atomic_write_json(os.path.join(run_dir, "crash.json"), {"error": str(e), "traceback": traceback.format_exc()}) raise if __name__ == "__main__": main()