Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tiny_router.constants import FEATURE_MODES, HEAD_LABELS | |
| from tiny_router.data import RouterCollator, build_dataset_dict, tokenize_dataset_dict | |
| from tiny_router.io import load_checkpoint, load_temperature_scaling | |
| from tiny_router.metrics import evaluate_multitask | |
| from tiny_router.runtime import dump_json, get_device | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Evaluate a trained tiny-router checkpoint.") | |
| parser.add_argument("--model-dir", required=True) | |
| parser.add_argument("--data-file", required=True) | |
| parser.add_argument("--output-file") | |
| parser.add_argument("--device", choices=["auto", "cpu", "cuda", "mps"], default="auto") | |
| parser.add_argument("--batch-size", type=int, default=16) | |
| parser.add_argument("--feature-mode") | |
| parser.add_argument("--confidence-threshold", type=float, default=0.8) | |
| parser.add_argument("--run-ablations", action="store_true") | |
| return parser.parse_args() | |
| def collect_metrics( | |
| model, | |
| dataloader, | |
| threshold: float, | |
| device: torch.device, | |
| temperatures: dict[str, float] | None = None, | |
| ) -> dict: | |
| logits_by_head = {head: [] for head in HEAD_LABELS} | |
| labels_by_head = {head: [] for head in HEAD_LABELS} | |
| model.eval() | |
| for batch in dataloader: | |
| batch = {key: value.to(device) for key, value in batch.items()} | |
| outputs = model(**batch) | |
| for head in HEAD_LABELS: | |
| logits_by_head[head].append(outputs["logits"][head].detach().cpu()) | |
| labels_by_head[head].append(batch[f"labels_{head}"].detach().cpu()) | |
| return evaluate_multitask( | |
| {head: torch.cat(values).numpy() for head, values in logits_by_head.items()}, | |
| {head: torch.cat(values).numpy() for head, values in labels_by_head.items()}, | |
| threshold=threshold, | |
| temperatures=temperatures, | |
| ) | |
| def evaluate_mode( | |
| model_dir: str, | |
| data_file: str, | |
| feature_mode: str | None, | |
| batch_size: int, | |
| threshold: float, | |
| requested_device: str, | |
| ) -> dict: | |
| device = get_device(requested_device=requested_device) | |
| model, tokenizer, config = load_checkpoint(model_dir, device=device) | |
| stored_temperatures = load_temperature_scaling(model_dir) | |
| chosen_mode = feature_mode or config.feature_mode | |
| temperatures = stored_temperatures if chosen_mode == config.feature_mode else None | |
| dataset_dict = build_dataset_dict(None, None, test_file=data_file) | |
| dataset_dict = tokenize_dataset_dict( | |
| dataset_dict, | |
| tokenizer=tokenizer, | |
| feature_mode=chosen_mode, | |
| max_length=config.max_length, | |
| recency_max=config.recency_max, | |
| ) | |
| loader = DataLoader( | |
| dataset_dict["test"], | |
| batch_size=batch_size, | |
| shuffle=False, | |
| collate_fn=RouterCollator(tokenizer), | |
| ) | |
| metrics = collect_metrics( | |
| model, | |
| loader, | |
| threshold=threshold, | |
| device=device, | |
| temperatures=temperatures, | |
| ) | |
| metrics["feature_mode"] = chosen_mode | |
| return metrics | |
| def main() -> None: | |
| args = parse_args() | |
| feature_mode = args.feature_mode | |
| metrics = evaluate_mode( | |
| model_dir=args.model_dir, | |
| data_file=args.data_file, | |
| feature_mode=feature_mode, | |
| batch_size=args.batch_size, | |
| threshold=args.confidence_threshold, | |
| requested_device=args.device, | |
| ) | |
| if args.run_ablations: | |
| metrics["ablations"] = { | |
| mode: evaluate_mode( | |
| model_dir=args.model_dir, | |
| data_file=args.data_file, | |
| feature_mode=mode, | |
| batch_size=args.batch_size, | |
| threshold=args.confidence_threshold, | |
| requested_device=args.device, | |
| )["overall"] | |
| for mode in FEATURE_MODES | |
| } | |
| if args.output_file: | |
| dump_json(args.output_file, metrics) | |
| else: | |
| print(metrics) | |
| if __name__ == "__main__": | |
| main() | |