Spaces:
Sleeping
Sleeping
| """Training loop. Frozen encoder, head-only optimization, multi-task loss.""" | |
| from __future__ import annotations | |
| import json | |
| import math | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from greenrouting.classifier.model import DEFAULT_ENCODER, Encoder, ModelSpec, build_head | |
| from greenrouting.data.schema import LENGTH_BUCKETS | |
| from greenrouting.routing.registry import CAPABILITY_KEYS | |
| LENGTH_TO_INDEX: dict[str, int] = {b: i for i, b in enumerate(LENGTH_BUCKETS)} | |
| class TrainConfig: | |
| encoder_name: str = DEFAULT_ENCODER | |
| hidden_dim: int = 256 | |
| dropout: float = 0.1 | |
| max_seq_len: int = 256 | |
| epochs: int = 8 | |
| batch_size: int = 32 | |
| learning_rate: float = 1e-3 | |
| weight_decay: float = 1e-4 | |
| cap_weight: float = 1.0 | |
| diff_weight: float = 0.5 | |
| len_weight: float = 0.3 | |
| val_split: float = 0.15 | |
| seed: int = 42 | |
| huber_delta: float = 1.0 | |
| cap_pos_weight: float = 2.0 | |
| diff_target_center: float = field(default_factory=lambda: math.log(8e9)) | |
| def _load_split(parquet_path: str | Path): | |
| import pandas as pd | |
| return pd.read_parquet(parquet_path) | |
| def _build_targets(df, cfg: TrainConfig): | |
| import numpy as np | |
| cap_cols = [f"cap_{k}" for k in CAPABILITY_KEYS] | |
| caps = df[cap_cols].fillna(0.0).to_numpy(dtype=np.float32) | |
| diff = (df["difficulty_log_params"].fillna(cfg.diff_target_center).to_numpy(dtype=np.float32)) | |
| diff_centered = diff - cfg.diff_target_center | |
| lens = df["length_bucket"].fillna("medium").map(LENGTH_TO_INDEX).fillna(1).to_numpy(dtype=np.int64) | |
| texts = df["text"].astype(str).tolist() | |
| return texts, caps, diff_centered, lens | |
| def _split_train_val(texts, caps, diff, lens, val_split: float, seed: int): | |
| import numpy as np | |
| rng = np.random.default_rng(seed) | |
| n = len(texts) | |
| indices = np.arange(n) | |
| rng.shuffle(indices) | |
| n_val = max(1, int(n * val_split)) | |
| val_idx = indices[:n_val] | |
| train_idx = indices[n_val:] | |
| return ( | |
| ([texts[i] for i in train_idx], caps[train_idx], diff[train_idx], lens[train_idx]), | |
| ([texts[i] for i in val_idx], caps[val_idx], diff[val_idx], lens[val_idx]), | |
| ) | |
| def _iterate_batches(texts, caps, diff, lens, batch_size: int, encoder: Encoder, shuffle: bool, seed: int): | |
| import numpy as np | |
| import torch | |
| n = len(texts) | |
| indices = np.arange(n) | |
| if shuffle: | |
| np.random.default_rng(seed).shuffle(indices) | |
| for start in range(0, n, batch_size): | |
| idx = indices[start:start + batch_size] | |
| batch_texts = [texts[i] for i in idx] | |
| emb = encoder.embed(batch_texts) | |
| cap_t = torch.tensor(caps[idx], dtype=torch.float32, device=emb.device) | |
| diff_t = torch.tensor(diff[idx], dtype=torch.float32, device=emb.device) | |
| len_t = torch.tensor(lens[idx], dtype=torch.long, device=emb.device) | |
| yield emb, cap_t, diff_t, len_t | |
| def train( | |
| train_parquet: str | Path, | |
| output_dir: str | Path, | |
| cfg: TrainConfig | None = None, | |
| ) -> dict: | |
| import torch | |
| import torch.nn as nn | |
| from torch.optim import AdamW | |
| cfg = cfg or TrainConfig() | |
| out_dir = Path(output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| df = _load_split(train_parquet) | |
| texts, caps, diff, lens = _build_targets(df, cfg) | |
| train_set, val_set = _split_train_val(texts, caps, diff, lens, cfg.val_split, cfg.seed) | |
| encoder = Encoder(cfg.encoder_name, cfg.max_seq_len) | |
| embed_dim = encoder.embed(["probe"]).shape[-1] | |
| spec = ModelSpec( | |
| encoder_name=cfg.encoder_name, | |
| embedding_dim=embed_dim, | |
| hidden_dim=cfg.hidden_dim, | |
| dropout=cfg.dropout, | |
| max_seq_len=cfg.max_seq_len, | |
| ) | |
| head = build_head(spec).to(encoder.device) | |
| pos_weight = torch.full((spec.n_capabilities,), cfg.cap_pos_weight, device=encoder.device) | |
| cap_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) | |
| diff_loss_fn = nn.HuberLoss(delta=cfg.huber_delta) | |
| len_loss_fn = nn.CrossEntropyLoss() | |
| optimizer = AdamW(head.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) | |
| history = [] | |
| for epoch in range(cfg.epochs): | |
| head.train() | |
| train_loss_sum = 0.0 | |
| n_train = 0 | |
| for emb, cap_t, diff_t, len_t in _iterate_batches( | |
| *train_set, batch_size=cfg.batch_size, encoder=encoder, | |
| shuffle=True, seed=cfg.seed + epoch, | |
| ): | |
| out = head(emb) | |
| loss = ( | |
| cfg.cap_weight * cap_loss_fn(out["cap_logits"], cap_t) | |
| + cfg.diff_weight * diff_loss_fn(out["diff"], diff_t) | |
| + cfg.len_weight * len_loss_fn(out["len_logits"], len_t) | |
| ) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| train_loss_sum += loss.item() * emb.shape[0] | |
| n_train += emb.shape[0] | |
| val_metrics = _evaluate(head, encoder, val_set, cfg) | |
| history.append({ | |
| "epoch": epoch, | |
| "train_loss": train_loss_sum / max(n_train, 1), | |
| **val_metrics, | |
| }) | |
| print( | |
| f"epoch {epoch+1}/{cfg.epochs} " | |
| f"train_loss={train_loss_sum/max(n_train,1):.4f} " | |
| f"val_cap_f1={val_metrics['cap_f1']:.3f} " | |
| f"val_diff_mae={val_metrics['diff_mae']:.3f} " | |
| f"val_len_acc={val_metrics['len_acc']:.3f}" | |
| ) | |
| head.eval() | |
| torch.save(head.state_dict(), out_dir / "head.pt") | |
| (out_dir / "encoder_name.txt").write_text(cfg.encoder_name) | |
| (out_dir / "metadata.json").write_text(json.dumps({ | |
| "capability_keys": list(CAPABILITY_KEYS), | |
| "length_buckets": list(LENGTH_BUCKETS), | |
| "embedding_dim": int(spec.embedding_dim), | |
| "hidden_dim": int(spec.hidden_dim), | |
| "max_seq_len": int(spec.max_seq_len), | |
| "diff_target_center": float(cfg.diff_target_center), | |
| }, indent=2)) | |
| (out_dir / "training_history.json").write_text(json.dumps(history, indent=2)) | |
| train_embeddings = _collect_embeddings(encoder, train_set[0], batch_size=cfg.batch_size) | |
| val_cap_logits = _collect_logits(head, encoder, val_set, cfg.batch_size) | |
| from greenrouting.classifier.calibration import fit_temperature | |
| temperature = fit_temperature(val_cap_logits, val_set[1]) | |
| (out_dir / "calibration.json").write_text(json.dumps({"temperature": float(temperature)}, indent=2)) | |
| from greenrouting.classifier.ood import calibrate_thresholds, fit_ood_stats | |
| ood_stats = fit_ood_stats(train_embeddings, k=5) | |
| thresholds = calibrate_thresholds(train_embeddings, ood_stats, percentile=99.0) | |
| import numpy as np | |
| np.savez( | |
| out_dir / "ood_stats.npz", | |
| centroid=ood_stats["centroid"], | |
| reference=ood_stats["reference"], | |
| k=ood_stats.get("k", 5), | |
| centroid_threshold=thresholds["centroid_threshold"], | |
| knn_threshold=thresholds["knn_threshold"], | |
| ) | |
| return { | |
| "history": history, | |
| "temperature": float(temperature), | |
| "n_train": len(train_set[0]), | |
| "n_val": len(val_set[0]), | |
| } | |
| def _evaluate(head, encoder: Encoder, val_set, cfg: TrainConfig) -> dict: | |
| import torch | |
| head.eval() | |
| all_cap_pred, all_cap_true = [], [] | |
| all_diff_pred, all_diff_true = [], [] | |
| all_len_pred, all_len_true = [], [] | |
| with torch.no_grad(): | |
| for emb, cap_t, diff_t, len_t in _iterate_batches( | |
| *val_set, batch_size=cfg.batch_size, encoder=encoder, shuffle=False, seed=cfg.seed, | |
| ): | |
| out = head(emb) | |
| all_cap_pred.append(torch.sigmoid(out["cap_logits"]).cpu().numpy()) | |
| all_cap_true.append(cap_t.cpu().numpy()) | |
| all_diff_pred.append(out["diff"].cpu().numpy()) | |
| all_diff_true.append(diff_t.cpu().numpy()) | |
| all_len_pred.append(out["len_logits"].argmax(dim=-1).cpu().numpy()) | |
| all_len_true.append(len_t.cpu().numpy()) | |
| head.train() | |
| import numpy as np | |
| cap_pred = np.concatenate(all_cap_pred) | |
| cap_true = np.concatenate(all_cap_true) | |
| diff_pred = np.concatenate(all_diff_pred) | |
| diff_true = np.concatenate(all_diff_true) | |
| len_pred = np.concatenate(all_len_pred) | |
| len_true = np.concatenate(all_len_true) | |
| cap_pred_bin = (cap_pred >= 0.5).astype(np.float32) | |
| cap_true_bin = (cap_true >= 0.5).astype(np.float32) | |
| tp = ((cap_pred_bin == 1) & (cap_true_bin == 1)).sum() | |
| fp = ((cap_pred_bin == 1) & (cap_true_bin == 0)).sum() | |
| fn = ((cap_pred_bin == 0) & (cap_true_bin == 1)).sum() | |
| precision = tp / max(tp + fp, 1) | |
| recall = tp / max(tp + fn, 1) | |
| f1 = 2 * precision * recall / max(precision + recall, 1e-9) | |
| diff_mae = float(np.abs(diff_pred - diff_true).mean()) | |
| len_acc = float((len_pred == len_true).mean()) | |
| return { | |
| "cap_precision": float(precision), | |
| "cap_recall": float(recall), | |
| "cap_f1": float(f1), | |
| "diff_mae": diff_mae, | |
| "len_acc": len_acc, | |
| } | |
| def _collect_embeddings(encoder: Encoder, texts: list[str], batch_size: int): | |
| import numpy as np | |
| chunks = [] | |
| for start in range(0, len(texts), batch_size): | |
| chunk = texts[start:start + batch_size] | |
| emb = encoder.embed(chunk).cpu().numpy() | |
| chunks.append(emb) | |
| return np.concatenate(chunks, axis=0) if chunks else np.zeros((0, 384), dtype=np.float32) | |
| def _collect_logits(head, encoder: Encoder, val_set, batch_size: int): | |
| import numpy as np | |
| import torch | |
| head.eval() | |
| out_logits = [] | |
| with torch.no_grad(): | |
| for emb, _cap, _diff, _len in _iterate_batches( | |
| *val_set, batch_size=batch_size, encoder=encoder, shuffle=False, seed=0, | |
| ): | |
| out = head(emb) | |
| out_logits.append(out["cap_logits"].cpu().numpy()) | |
| head.train() | |
| return np.concatenate(out_logits, axis=0) if out_logits else np.zeros((0, 8), dtype=np.float32) | |