| |
|
|
| import os |
| import time |
| import math |
| import glob |
| import json |
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| import traceback |
| import matplotlib.pyplot as plt |
| from collections import Counter |
|
|
| |
| import dataloader as dl |
| from embeddings import GeneralistComfortDT |
| from losses import compute_generalist_loss, GeneralistLossConfig |
| import plots |
|
|
| |
| |
| |
| DATA_DIR = "TrajectoryData_from_docker" |
| RUNS_DIR = "training-runs" |
|
|
| |
| VOCAB_SIZE = 512 |
| D_MODEL = 256 |
| N_LAYERS = 6 |
| N_HEADS = 8 |
| DROPOUT = 0.1 |
| MAX_ZONES = 32 |
|
|
| |
| BATCH_SIZE = 16 |
| EPOCHS = 50 |
| LR = 3e-4 |
| WARMUP_STEPS = 1000 |
| WEIGHT_DECAY = 1e-2 |
| GRAD_CLIP = 1.0 |
|
|
| MAX_TOKENS_PER_STEP = 64 |
| CONTEXT_LEN = 48 |
| CONTEXT_DIM = 10 |
| RTG_DIM = 2 |
|
|
| |
| W_ACTION = 1.0 |
| W_PHYSICS = 1.0 |
| W_VALUE = 1.0 |
|
|
| |
| USE_TOPK = True |
| TOPK_FRACTION = 1.0 |
| TOPK_MODE = "filter" |
| TOPK_ON = "pareto" |
| RTG_SCALE = 1.0 |
|
|
| |
| RTG_DROPOUT_PROB = 0.2 |
|
|
| SEED = 42 |
| NUM_WORKERS = 12 |
|
|
| |
| |
| |
| def set_seed(s): |
| torch.manual_seed(s) |
| torch.cuda.manual_seed_all(s) |
| np.random.seed(s) |
|
|
| def list_episode_npzs(data_dir: str): |
| paths = sorted(glob.glob(os.path.join(DATA_DIR, "TrajectoryData_officesmall", "**", "traj_ep*_seed*.npz"), recursive=True)) |
| paths = [p for p in paths if "norm_stats" not in p and "cache" not in p] |
| return paths |
|
|
| def load_checkpoint_if_available(run_dir, model, opt, scaler, device): |
| last_path = os.path.join(run_dir, "last.pt") |
| if not os.path.exists(last_path): |
| return 1, 0 |
| ckpt = torch.load(last_path, map_location=device) |
| model.load_state_dict(ckpt["model"]) |
| opt.load_state_dict(ckpt["opt"]) |
| scaler.load_state_dict(ckpt["scaler"]) |
| start_epoch = int(ckpt.get("epoch", 0)) + 1 |
| global_step = int(ckpt.get("global_step", 0)) |
| print(f"[Resume] Loaded {last_path} | start_epoch={start_epoch} global_step={global_step}") |
| return start_epoch, global_step |
|
|
| def save_checkpoint(run_dir, model, opt, scaler, epoch, global_step, name): |
| ckpt = { |
| "epoch": epoch, |
| "global_step": global_step, |
| "model": model.state_dict(), |
| "opt": opt.state_dict(), |
| "scaler": scaler.state_dict(), |
| } |
| torch.save(ckpt, os.path.join(run_dir, name)) |
|
|
| def get_run_dir(): |
| os.makedirs(RUNS_DIR, exist_ok=True) |
| existing = len(glob.glob(os.path.join(RUNS_DIR, "run_*"))) |
| path = os.path.join(RUNS_DIR, f"run_{existing+1:03d}") |
| os.makedirs(path, exist_ok=True) |
| os.makedirs(os.path.join(path, "plots"), exist_ok=True) |
| return path |
|
|
| def _atomic_write_json(path, obj): |
| tmp = path + ".tmp" |
| with open(tmp, "w") as f: |
| json.dump(obj, f, indent=2) |
| os.replace(tmp, path) |
|
|
| |
| |
| |
| def main(): |
| set_seed(SEED) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.set_float32_matmul_precision("high") |
|
|
| run_dir = get_run_dir() |
| os.makedirs(os.path.join(run_dir, "plot_data"), exist_ok=True) |
|
|
| report_path = os.path.join(run_dir, "report.json") |
| metrics_csv = os.path.join(run_dir, "metrics.csv") |
|
|
| hist = {"step": [], "loss": [], "acc": [], "phy": [], "val": [], "lr": [], "grad_norm": [], "loss_action": []} |
| epoch_hist = {"epoch": [], "loss_mean": [], "acc_mean": [], "phy_mean": [], "val_mean": []} |
|
|
| report = { |
| "run_dir": run_dir, |
| "started_at": time.strftime("%Y-%m-%d %H:%M:%S"), |
| "config": { |
| "DATA_DIR": DATA_DIR, "MAX_TOKENS": MAX_TOKENS_PER_STEP, |
| "BATCH_SIZE": BATCH_SIZE, "LR": LR, "SEED": SEED |
| }, |
| "status": "running", |
| "progress": {"epoch": 0, "global_step": 0}, |
| } |
| _atomic_write_json(report_path, report) |
|
|
| try: |
| print(f"Loading data from {DATA_DIR}...") |
| all_paths = list_episode_npzs(DATA_DIR) |
| if not all_paths: raise RuntimeError(f"No valid npz files found in {DATA_DIR}") |
|
|
| train_ds = dl.GeneralistDataset( |
| all_paths, seed=SEED, |
| max_tokens=MAX_TOKENS_PER_STEP, |
| topk_frac=TOPK_FRACTION, |
| topk_mode=TOPK_MODE, |
| topk_on=TOPK_ON |
| ) |
| train_ds.is_train = True |
| train_loader = DataLoader( |
| train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, |
| pin_memory=True, pin_memory_device="cuda", persistent_workers=True, |
| prefetch_factor=4, collate_fn=dl.generalist_collate_fn, drop_last=True |
| ) |
|
|
| model_config = { |
| "VOCAB_SIZE": VOCAB_SIZE, "D_MODEL": D_MODEL, |
| "N_LAYERS": N_LAYERS, "N_HEADS": N_HEADS, |
| "DROPOUT": DROPOUT, "MAX_ZONES": MAX_ZONES, |
| "CONTEXT_LEN": CONTEXT_LEN, |
| "NUM_ACTION_BINS": dl.NUM_ACTION_BINS, |
| "CONTEXT_DIM": CONTEXT_DIM, |
| "RTG_DIM": RTG_DIM |
| } |
|
|
| model = GeneralistComfortDT(model_config).to(device) |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"\n{'='*40}\nModel Params: {total_params:,}\n{'='*40}\n") |
|
|
| opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) |
| scaler = torch.amp.GradScaler("cuda") |
| start_epoch, global_step = load_checkpoint_if_available(run_dir, model, opt, scaler, device) |
|
|
| loss_cfg = GeneralistLossConfig( |
| w_action=W_ACTION, |
| w_physics=W_PHYSICS, |
| w_value=W_VALUE, |
| use_rtg_weighting=True, |
| rtg_weight_mode="exp", |
| rtg_weight_beta=2.0 |
| ) |
| |
| _atomic_write_json(os.path.join(run_dir, "model_config.json"), model_config) |
|
|
| total_steps = len(train_loader) * EPOCHS |
| print(f"Starting Training | Steps: {total_steps}") |
|
|
| csv_header = ["timestamp", "epoch", "step", "loss", "loss_action", "accuracy", "loss_physics", "loss_value", "lr", "grad_norm"] |
| csv_buffer = [] |
|
|
| def flush_csv(): |
| nonlocal csv_buffer |
| if not csv_buffer: return |
| write_header = not os.path.exists(metrics_csv) |
| with open(metrics_csv, "a") as f: |
| if write_header: f.write(",".join(csv_header) + "\n") |
| for row in csv_buffer: |
| f.write(",".join(str(row.get(k, "")) for k in csv_header) + "\n") |
| csv_buffer = [] |
|
|
| for epoch in range(start_epoch, EPOCHS + 1): |
| model.train() |
| train_ds.set_epoch(epoch) |
| pbar = tqdm(train_loader, desc=f"Ep {epoch}", dynamic_ncols=True) |
| stats = {"loss": [], "acc": [], "phy": [], "val": []} |
|
|
| for batch in pbar: |
| |
| MIN_LR = 5e-5 |
| curr_lr = MIN_LR + 0.5 * (LR - MIN_LR) * (1 + math.cos(math.pi * global_step / total_steps)) |
|
|
| |
| if global_step < WARMUP_STEPS: |
| curr_lr = LR * (global_step / WARMUP_STEPS) |
|
|
| for pg in opt.param_groups: |
| pg['lr'] = curr_lr |
|
|
| b_gpu = {k: v.to(device, non_blocking=True) for k, v in batch.items()} |
| |
| |
| |
| rtg_input = b_gpu["rtg"] * RTG_SCALE |
|
|
| with torch.amp.autocast("cuda"): |
| out = model( |
| feature_ids=b_gpu["feature_ids"], |
| feature_vals=b_gpu["feature_values"], |
| zone_ids=b_gpu["zone_ids"], |
| attn_mask=b_gpu["attention_mask"], |
| rtg=rtg_input, |
| context=b_gpu["context"], |
| rtg_dropout_prob=RTG_DROPOUT_PROB |
| ) |
| |
| |
| loss, metrics = compute_generalist_loss(out, b_gpu, loss_cfg) |
|
|
| opt.zero_grad(set_to_none=True) |
| scaler.scale(loss).backward() |
| scaler.unscale_(opt) |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| scaler.step(opt) |
| if global_step % 500 == 0: |
| print(f"DEBUG: Step {global_step} | Grad Norm: {grad_norm:.4f} | LR: {curr_lr:.2e}") |
| scaler.update() |
|
|
| global_step += 1 |
| |
| |
| for k in ["loss_action", "loss_physics", "loss_value", "accuracy", "total_loss"]: |
| val = metrics.get(k, 0.0) |
| if torch.is_tensor(val): val = val.item() |
| |
| if k == "total_loss": stats["loss"].append(val) |
| elif k == "accuracy": stats["acc"].append(val) |
| elif k == "loss_physics": stats["phy"].append(val) |
| elif k == "loss_value": stats["val"].append(val) |
| elif k == "loss_action": |
| hist["loss_action"].append(val) |
|
|
| |
| hist["step"].append(global_step) |
| hist["loss"].append(stats["loss"][-1]) |
| hist["acc"].append(stats["acc"][-1]) |
| hist["phy"].append(stats["phy"][-1]) |
| hist["val"].append(stats["val"][-1]) |
| hist["lr"].append(curr_lr) |
| hist["grad_norm"].append(float(grad_norm.item()) if torch.is_tensor(grad_norm) else grad_norm) |
|
|
| csv_buffer.append({ |
| "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "epoch": epoch, "step": global_step, |
| "loss": stats["loss"][-1], |
| "loss_action": metrics.get("loss_action", 0.0).item() if torch.is_tensor(metrics.get("loss_action", 0.0)) else metrics.get("loss_action", 0.0), |
| "accuracy": stats["acc"][-1], |
| "loss_physics": stats["phy"][-1], "loss_value": stats["val"][-1], |
| "lr": float(curr_lr), "grad_norm": hist["grad_norm"][-1] |
| }) |
| |
| if global_step % 50 == 0: flush_csv() |
| if global_step % 20 == 0: |
| pbar.set_postfix( |
| act=f"{metrics.get('loss_action', 0):.2f}", |
| phy=f"{np.mean(stats['phy'][-20:]):.4f}", |
| val=f"{np.mean(stats['val'][-20:]):.2f}", |
| acc=f"{np.mean(stats['acc'][-20:]):.2f}" |
| ) |
| model.eval() |
| with torch.no_grad(): |
| try: |
| debug_batch = next(iter(train_loader)) |
| except StopIteration: |
| debug_batch = next(iter(train_loader)) |
| |
| b_debug = {k: v.to(device) for k, v in debug_batch.items()} |
| rtg_input_debug = b_debug["rtg"] * RTG_SCALE |
|
|
| |
| out_debug = model( |
| feature_ids=b_debug["feature_ids"], |
| feature_vals=b_debug["feature_values"], |
| zone_ids=b_debug["zone_ids"], |
| attn_mask=b_debug["attention_mask"], |
| rtg=rtg_input_debug, |
| context=b_debug["context"], |
| rtg_dropout_prob=0.0 |
| ) |
| |
| |
| logits = out_debug["action_logits"] |
| pred_bins = torch.argmax(logits, dim=-1).cpu().numpy() |
| target_bins = b_debug["target_action_tokens"].cpu().numpy() |
| |
| |
| |
| t_mask = b_debug["time_mask"].cpu().numpy().astype(bool) |
| |
| a_mask = b_debug["target_mask"].cpu().numpy().astype(bool) |
| valid_preds = pred_bins[a_mask] |
| valid_targets = target_bins[a_mask] |
| target_rtg_raw = b_debug["rtg"].cpu().numpy() |
| pred_rtg_raw = out_debug["return_preds"].cpu().numpy() |
| |
| valid_target_rtg = target_rtg_raw[t_mask] |
| valid_pred_rtg = pred_rtg_raw[t_mask] |
|
|
|
|
| np.savez_compressed( |
| os.path.join(run_dir, "plot_data", "distributions.npz"), |
| target_actions=valid_targets, |
| pred_actions=valid_preds, |
| target_rtg=valid_target_rtg, |
| pred_rtg=valid_pred_rtg |
| ) |
| |
|
|
| flush_csv() |
| save_checkpoint(run_dir, model, opt, scaler, epoch, global_step, "last.pt") |
| if epoch % 5 == 0: |
| save_checkpoint(run_dir, model, opt, scaler, epoch, global_step, f"ckpt_{epoch}.pt") |
| |
| epoch_hist["epoch"].append(epoch) |
| epoch_hist["loss_mean"].append(np.mean(stats["loss"])) |
| epoch_hist["acc_mean"].append(np.mean(stats["acc"])) |
| epoch_hist["phy_mean"].append(np.mean(stats["phy"])) |
| epoch_hist["val_mean"].append(np.mean(stats["val"])) |
|
|
| try: |
| plots.save_plot_arrays(run_dir, hist, epoch_hist) |
| plots.make_plots(run_dir) |
| except Exception as e: |
| print(f"Plotting failed: {e}") |
|
|
| report["status"] = "complete" |
| _atomic_write_json(report_path, report) |
| print("Training Complete.") |
|
|
| except Exception as e: |
| _atomic_write_json(os.path.join(run_dir, "crash.json"), {"error": str(e), "traceback": traceback.format_exc()}) |
| raise |
|
|
| if __name__ == "__main__": |
| main() |