"""Training loop. Frozen encoder, head-only optimization, multi-task loss.""" from __future__ import annotations import json import math from dataclasses import dataclass, field from pathlib import Path from greenrouting.classifier.model import DEFAULT_ENCODER, Encoder, ModelSpec, build_head from greenrouting.data.schema import LENGTH_BUCKETS from greenrouting.routing.registry import CAPABILITY_KEYS LENGTH_TO_INDEX: dict[str, int] = {b: i for i, b in enumerate(LENGTH_BUCKETS)} @dataclass class TrainConfig: encoder_name: str = DEFAULT_ENCODER hidden_dim: int = 256 dropout: float = 0.1 max_seq_len: int = 256 epochs: int = 8 batch_size: int = 32 learning_rate: float = 1e-3 weight_decay: float = 1e-4 cap_weight: float = 1.0 diff_weight: float = 0.5 len_weight: float = 0.3 val_split: float = 0.15 seed: int = 42 huber_delta: float = 1.0 cap_pos_weight: float = 2.0 diff_target_center: float = field(default_factory=lambda: math.log(8e9)) def _load_split(parquet_path: str | Path): import pandas as pd return pd.read_parquet(parquet_path) def _build_targets(df, cfg: TrainConfig): import numpy as np cap_cols = [f"cap_{k}" for k in CAPABILITY_KEYS] caps = df[cap_cols].fillna(0.0).to_numpy(dtype=np.float32) diff = (df["difficulty_log_params"].fillna(cfg.diff_target_center).to_numpy(dtype=np.float32)) diff_centered = diff - cfg.diff_target_center lens = df["length_bucket"].fillna("medium").map(LENGTH_TO_INDEX).fillna(1).to_numpy(dtype=np.int64) texts = df["text"].astype(str).tolist() return texts, caps, diff_centered, lens def _split_train_val(texts, caps, diff, lens, val_split: float, seed: int): import numpy as np rng = np.random.default_rng(seed) n = len(texts) indices = np.arange(n) rng.shuffle(indices) n_val = max(1, int(n * val_split)) val_idx = indices[:n_val] train_idx = indices[n_val:] return ( ([texts[i] for i in train_idx], caps[train_idx], diff[train_idx], lens[train_idx]), ([texts[i] for i in val_idx], caps[val_idx], diff[val_idx], lens[val_idx]), ) def _iterate_batches(texts, caps, diff, lens, batch_size: int, encoder: Encoder, shuffle: bool, seed: int): import numpy as np import torch n = len(texts) indices = np.arange(n) if shuffle: np.random.default_rng(seed).shuffle(indices) for start in range(0, n, batch_size): idx = indices[start:start + batch_size] batch_texts = [texts[i] for i in idx] emb = encoder.embed(batch_texts) cap_t = torch.tensor(caps[idx], dtype=torch.float32, device=emb.device) diff_t = torch.tensor(diff[idx], dtype=torch.float32, device=emb.device) len_t = torch.tensor(lens[idx], dtype=torch.long, device=emb.device) yield emb, cap_t, diff_t, len_t def train( train_parquet: str | Path, output_dir: str | Path, cfg: TrainConfig | None = None, ) -> dict: import torch import torch.nn as nn from torch.optim import AdamW cfg = cfg or TrainConfig() out_dir = Path(output_dir) out_dir.mkdir(parents=True, exist_ok=True) df = _load_split(train_parquet) texts, caps, diff, lens = _build_targets(df, cfg) train_set, val_set = _split_train_val(texts, caps, diff, lens, cfg.val_split, cfg.seed) encoder = Encoder(cfg.encoder_name, cfg.max_seq_len) embed_dim = encoder.embed(["probe"]).shape[-1] spec = ModelSpec( encoder_name=cfg.encoder_name, embedding_dim=embed_dim, hidden_dim=cfg.hidden_dim, dropout=cfg.dropout, max_seq_len=cfg.max_seq_len, ) head = build_head(spec).to(encoder.device) pos_weight = torch.full((spec.n_capabilities,), cfg.cap_pos_weight, device=encoder.device) cap_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) diff_loss_fn = nn.HuberLoss(delta=cfg.huber_delta) len_loss_fn = nn.CrossEntropyLoss() optimizer = AdamW(head.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) history = [] for epoch in range(cfg.epochs): head.train() train_loss_sum = 0.0 n_train = 0 for emb, cap_t, diff_t, len_t in _iterate_batches( *train_set, batch_size=cfg.batch_size, encoder=encoder, shuffle=True, seed=cfg.seed + epoch, ): out = head(emb) loss = ( cfg.cap_weight * cap_loss_fn(out["cap_logits"], cap_t) + cfg.diff_weight * diff_loss_fn(out["diff"], diff_t) + cfg.len_weight * len_loss_fn(out["len_logits"], len_t) ) optimizer.zero_grad() loss.backward() optimizer.step() train_loss_sum += loss.item() * emb.shape[0] n_train += emb.shape[0] val_metrics = _evaluate(head, encoder, val_set, cfg) history.append({ "epoch": epoch, "train_loss": train_loss_sum / max(n_train, 1), **val_metrics, }) print( f"epoch {epoch+1}/{cfg.epochs} " f"train_loss={train_loss_sum/max(n_train,1):.4f} " f"val_cap_f1={val_metrics['cap_f1']:.3f} " f"val_diff_mae={val_metrics['diff_mae']:.3f} " f"val_len_acc={val_metrics['len_acc']:.3f}" ) head.eval() torch.save(head.state_dict(), out_dir / "head.pt") (out_dir / "encoder_name.txt").write_text(cfg.encoder_name) (out_dir / "metadata.json").write_text(json.dumps({ "capability_keys": list(CAPABILITY_KEYS), "length_buckets": list(LENGTH_BUCKETS), "embedding_dim": int(spec.embedding_dim), "hidden_dim": int(spec.hidden_dim), "max_seq_len": int(spec.max_seq_len), "diff_target_center": float(cfg.diff_target_center), }, indent=2)) (out_dir / "training_history.json").write_text(json.dumps(history, indent=2)) train_embeddings = _collect_embeddings(encoder, train_set[0], batch_size=cfg.batch_size) val_cap_logits = _collect_logits(head, encoder, val_set, cfg.batch_size) from greenrouting.classifier.calibration import fit_temperature temperature = fit_temperature(val_cap_logits, val_set[1]) (out_dir / "calibration.json").write_text(json.dumps({"temperature": float(temperature)}, indent=2)) from greenrouting.classifier.ood import calibrate_thresholds, fit_ood_stats ood_stats = fit_ood_stats(train_embeddings, k=5) thresholds = calibrate_thresholds(train_embeddings, ood_stats, percentile=99.0) import numpy as np np.savez( out_dir / "ood_stats.npz", centroid=ood_stats["centroid"], reference=ood_stats["reference"], k=ood_stats.get("k", 5), centroid_threshold=thresholds["centroid_threshold"], knn_threshold=thresholds["knn_threshold"], ) return { "history": history, "temperature": float(temperature), "n_train": len(train_set[0]), "n_val": len(val_set[0]), } def _evaluate(head, encoder: Encoder, val_set, cfg: TrainConfig) -> dict: import torch head.eval() all_cap_pred, all_cap_true = [], [] all_diff_pred, all_diff_true = [], [] all_len_pred, all_len_true = [], [] with torch.no_grad(): for emb, cap_t, diff_t, len_t in _iterate_batches( *val_set, batch_size=cfg.batch_size, encoder=encoder, shuffle=False, seed=cfg.seed, ): out = head(emb) all_cap_pred.append(torch.sigmoid(out["cap_logits"]).cpu().numpy()) all_cap_true.append(cap_t.cpu().numpy()) all_diff_pred.append(out["diff"].cpu().numpy()) all_diff_true.append(diff_t.cpu().numpy()) all_len_pred.append(out["len_logits"].argmax(dim=-1).cpu().numpy()) all_len_true.append(len_t.cpu().numpy()) head.train() import numpy as np cap_pred = np.concatenate(all_cap_pred) cap_true = np.concatenate(all_cap_true) diff_pred = np.concatenate(all_diff_pred) diff_true = np.concatenate(all_diff_true) len_pred = np.concatenate(all_len_pred) len_true = np.concatenate(all_len_true) cap_pred_bin = (cap_pred >= 0.5).astype(np.float32) cap_true_bin = (cap_true >= 0.5).astype(np.float32) tp = ((cap_pred_bin == 1) & (cap_true_bin == 1)).sum() fp = ((cap_pred_bin == 1) & (cap_true_bin == 0)).sum() fn = ((cap_pred_bin == 0) & (cap_true_bin == 1)).sum() precision = tp / max(tp + fp, 1) recall = tp / max(tp + fn, 1) f1 = 2 * precision * recall / max(precision + recall, 1e-9) diff_mae = float(np.abs(diff_pred - diff_true).mean()) len_acc = float((len_pred == len_true).mean()) return { "cap_precision": float(precision), "cap_recall": float(recall), "cap_f1": float(f1), "diff_mae": diff_mae, "len_acc": len_acc, } def _collect_embeddings(encoder: Encoder, texts: list[str], batch_size: int): import numpy as np chunks = [] for start in range(0, len(texts), batch_size): chunk = texts[start:start + batch_size] emb = encoder.embed(chunk).cpu().numpy() chunks.append(emb) return np.concatenate(chunks, axis=0) if chunks else np.zeros((0, 384), dtype=np.float32) def _collect_logits(head, encoder: Encoder, val_set, batch_size: int): import numpy as np import torch head.eval() out_logits = [] with torch.no_grad(): for emb, _cap, _diff, _len in _iterate_batches( *val_set, batch_size=batch_size, encoder=encoder, shuffle=False, seed=0, ): out = head(emb) out_logits.append(out["cap_logits"].cpu().numpy()) head.train() return np.concatenate(out_logits, axis=0) if out_logits else np.zeros((0, 8), dtype=np.float32)