spectralman's picture
Initial deploy: classifier + FastAPI router
6f0ff99 verified
"""Training loop. Frozen encoder, head-only optimization, multi-task loss."""
from __future__ import annotations
import json
import math
from dataclasses import dataclass, field
from pathlib import Path
from greenrouting.classifier.model import DEFAULT_ENCODER, Encoder, ModelSpec, build_head
from greenrouting.data.schema import LENGTH_BUCKETS
from greenrouting.routing.registry import CAPABILITY_KEYS
LENGTH_TO_INDEX: dict[str, int] = {b: i for i, b in enumerate(LENGTH_BUCKETS)}
@dataclass
class TrainConfig:
encoder_name: str = DEFAULT_ENCODER
hidden_dim: int = 256
dropout: float = 0.1
max_seq_len: int = 256
epochs: int = 8
batch_size: int = 32
learning_rate: float = 1e-3
weight_decay: float = 1e-4
cap_weight: float = 1.0
diff_weight: float = 0.5
len_weight: float = 0.3
val_split: float = 0.15
seed: int = 42
huber_delta: float = 1.0
cap_pos_weight: float = 2.0
diff_target_center: float = field(default_factory=lambda: math.log(8e9))
def _load_split(parquet_path: str | Path):
import pandas as pd
return pd.read_parquet(parquet_path)
def _build_targets(df, cfg: TrainConfig):
import numpy as np
cap_cols = [f"cap_{k}" for k in CAPABILITY_KEYS]
caps = df[cap_cols].fillna(0.0).to_numpy(dtype=np.float32)
diff = (df["difficulty_log_params"].fillna(cfg.diff_target_center).to_numpy(dtype=np.float32))
diff_centered = diff - cfg.diff_target_center
lens = df["length_bucket"].fillna("medium").map(LENGTH_TO_INDEX).fillna(1).to_numpy(dtype=np.int64)
texts = df["text"].astype(str).tolist()
return texts, caps, diff_centered, lens
def _split_train_val(texts, caps, diff, lens, val_split: float, seed: int):
import numpy as np
rng = np.random.default_rng(seed)
n = len(texts)
indices = np.arange(n)
rng.shuffle(indices)
n_val = max(1, int(n * val_split))
val_idx = indices[:n_val]
train_idx = indices[n_val:]
return (
([texts[i] for i in train_idx], caps[train_idx], diff[train_idx], lens[train_idx]),
([texts[i] for i in val_idx], caps[val_idx], diff[val_idx], lens[val_idx]),
)
def _iterate_batches(texts, caps, diff, lens, batch_size: int, encoder: Encoder, shuffle: bool, seed: int):
import numpy as np
import torch
n = len(texts)
indices = np.arange(n)
if shuffle:
np.random.default_rng(seed).shuffle(indices)
for start in range(0, n, batch_size):
idx = indices[start:start + batch_size]
batch_texts = [texts[i] for i in idx]
emb = encoder.embed(batch_texts)
cap_t = torch.tensor(caps[idx], dtype=torch.float32, device=emb.device)
diff_t = torch.tensor(diff[idx], dtype=torch.float32, device=emb.device)
len_t = torch.tensor(lens[idx], dtype=torch.long, device=emb.device)
yield emb, cap_t, diff_t, len_t
def train(
train_parquet: str | Path,
output_dir: str | Path,
cfg: TrainConfig | None = None,
) -> dict:
import torch
import torch.nn as nn
from torch.optim import AdamW
cfg = cfg or TrainConfig()
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
df = _load_split(train_parquet)
texts, caps, diff, lens = _build_targets(df, cfg)
train_set, val_set = _split_train_val(texts, caps, diff, lens, cfg.val_split, cfg.seed)
encoder = Encoder(cfg.encoder_name, cfg.max_seq_len)
embed_dim = encoder.embed(["probe"]).shape[-1]
spec = ModelSpec(
encoder_name=cfg.encoder_name,
embedding_dim=embed_dim,
hidden_dim=cfg.hidden_dim,
dropout=cfg.dropout,
max_seq_len=cfg.max_seq_len,
)
head = build_head(spec).to(encoder.device)
pos_weight = torch.full((spec.n_capabilities,), cfg.cap_pos_weight, device=encoder.device)
cap_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
diff_loss_fn = nn.HuberLoss(delta=cfg.huber_delta)
len_loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(head.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
history = []
for epoch in range(cfg.epochs):
head.train()
train_loss_sum = 0.0
n_train = 0
for emb, cap_t, diff_t, len_t in _iterate_batches(
*train_set, batch_size=cfg.batch_size, encoder=encoder,
shuffle=True, seed=cfg.seed + epoch,
):
out = head(emb)
loss = (
cfg.cap_weight * cap_loss_fn(out["cap_logits"], cap_t)
+ cfg.diff_weight * diff_loss_fn(out["diff"], diff_t)
+ cfg.len_weight * len_loss_fn(out["len_logits"], len_t)
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss_sum += loss.item() * emb.shape[0]
n_train += emb.shape[0]
val_metrics = _evaluate(head, encoder, val_set, cfg)
history.append({
"epoch": epoch,
"train_loss": train_loss_sum / max(n_train, 1),
**val_metrics,
})
print(
f"epoch {epoch+1}/{cfg.epochs} "
f"train_loss={train_loss_sum/max(n_train,1):.4f} "
f"val_cap_f1={val_metrics['cap_f1']:.3f} "
f"val_diff_mae={val_metrics['diff_mae']:.3f} "
f"val_len_acc={val_metrics['len_acc']:.3f}"
)
head.eval()
torch.save(head.state_dict(), out_dir / "head.pt")
(out_dir / "encoder_name.txt").write_text(cfg.encoder_name)
(out_dir / "metadata.json").write_text(json.dumps({
"capability_keys": list(CAPABILITY_KEYS),
"length_buckets": list(LENGTH_BUCKETS),
"embedding_dim": int(spec.embedding_dim),
"hidden_dim": int(spec.hidden_dim),
"max_seq_len": int(spec.max_seq_len),
"diff_target_center": float(cfg.diff_target_center),
}, indent=2))
(out_dir / "training_history.json").write_text(json.dumps(history, indent=2))
train_embeddings = _collect_embeddings(encoder, train_set[0], batch_size=cfg.batch_size)
val_cap_logits = _collect_logits(head, encoder, val_set, cfg.batch_size)
from greenrouting.classifier.calibration import fit_temperature
temperature = fit_temperature(val_cap_logits, val_set[1])
(out_dir / "calibration.json").write_text(json.dumps({"temperature": float(temperature)}, indent=2))
from greenrouting.classifier.ood import calibrate_thresholds, fit_ood_stats
ood_stats = fit_ood_stats(train_embeddings, k=5)
thresholds = calibrate_thresholds(train_embeddings, ood_stats, percentile=99.0)
import numpy as np
np.savez(
out_dir / "ood_stats.npz",
centroid=ood_stats["centroid"],
reference=ood_stats["reference"],
k=ood_stats.get("k", 5),
centroid_threshold=thresholds["centroid_threshold"],
knn_threshold=thresholds["knn_threshold"],
)
return {
"history": history,
"temperature": float(temperature),
"n_train": len(train_set[0]),
"n_val": len(val_set[0]),
}
def _evaluate(head, encoder: Encoder, val_set, cfg: TrainConfig) -> dict:
import torch
head.eval()
all_cap_pred, all_cap_true = [], []
all_diff_pred, all_diff_true = [], []
all_len_pred, all_len_true = [], []
with torch.no_grad():
for emb, cap_t, diff_t, len_t in _iterate_batches(
*val_set, batch_size=cfg.batch_size, encoder=encoder, shuffle=False, seed=cfg.seed,
):
out = head(emb)
all_cap_pred.append(torch.sigmoid(out["cap_logits"]).cpu().numpy())
all_cap_true.append(cap_t.cpu().numpy())
all_diff_pred.append(out["diff"].cpu().numpy())
all_diff_true.append(diff_t.cpu().numpy())
all_len_pred.append(out["len_logits"].argmax(dim=-1).cpu().numpy())
all_len_true.append(len_t.cpu().numpy())
head.train()
import numpy as np
cap_pred = np.concatenate(all_cap_pred)
cap_true = np.concatenate(all_cap_true)
diff_pred = np.concatenate(all_diff_pred)
diff_true = np.concatenate(all_diff_true)
len_pred = np.concatenate(all_len_pred)
len_true = np.concatenate(all_len_true)
cap_pred_bin = (cap_pred >= 0.5).astype(np.float32)
cap_true_bin = (cap_true >= 0.5).astype(np.float32)
tp = ((cap_pred_bin == 1) & (cap_true_bin == 1)).sum()
fp = ((cap_pred_bin == 1) & (cap_true_bin == 0)).sum()
fn = ((cap_pred_bin == 0) & (cap_true_bin == 1)).sum()
precision = tp / max(tp + fp, 1)
recall = tp / max(tp + fn, 1)
f1 = 2 * precision * recall / max(precision + recall, 1e-9)
diff_mae = float(np.abs(diff_pred - diff_true).mean())
len_acc = float((len_pred == len_true).mean())
return {
"cap_precision": float(precision),
"cap_recall": float(recall),
"cap_f1": float(f1),
"diff_mae": diff_mae,
"len_acc": len_acc,
}
def _collect_embeddings(encoder: Encoder, texts: list[str], batch_size: int):
import numpy as np
chunks = []
for start in range(0, len(texts), batch_size):
chunk = texts[start:start + batch_size]
emb = encoder.embed(chunk).cpu().numpy()
chunks.append(emb)
return np.concatenate(chunks, axis=0) if chunks else np.zeros((0, 384), dtype=np.float32)
def _collect_logits(head, encoder: Encoder, val_set, batch_size: int):
import numpy as np
import torch
head.eval()
out_logits = []
with torch.no_grad():
for emb, _cap, _diff, _len in _iterate_batches(
*val_set, batch_size=batch_size, encoder=encoder, shuffle=False, seed=0,
):
out = head(emb)
out_logits.append(out["cap_logits"].cpu().numpy())
head.train()
return np.concatenate(out_logits, axis=0) if out_logits else np.zeros((0, 8), dtype=np.float32)