agentsight-api / src /training /evaluate.py
Minato Namikaze
Deploy to Hugging Face Spaces
2aed081
Raw
History Blame Contribute Delete
5.51 kB
import torch
import numpy as np
from sklearn.metrics import f1_score, recall_score, precision_score
# ─────────────────────────────────────────────────────────────────────────────
# Official metric definitions β€” match AgentHallu paper (Equations 6-11)
# ─────────────────────────────────────────────────────────────────────────────
def step_localization_accuracy(samples, step_preds):
"""
AgentHallu Eq. 11:
Acc_step = |{i ∈ Hhal : pred_step(i) == true_step(i)}| / |Hhal|
Denominator is ALL hallucinated trajectories (regardless of whether
the model's judgment prediction was correct). Do NOT restrict to
TP-only β€” that would be a lenient variant not comparable to the paper.
"""
correct = total = 0
for s, pred_step in zip(samples, step_preds):
is_hal = s.get("is_hallucination", False)
if isinstance(is_hal, str):
is_hal = is_hal.lower() == "true"
if not is_hal:
continue
total += 1
true_step = s.get("hallucination_step")
if true_step is not None:
true_step = int(true_step)
if pred_step == true_step:
correct += 1
return correct / total if total else 0.0
def tune_threshold(model, val_samples, preprocessor, thresholds=None):
"""
Find the decision threshold on the validation set that maximises
macro-F1. Returns the best threshold and the corresponding metrics.
"""
if thresholds is None:
thresholds = np.arange(0.2, 0.8, 0.02)
model.eval()
device = next(model.parameters()).device
# Collect max-prob per trajectory
max_probs = []
hal_true = []
with torch.no_grad():
for sample in val_samples:
is_hal = sample.get("is_hallucination", False)
if isinstance(is_hal, str):
is_hal = is_hal.lower() == "true"
hal_true.append(int(is_hal))
try:
steps = preprocessor.encode_trajectory(sample)
except Exception:
steps = []
if not steps:
max_probs.append(0.0)
continue
ids = torch.stack([s["encoding"]["input_ids"].squeeze(0) for s in steps]).to(device)
mask = torch.stack([s["encoding"]["attention_mask"].squeeze(0) for s in steps]).to(device)
vocab_size = model.encoder.config.vocab_size
ids = torch.clamp(ids, 0, vocab_size - 1)
logits = model(ids, mask)
probs = torch.sigmoid(logits).cpu().tolist()
if isinstance(probs, float):
probs = [probs]
max_probs.append(max(probs))
best_f1 = -1.0
best_thr = 0.5
for thr in thresholds:
preds = [1 if p > thr else 0 for p in max_probs]
f1 = f1_score(hal_true, preds, average="macro", zero_division=0)
if f1 > best_f1:
best_f1 = f1
best_thr = float(thr)
return best_thr, best_f1
def evaluate(model, test_samples, preprocessor, threshold=0.5):
"""
Full evaluation: judgment (binary) + step localisation.
Uses the supplied threshold (tune it on val, never on test).
"""
model.eval()
device = next(model.parameters()).device
hal_preds = []
hal_true = []
step_preds = []
with torch.no_grad():
for sample in test_samples:
is_hal = sample.get("is_hallucination", False)
if isinstance(is_hal, str):
is_hal = is_hal.lower() == "true"
hal_true.append(int(is_hal))
try:
steps = preprocessor.encode_trajectory(sample)
except Exception:
steps = []
if not steps:
hal_preds.append(0)
step_preds.append(None)
continue
ids = torch.stack([s["encoding"]["input_ids"].squeeze(0) for s in steps]).to(device)
mask = torch.stack([s["encoding"]["attention_mask"].squeeze(0) for s in steps]).to(device)
vocab_size = model.encoder.config.vocab_size
ids = torch.clamp(ids, 0, vocab_size - 1)
logits = model(ids, mask)
probs = torch.sigmoid(logits).cpu().tolist()
if isinstance(probs, float):
probs = [probs]
max_prob = max(probs)
max_prob_idx = probs.index(max_prob)
pred_is_hal = max_prob > threshold
hal_preds.append(int(pred_is_hal))
pred_step = steps[max_prob_idx]["step_idx"] if pred_is_hal else None
step_preds.append(pred_step)
step_acc = step_localization_accuracy(test_samples, step_preds)
macro_f1 = f1_score(hal_true, hal_preds, average="macro", zero_division=0)
macro_rec = recall_score(hal_true, hal_preds, average="macro", zero_division=0)
macro_pre = precision_score(hal_true, hal_preds, average="macro", zero_division=0)
return {
"step_acc": step_acc,
"judgment_f1": macro_f1,
"judgment_recall": macro_rec,
"judgment_precision": macro_pre,
"threshold": threshold,
"n_samples": len(test_samples),
}