from __future__ import annotations import argparse from pathlib import Path import torch from torch.utils.data import DataLoader from tiny_router.constants import FEATURE_MODES, HEAD_LABELS from tiny_router.data import RouterCollator, build_dataset_dict, tokenize_dataset_dict from tiny_router.io import load_checkpoint, load_temperature_scaling from tiny_router.metrics import evaluate_multitask from tiny_router.runtime import dump_json, get_device def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Evaluate a trained tiny-router checkpoint.") parser.add_argument("--model-dir", required=True) parser.add_argument("--data-file", required=True) parser.add_argument("--output-file") parser.add_argument("--device", choices=["auto", "cpu", "cuda", "mps"], default="auto") parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--feature-mode") parser.add_argument("--confidence-threshold", type=float, default=0.8) parser.add_argument("--run-ablations", action="store_true") return parser.parse_args() @torch.no_grad() def collect_metrics( model, dataloader, threshold: float, device: torch.device, temperatures: dict[str, float] | None = None, ) -> dict: logits_by_head = {head: [] for head in HEAD_LABELS} labels_by_head = {head: [] for head in HEAD_LABELS} model.eval() for batch in dataloader: batch = {key: value.to(device) for key, value in batch.items()} outputs = model(**batch) for head in HEAD_LABELS: logits_by_head[head].append(outputs["logits"][head].detach().cpu()) labels_by_head[head].append(batch[f"labels_{head}"].detach().cpu()) return evaluate_multitask( {head: torch.cat(values).numpy() for head, values in logits_by_head.items()}, {head: torch.cat(values).numpy() for head, values in labels_by_head.items()}, threshold=threshold, temperatures=temperatures, ) def evaluate_mode( model_dir: str, data_file: str, feature_mode: str | None, batch_size: int, threshold: float, requested_device: str, ) -> dict: device = get_device(requested_device=requested_device) model, tokenizer, config = load_checkpoint(model_dir, device=device) stored_temperatures = load_temperature_scaling(model_dir) chosen_mode = feature_mode or config.feature_mode temperatures = stored_temperatures if chosen_mode == config.feature_mode else None dataset_dict = build_dataset_dict(None, None, test_file=data_file) dataset_dict = tokenize_dataset_dict( dataset_dict, tokenizer=tokenizer, feature_mode=chosen_mode, max_length=config.max_length, recency_max=config.recency_max, ) loader = DataLoader( dataset_dict["test"], batch_size=batch_size, shuffle=False, collate_fn=RouterCollator(tokenizer), ) metrics = collect_metrics( model, loader, threshold=threshold, device=device, temperatures=temperatures, ) metrics["feature_mode"] = chosen_mode return metrics def main() -> None: args = parse_args() feature_mode = args.feature_mode metrics = evaluate_mode( model_dir=args.model_dir, data_file=args.data_file, feature_mode=feature_mode, batch_size=args.batch_size, threshold=args.confidence_threshold, requested_device=args.device, ) if args.run_ablations: metrics["ablations"] = { mode: evaluate_mode( model_dir=args.model_dir, data_file=args.data_file, feature_mode=mode, batch_size=args.batch_size, threshold=args.confidence_threshold, requested_device=args.device, )["overall"] for mode in FEATURE_MODES } if args.output_file: dump_json(args.output_file, metrics) else: print(metrics) if __name__ == "__main__": main()