Controller / training /training.py
Gen-HVAC's picture
Upload 4 files
1641a08 verified
raw
history blame
14.3 kB
#train.py
import os
import time
import math
import glob
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import traceback
import matplotlib.pyplot as plt
from collections import Counter
# --- New Modules ---
import dataloader as dl
from embeddings import GeneralistComfortDT
from losses import compute_generalist_loss, GeneralistLossConfig
import plots
# ============================================================
# CONFIGURATION
# ============================================================
DATA_DIR = "TrajectoryData_from_docker"
RUNS_DIR = "training-runs"
# Architecture
VOCAB_SIZE = 512
D_MODEL = 256
N_LAYERS = 6
N_HEADS = 8
DROPOUT = 0.1
MAX_ZONES = 32
# Training
BATCH_SIZE = 16
EPOCHS = 50
LR = 3e-4
WARMUP_STEPS = 1000
WEIGHT_DECAY = 1e-2
GRAD_CLIP = 1.0
MAX_TOKENS_PER_STEP = 64
CONTEXT_LEN = 48
CONTEXT_DIM = 10
RTG_DIM = 2 # Energy + Comfort
# Loss Weights
W_ACTION = 1.0
W_PHYSICS = 1.0
W_VALUE = 1.0
# Generalist Stitching Config
USE_TOPK = True
TOPK_FRACTION = 1.0
TOPK_MODE = "filter"
TOPK_ON = "pareto"
RTG_SCALE = 1.0
# Robustness
RTG_DROPOUT_PROB = 0.2
SEED = 42
NUM_WORKERS = 12
# ============================================================
# UTILITIES
# ============================================================
def set_seed(s):
torch.manual_seed(s)
torch.cuda.manual_seed_all(s)
np.random.seed(s)
def list_episode_npzs(data_dir: str):
paths = sorted(glob.glob(os.path.join(DATA_DIR, "TrajectoryData_officesmall", "**", "traj_ep*_seed*.npz"), recursive=True))
paths = [p for p in paths if "norm_stats" not in p and "cache" not in p]
return paths
def load_checkpoint_if_available(run_dir, model, opt, scaler, device):
last_path = os.path.join(run_dir, "last.pt")
if not os.path.exists(last_path):
return 1, 0
ckpt = torch.load(last_path, map_location=device)
model.load_state_dict(ckpt["model"])
opt.load_state_dict(ckpt["opt"])
scaler.load_state_dict(ckpt["scaler"])
start_epoch = int(ckpt.get("epoch", 0)) + 1
global_step = int(ckpt.get("global_step", 0))
print(f"[Resume] Loaded {last_path} | start_epoch={start_epoch} global_step={global_step}")
return start_epoch, global_step
def save_checkpoint(run_dir, model, opt, scaler, epoch, global_step, name):
ckpt = {
"epoch": epoch,
"global_step": global_step,
"model": model.state_dict(),
"opt": opt.state_dict(),
"scaler": scaler.state_dict(),
}
torch.save(ckpt, os.path.join(run_dir, name))
def get_run_dir():
os.makedirs(RUNS_DIR, exist_ok=True)
existing = len(glob.glob(os.path.join(RUNS_DIR, "run_*")))
path = os.path.join(RUNS_DIR, f"run_{existing+1:03d}")
os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "plots"), exist_ok=True)
return path
def _atomic_write_json(path, obj):
tmp = path + ".tmp"
with open(tmp, "w") as f:
json.dump(obj, f, indent=2)
os.replace(tmp, path)
# ============================================================
# MAIN LOOP
# ============================================================
def main():
set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
run_dir = get_run_dir()
os.makedirs(os.path.join(run_dir, "plot_data"), exist_ok=True)
report_path = os.path.join(run_dir, "report.json")
metrics_csv = os.path.join(run_dir, "metrics.csv")
hist = {"step": [], "loss": [], "acc": [], "phy": [], "val": [], "lr": [], "grad_norm": [], "loss_action": []}
epoch_hist = {"epoch": [], "loss_mean": [], "acc_mean": [], "phy_mean": [], "val_mean": []}
report = {
"run_dir": run_dir,
"started_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"config": {
"DATA_DIR": DATA_DIR, "MAX_TOKENS": MAX_TOKENS_PER_STEP,
"BATCH_SIZE": BATCH_SIZE, "LR": LR, "SEED": SEED
},
"status": "running",
"progress": {"epoch": 0, "global_step": 0},
}
_atomic_write_json(report_path, report)
try:
print(f"Loading data from {DATA_DIR}...")
all_paths = list_episode_npzs(DATA_DIR)
if not all_paths: raise RuntimeError(f"No valid npz files found in {DATA_DIR}")
train_ds = dl.GeneralistDataset(
all_paths, seed=SEED,
max_tokens=MAX_TOKENS_PER_STEP,
topk_frac=TOPK_FRACTION,
topk_mode=TOPK_MODE,
topk_on=TOPK_ON
)
train_ds.is_train = True
train_loader = DataLoader(
train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS,
pin_memory=True, pin_memory_device="cuda", persistent_workers=True,
prefetch_factor=4, collate_fn=dl.generalist_collate_fn, drop_last=True
)
model_config = {
"VOCAB_SIZE": VOCAB_SIZE, "D_MODEL": D_MODEL,
"N_LAYERS": N_LAYERS, "N_HEADS": N_HEADS,
"DROPOUT": DROPOUT, "MAX_ZONES": MAX_ZONES,
"CONTEXT_LEN": CONTEXT_LEN,
"NUM_ACTION_BINS": dl.NUM_ACTION_BINS,
"CONTEXT_DIM": CONTEXT_DIM,
"RTG_DIM": RTG_DIM
}
model = GeneralistComfortDT(model_config).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"\n{'='*40}\nModel Params: {total_params:,}\n{'='*40}\n")
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scaler = torch.amp.GradScaler("cuda")
start_epoch, global_step = load_checkpoint_if_available(run_dir, model, opt, scaler, device)
loss_cfg = GeneralistLossConfig(
w_action=W_ACTION,
w_physics=W_PHYSICS,
w_value=W_VALUE,
use_rtg_weighting=True,
rtg_weight_mode="exp",
rtg_weight_beta=2.0
)
_atomic_write_json(os.path.join(run_dir, "model_config.json"), model_config)
total_steps = len(train_loader) * EPOCHS
print(f"Starting Training | Steps: {total_steps}")
csv_header = ["timestamp", "epoch", "step", "loss", "loss_action", "accuracy", "loss_physics", "loss_value", "lr", "grad_norm"]
csv_buffer = []
def flush_csv():
nonlocal csv_buffer
if not csv_buffer: return
write_header = not os.path.exists(metrics_csv)
with open(metrics_csv, "a") as f:
if write_header: f.write(",".join(csv_header) + "\n")
for row in csv_buffer:
f.write(",".join(str(row.get(k, "")) for k in csv_header) + "\n")
csv_buffer = []
for epoch in range(start_epoch, EPOCHS + 1):
model.train()
train_ds.set_epoch(epoch)
pbar = tqdm(train_loader, desc=f"Ep {epoch}", dynamic_ncols=True)
stats = {"loss": [], "acc": [], "phy": [], "val": []}
for batch in pbar:
# 1. LR Schedule
MIN_LR = 5e-5
curr_lr = MIN_LR + 0.5 * (LR - MIN_LR) * (1 + math.cos(math.pi * global_step / total_steps))
# Warmup check stays the same
if global_step < WARMUP_STEPS:
curr_lr = LR * (global_step / WARMUP_STEPS)
for pg in opt.param_groups:
pg['lr'] = curr_lr
b_gpu = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
# 2. RTG Prep
# rtg is [B, T, 2] (Energy, Comfort)
rtg_input = b_gpu["rtg"] * RTG_SCALE
with torch.amp.autocast("cuda"):
out = model(
feature_ids=b_gpu["feature_ids"],
feature_vals=b_gpu["feature_values"],
zone_ids=b_gpu["zone_ids"],
attn_mask=b_gpu["attention_mask"],
rtg=rtg_input,
context=b_gpu["context"],
rtg_dropout_prob=RTG_DROPOUT_PROB
)
# 3. Loss Calculation
loss, metrics = compute_generalist_loss(out, b_gpu, loss_cfg)
opt.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.unscale_(opt)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
scaler.step(opt)
if global_step % 500 == 0:
print(f"DEBUG: Step {global_step} | Grad Norm: {grad_norm:.4f} | LR: {curr_lr:.2e}")
scaler.update()
global_step += 1
# 5. Logging
for k in ["loss_action", "loss_physics", "loss_value", "accuracy", "total_loss"]:
val = metrics.get(k, 0.0)
if torch.is_tensor(val): val = val.item()
if k == "total_loss": stats["loss"].append(val)
elif k == "accuracy": stats["acc"].append(val)
elif k == "loss_physics": stats["phy"].append(val)
elif k == "loss_value": stats["val"].append(val)
elif k == "loss_action":
hist["loss_action"].append(val)
hist["step"].append(global_step)
hist["loss"].append(stats["loss"][-1])
hist["acc"].append(stats["acc"][-1])
hist["phy"].append(stats["phy"][-1])
hist["val"].append(stats["val"][-1])
hist["lr"].append(curr_lr)
hist["grad_norm"].append(float(grad_norm.item()) if torch.is_tensor(grad_norm) else grad_norm)
csv_buffer.append({
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "epoch": epoch, "step": global_step,
"loss": stats["loss"][-1],
"loss_action": metrics.get("loss_action", 0.0).item() if torch.is_tensor(metrics.get("loss_action", 0.0)) else metrics.get("loss_action", 0.0), # <--- ADDED
"accuracy": stats["acc"][-1],
"loss_physics": stats["phy"][-1], "loss_value": stats["val"][-1],
"lr": float(curr_lr), "grad_norm": hist["grad_norm"][-1]
})
if global_step % 50 == 0: flush_csv()
if global_step % 20 == 0:
pbar.set_postfix(
act=f"{metrics.get('loss_action', 0):.2f}", # Action CE
phy=f"{np.mean(stats['phy'][-20:]):.4f}", # Physics Delta MSE
val=f"{np.mean(stats['val'][-20:]):.2f}", # Rescaled Value MSE
acc=f"{np.mean(stats['acc'][-20:]):.2f}"
)
model.eval()
with torch.no_grad():
try:
debug_batch = next(iter(train_loader))
except StopIteration:
debug_batch = next(iter(train_loader))
b_debug = {k: v.to(device) for k, v in debug_batch.items()}
rtg_input_debug = b_debug["rtg"] * RTG_SCALE
# 3. Forward Pass
out_debug = model(
feature_ids=b_debug["feature_ids"],
feature_vals=b_debug["feature_values"],
zone_ids=b_debug["zone_ids"],
attn_mask=b_debug["attention_mask"],
rtg=rtg_input_debug,
context=b_debug["context"],
rtg_dropout_prob=0.0
)
# 4. Process Data
logits = out_debug["action_logits"]
pred_bins = torch.argmax(logits, dim=-1).cpu().numpy()
target_bins = b_debug["target_action_tokens"].cpu().numpy()
# Create masks
# [B, T, K] -> [B, T]
t_mask = b_debug["time_mask"].cpu().numpy().astype(bool) # [B, T]
# [B, T, K] for actions
a_mask = b_debug["target_mask"].cpu().numpy().astype(bool) # [B, T, K]
valid_preds = pred_bins[a_mask]
valid_targets = target_bins[a_mask]
target_rtg_raw = b_debug["rtg"].cpu().numpy()
pred_rtg_raw = out_debug["return_preds"].cpu().numpy()
valid_target_rtg = target_rtg_raw[t_mask]
valid_pred_rtg = pred_rtg_raw[t_mask]
np.savez_compressed(
os.path.join(run_dir, "plot_data", "distributions.npz"),
target_actions=valid_targets,
pred_actions=valid_preds,
target_rtg=valid_target_rtg,
pred_rtg=valid_pred_rtg
)
# ====================================
flush_csv()
save_checkpoint(run_dir, model, opt, scaler, epoch, global_step, "last.pt")
if epoch % 5 == 0:
save_checkpoint(run_dir, model, opt, scaler, epoch, global_step, f"ckpt_{epoch}.pt")
epoch_hist["epoch"].append(epoch)
epoch_hist["loss_mean"].append(np.mean(stats["loss"]))
epoch_hist["acc_mean"].append(np.mean(stats["acc"]))
epoch_hist["phy_mean"].append(np.mean(stats["phy"]))
epoch_hist["val_mean"].append(np.mean(stats["val"]))
try:
plots.save_plot_arrays(run_dir, hist, epoch_hist)
plots.make_plots(run_dir)
except Exception as e:
print(f"Plotting failed: {e}")
report["status"] = "complete"
_atomic_write_json(report_path, report)
print("Training Complete.")
except Exception as e:
_atomic_write_json(os.path.join(run_dir, "crash.json"), {"error": str(e), "traceback": traceback.format_exc()})
raise
if __name__ == "__main__":
main()