dexifried
Replace with tiny-router trainer (ZeroGPU/H200)
3bfff54
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from tiny_router.constants import FEATURE_MODES, HEAD_LABELS
from tiny_router.data import RouterCollator, build_dataset_dict, tokenize_dataset_dict
from tiny_router.io import load_checkpoint, load_temperature_scaling
from tiny_router.metrics import evaluate_multitask
from tiny_router.runtime import dump_json, get_device
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate a trained tiny-router checkpoint.")
parser.add_argument("--model-dir", required=True)
parser.add_argument("--data-file", required=True)
parser.add_argument("--output-file")
parser.add_argument("--device", choices=["auto", "cpu", "cuda", "mps"], default="auto")
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--feature-mode")
parser.add_argument("--confidence-threshold", type=float, default=0.8)
parser.add_argument("--run-ablations", action="store_true")
return parser.parse_args()
@torch.no_grad()
def collect_metrics(
model,
dataloader,
threshold: float,
device: torch.device,
temperatures: dict[str, float] | None = None,
) -> dict:
logits_by_head = {head: [] for head in HEAD_LABELS}
labels_by_head = {head: [] for head in HEAD_LABELS}
model.eval()
for batch in dataloader:
batch = {key: value.to(device) for key, value in batch.items()}
outputs = model(**batch)
for head in HEAD_LABELS:
logits_by_head[head].append(outputs["logits"][head].detach().cpu())
labels_by_head[head].append(batch[f"labels_{head}"].detach().cpu())
return evaluate_multitask(
{head: torch.cat(values).numpy() for head, values in logits_by_head.items()},
{head: torch.cat(values).numpy() for head, values in labels_by_head.items()},
threshold=threshold,
temperatures=temperatures,
)
def evaluate_mode(
model_dir: str,
data_file: str,
feature_mode: str | None,
batch_size: int,
threshold: float,
requested_device: str,
) -> dict:
device = get_device(requested_device=requested_device)
model, tokenizer, config = load_checkpoint(model_dir, device=device)
stored_temperatures = load_temperature_scaling(model_dir)
chosen_mode = feature_mode or config.feature_mode
temperatures = stored_temperatures if chosen_mode == config.feature_mode else None
dataset_dict = build_dataset_dict(None, None, test_file=data_file)
dataset_dict = tokenize_dataset_dict(
dataset_dict,
tokenizer=tokenizer,
feature_mode=chosen_mode,
max_length=config.max_length,
recency_max=config.recency_max,
)
loader = DataLoader(
dataset_dict["test"],
batch_size=batch_size,
shuffle=False,
collate_fn=RouterCollator(tokenizer),
)
metrics = collect_metrics(
model,
loader,
threshold=threshold,
device=device,
temperatures=temperatures,
)
metrics["feature_mode"] = chosen_mode
return metrics
def main() -> None:
args = parse_args()
feature_mode = args.feature_mode
metrics = evaluate_mode(
model_dir=args.model_dir,
data_file=args.data_file,
feature_mode=feature_mode,
batch_size=args.batch_size,
threshold=args.confidence_threshold,
requested_device=args.device,
)
if args.run_ablations:
metrics["ablations"] = {
mode: evaluate_mode(
model_dir=args.model_dir,
data_file=args.data_file,
feature_mode=mode,
batch_size=args.batch_size,
threshold=args.confidence_threshold,
requested_device=args.device,
)["overall"]
for mode in FEATURE_MODES
}
if args.output_file:
dump_json(args.output_file, metrics)
else:
print(metrics)
if __name__ == "__main__":
main()