Spaces:
Sleeping
Sleeping
File size: 4,088 Bytes
3bfff54 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | from __future__ import annotations
import argparse
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from tiny_router.constants import FEATURE_MODES, HEAD_LABELS
from tiny_router.data import RouterCollator, build_dataset_dict, tokenize_dataset_dict
from tiny_router.io import load_checkpoint, load_temperature_scaling
from tiny_router.metrics import evaluate_multitask
from tiny_router.runtime import dump_json, get_device
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate a trained tiny-router checkpoint.")
parser.add_argument("--model-dir", required=True)
parser.add_argument("--data-file", required=True)
parser.add_argument("--output-file")
parser.add_argument("--device", choices=["auto", "cpu", "cuda", "mps"], default="auto")
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--feature-mode")
parser.add_argument("--confidence-threshold", type=float, default=0.8)
parser.add_argument("--run-ablations", action="store_true")
return parser.parse_args()
@torch.no_grad()
def collect_metrics(
model,
dataloader,
threshold: float,
device: torch.device,
temperatures: dict[str, float] | None = None,
) -> dict:
logits_by_head = {head: [] for head in HEAD_LABELS}
labels_by_head = {head: [] for head in HEAD_LABELS}
model.eval()
for batch in dataloader:
batch = {key: value.to(device) for key, value in batch.items()}
outputs = model(**batch)
for head in HEAD_LABELS:
logits_by_head[head].append(outputs["logits"][head].detach().cpu())
labels_by_head[head].append(batch[f"labels_{head}"].detach().cpu())
return evaluate_multitask(
{head: torch.cat(values).numpy() for head, values in logits_by_head.items()},
{head: torch.cat(values).numpy() for head, values in labels_by_head.items()},
threshold=threshold,
temperatures=temperatures,
)
def evaluate_mode(
model_dir: str,
data_file: str,
feature_mode: str | None,
batch_size: int,
threshold: float,
requested_device: str,
) -> dict:
device = get_device(requested_device=requested_device)
model, tokenizer, config = load_checkpoint(model_dir, device=device)
stored_temperatures = load_temperature_scaling(model_dir)
chosen_mode = feature_mode or config.feature_mode
temperatures = stored_temperatures if chosen_mode == config.feature_mode else None
dataset_dict = build_dataset_dict(None, None, test_file=data_file)
dataset_dict = tokenize_dataset_dict(
dataset_dict,
tokenizer=tokenizer,
feature_mode=chosen_mode,
max_length=config.max_length,
recency_max=config.recency_max,
)
loader = DataLoader(
dataset_dict["test"],
batch_size=batch_size,
shuffle=False,
collate_fn=RouterCollator(tokenizer),
)
metrics = collect_metrics(
model,
loader,
threshold=threshold,
device=device,
temperatures=temperatures,
)
metrics["feature_mode"] = chosen_mode
return metrics
def main() -> None:
args = parse_args()
feature_mode = args.feature_mode
metrics = evaluate_mode(
model_dir=args.model_dir,
data_file=args.data_file,
feature_mode=feature_mode,
batch_size=args.batch_size,
threshold=args.confidence_threshold,
requested_device=args.device,
)
if args.run_ablations:
metrics["ablations"] = {
mode: evaluate_mode(
model_dir=args.model_dir,
data_file=args.data_file,
feature_mode=mode,
batch_size=args.batch_size,
threshold=args.confidence_threshold,
requested_device=args.device,
)["overall"]
for mode in FEATURE_MODES
}
if args.output_file:
dump_json(args.output_file, metrics)
else:
print(metrics)
if __name__ == "__main__":
main()
|