Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import math | |
| from pathlib import Path | |
| import torch | |
| from torch.optim import AdamW | |
| from torch.utils.data import DataLoader | |
| from tqdm.auto import tqdm | |
| from transformers import get_linear_schedule_with_warmup | |
| from tiny_router.calibration import fit_temperature_by_head | |
| from tiny_router.config import RouterModelConfig | |
| from tiny_router.constants import ( | |
| ACTION_VOCAB, | |
| DEFAULT_ENCODER, | |
| DEFAULT_MAX_LENGTH, | |
| DEFAULT_RECENCY_MAX, | |
| HEAD_LABELS, | |
| OUTCOME_VOCAB, | |
| ) | |
| from tiny_router.data import RouterCollator, build_dataset_dict, tokenize_dataset_dict | |
| from tiny_router.io import load_checkpoint, load_tokenizer, save_checkpoint, save_temperature_scaling | |
| from tiny_router.metrics import evaluate_multitask | |
| from tiny_router.model import TinyRouterModel | |
| from tiny_router.runtime import dump_json, get_autocast, get_device, seed_everything | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Train the tiny-router multi-head classifier.") | |
| parser.add_argument("--train-file", required=True) | |
| parser.add_argument("--validation-file", required=True) | |
| parser.add_argument("--test-file") | |
| parser.add_argument("--output-dir", required=True) | |
| parser.add_argument("--encoder-name", default=DEFAULT_ENCODER) | |
| parser.add_argument("--device", choices=["auto", "cpu", "cuda", "mps"], default="auto") | |
| parser.add_argument("--feature-mode", default="full_interaction") | |
| parser.add_argument("--pooling-type", choices=["mean", "attention"], default="attention") | |
| parser.add_argument( | |
| "--use-head-dependencies", | |
| action=argparse.BooleanOptionalAction, | |
| default=True, | |
| ) | |
| parser.add_argument("--dependency-hidden-dim", type=int, default=32) | |
| parser.add_argument("--max-length", type=int, default=DEFAULT_MAX_LENGTH) | |
| parser.add_argument("--recency-max", type=int, default=DEFAULT_RECENCY_MAX) | |
| parser.add_argument("--batch-size", type=int, default=16) | |
| parser.add_argument("--epochs", type=int, default=5) | |
| parser.add_argument("--encoder-lr", type=float, default=2e-5) | |
| parser.add_argument("--head-lr", type=float, default=1e-4) | |
| parser.add_argument("--weight-decay", type=float, default=0.01) | |
| parser.add_argument("--warmup-ratio", type=float, default=0.1) | |
| parser.add_argument("--dropout", type=float, default=0.1) | |
| parser.add_argument("--seed", type=int, default=13) | |
| parser.add_argument("--patience", type=int, default=2) | |
| parser.add_argument("--mixed-precision", action="store_true") | |
| parser.add_argument("--confidence-threshold", type=float, default=0.8) | |
| parser.add_argument( | |
| "--head-loss-weights", | |
| default="{}", | |
| help='JSON dict, for example {"urgency": 1.2, "retention": 0.8}', | |
| ) | |
| return parser.parse_args() | |
| def move_batch(batch: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]: | |
| return {key: value.to(device) for key, value in batch.items()} | |
| def collect_eval_arrays( | |
| model: TinyRouterModel, | |
| dataloader: DataLoader, | |
| device: torch.device, | |
| ) -> tuple[dict[str, object], dict[str, object]]: | |
| logits_by_head = {head: [] for head in HEAD_LABELS} | |
| labels_by_head = {head: [] for head in HEAD_LABELS} | |
| model.eval() | |
| for batch in dataloader: | |
| batch = move_batch(batch, device) | |
| outputs = model(**batch) | |
| for head in HEAD_LABELS: | |
| logits_by_head[head].append(outputs["logits"][head].detach().cpu()) | |
| labels_by_head[head].append(batch[f"labels_{head}"].detach().cpu()) | |
| stacked_logits = { | |
| head: torch.cat(chunks).numpy() for head, chunks in logits_by_head.items() | |
| } | |
| stacked_labels = { | |
| head: torch.cat(chunks).numpy() for head, chunks in labels_by_head.items() | |
| } | |
| return stacked_logits, stacked_labels | |
| def run_eval( | |
| model: TinyRouterModel, | |
| dataloader: DataLoader, | |
| device: torch.device, | |
| threshold: float, | |
| temperatures: dict[str, float] | None = None, | |
| ) -> dict: | |
| stacked_logits, stacked_labels = collect_eval_arrays(model, dataloader, device=device) | |
| return evaluate_multitask( | |
| stacked_logits, | |
| stacked_labels, | |
| threshold=threshold, | |
| temperatures=temperatures, | |
| ) | |
| def main() -> None: | |
| args = parse_args() | |
| seed_everything(args.seed) | |
| device = get_device(requested_device=args.device) | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| tokenizer = load_tokenizer(args.encoder_name) | |
| dataset_dict = build_dataset_dict(args.train_file, args.validation_file, args.test_file) | |
| dataset_dict = tokenize_dataset_dict( | |
| dataset_dict, | |
| tokenizer=tokenizer, | |
| feature_mode=args.feature_mode, | |
| max_length=args.max_length, | |
| recency_max=args.recency_max, | |
| ) | |
| collator = RouterCollator(tokenizer) | |
| train_loader = DataLoader( | |
| dataset_dict["train"], | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| collate_fn=collator, | |
| ) | |
| validation_loader = DataLoader( | |
| dataset_dict["validation"], | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| collate_fn=collator, | |
| ) | |
| model_config = RouterModelConfig( | |
| encoder_name=args.encoder_name, | |
| dropout=args.dropout, | |
| action_vocab=ACTION_VOCAB, | |
| outcome_vocab=OUTCOME_VOCAB, | |
| label_maps=HEAD_LABELS, | |
| pooling_type=args.pooling_type, | |
| use_head_dependencies=args.use_head_dependencies, | |
| dependency_hidden_dim=args.dependency_hidden_dim, | |
| feature_mode=args.feature_mode, | |
| max_length=args.max_length, | |
| recency_max=args.recency_max, | |
| ) | |
| model = TinyRouterModel(model_config).to(device) | |
| head_loss_weights = json.loads(args.head_loss_weights) | |
| optimizer = AdamW( | |
| [ | |
| {"params": model.encoder.parameters(), "lr": args.encoder_lr}, | |
| { | |
| "params": [ | |
| param | |
| for name, param in model.named_parameters() | |
| if not name.startswith("encoder.") | |
| ], | |
| "lr": args.head_lr, | |
| }, | |
| ], | |
| weight_decay=args.weight_decay, | |
| ) | |
| total_steps = len(train_loader) * args.epochs | |
| warmup_steps = math.ceil(total_steps * args.warmup_ratio) | |
| scheduler = get_linear_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=warmup_steps, | |
| num_training_steps=total_steps, | |
| ) | |
| scaler = torch.amp.GradScaler("cuda", enabled=args.mixed_precision and device.type == "cuda") | |
| best_score = float("-inf") | |
| best_metrics = None | |
| patience = 0 | |
| history = [] | |
| for epoch in range(1, args.epochs + 1): | |
| model.train() | |
| epoch_loss = 0.0 | |
| progress = tqdm(train_loader, desc=f"epoch {epoch}", leave=False) | |
| for batch in progress: | |
| batch = move_batch(batch, device) | |
| optimizer.zero_grad(set_to_none=True) | |
| with get_autocast(device, args.mixed_precision): | |
| outputs = model(**batch, head_loss_weights=head_loss_weights) | |
| loss = outputs["loss"] | |
| if loss is None: | |
| raise RuntimeError("Training batch is missing labels.") | |
| if not torch.isfinite(loss): | |
| raise RuntimeError( | |
| "Encountered non-finite loss during training. " | |
| "On Apple Silicon, retry with `--device cpu`." | |
| ) | |
| if scaler.is_enabled(): | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| epoch_loss += float(loss.detach().cpu()) | |
| progress.set_postfix(loss=f"{loss.item():.4f}") | |
| metrics = run_eval( | |
| model, | |
| validation_loader, | |
| device=device, | |
| threshold=args.confidence_threshold, | |
| ) | |
| metrics["training"] = {"epoch": epoch, "loss": round(epoch_loss / max(len(train_loader), 1), 4)} | |
| history.append(metrics) | |
| score = metrics["overall"]["macro_average_f1"] | |
| if score > best_score: | |
| best_score = score | |
| best_metrics = metrics | |
| patience = 0 | |
| save_checkpoint( | |
| output_dir, | |
| model, | |
| tokenizer, | |
| model_config, | |
| training_args=vars(args), | |
| metrics=metrics, | |
| ) | |
| else: | |
| patience += 1 | |
| if patience > args.patience: | |
| break | |
| if best_metrics is None: | |
| raise RuntimeError("Training did not produce any validation metrics.") | |
| dump_json(output_dir / "history.json", {"epochs": history, "best_macro_average_f1": best_score}) | |
| best_model, _, _ = load_checkpoint(output_dir, device=device) | |
| validation_logits, validation_labels = collect_eval_arrays( | |
| best_model, | |
| validation_loader, | |
| device=device, | |
| ) | |
| temperature_payload = fit_temperature_by_head(validation_logits, validation_labels) | |
| save_temperature_scaling(output_dir, temperature_payload) | |
| calibrated_validation_metrics = evaluate_multitask( | |
| validation_logits, | |
| validation_labels, | |
| threshold=args.confidence_threshold, | |
| temperatures=temperature_payload["per_head"], | |
| ) | |
| dump_json(output_dir / "metrics.json", calibrated_validation_metrics) | |
| if "test" in dataset_dict: | |
| test_loader = DataLoader( | |
| dataset_dict["test"], | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| collate_fn=collator, | |
| ) | |
| test_metrics = run_eval( | |
| best_model, | |
| test_loader, | |
| device=device, | |
| threshold=args.confidence_threshold, | |
| temperatures=temperature_payload["per_head"], | |
| ) | |
| dump_json(output_dir / "test_metrics.json", test_metrics) | |
| if __name__ == "__main__": | |
| main() | |