Ali-Bhai's picture
Add final DeBERTa span-level tool hallucination detector
bf58e9a verified
# span_inference.py
# Inference helper for the DeBERTa-v3-small span-token hallucination detector.
from collections import defaultdict
import torch
MAX_LENGTH = 512
ANSWER_WINDOW_TOKENS = 384
ANSWER_WINDOW_STRIDE = 256
DEFAULT_DECODING_CONFIG = {
"threshold": 0.5,
"min_span_chars": 1,
"min_span_tokens": 1,
"merge_gap_chars": 1,
"strip_predicted_span_whitespace": True,
"drop_spans_without_alnum": True,
"score_name": "sum_non_O_probability",
}
SPAN_LABELS = [
"O",
"tool_output_conflict",
"overgeneration",
"missing_tool_action_recommendation",
]
def normalize_tool_names(names, max_names=80):
if not isinstance(names, list):
return ""
out = []
seen = set()
for item in names:
if isinstance(item, str):
item = " ".join(item.split()).strip()
if item and item.lower() not in seen:
out.append(item)
seen.add(item.lower())
if len(out) >= max_names:
break
return "; ".join(out)
def has_alnum(text):
return any(ch.isalnum() for ch in text)
def safe_divide(num, den):
return float(num) / float(den) if den else 0.0
def answer_token_offsets(answer, tokenizer):
enc = tokenizer(
answer,
add_special_tokens=False,
return_offsets_mapping=True,
truncation=False,
)
offsets = []
for i, (start, end) in enumerate(enc["offset_mapping"]):
start = int(start)
end = int(end)
if end > start:
offsets.append({"answer_token_index": i, "start": start, "end": end})
return offsets
def make_answer_windows(answer, tokenizer):
offsets = answer_token_offsets(answer, tokenizer)
if not offsets:
return [{
"window_index": 0,
"answer_window_start": 0,
"answer_window_end": len(answer),
}]
n = len(offsets)
if n <= ANSWER_WINDOW_TOKENS:
token_ranges = [(0, n)]
else:
starts = list(range(0, n, ANSWER_WINDOW_STRIDE))
final_start = max(0, n - ANSWER_WINDOW_TOKENS)
if starts[-1] != final_start:
starts.append(final_start)
starts = sorted(set(starts))
token_ranges = [(s, min(s + ANSWER_WINDOW_TOKENS, n)) for s in starts]
windows = []
for wi, (tok_start, tok_end) in enumerate(token_ranges):
char_start = offsets[tok_start]["start"]
char_end = offsets[tok_end - 1]["end"]
windows.append({
"window_index": wi,
"answer_window_start": int(char_start),
"answer_window_end": int(char_end),
})
return windows
def build_window_text(record, window):
answer = record.get("output", "")
query = record.get("query", "")
context = record.get("context", "")
tool_names = normalize_tool_names(record.get("available_tool_names", []))
answer_window_text = answer[window["answer_window_start"]:window["answer_window_end"]]
prefix = "Answer window:\n"
text = prefix + answer_window_text
text += "\n\nQuestion:\n" + query
text += "\n\nTool responses:\n" + context
if tool_names:
text += "\n\nAvailable tool names:\n" + tool_names
return text, answer_window_text, len(prefix), len(prefix) + len(answer_window_text)
def deduplicate_candidates(candidates):
by_key = {}
for cand in candidates:
key = (int(cand["start"]), int(cand["end"]))
if key not in by_key:
by_key[key] = dict(cand)
continue
old = by_key[key]
if cand["hallucination_probability"] > old["hallucination_probability"]:
by_key[key] = dict(cand)
else:
old["hallucination_probability"] = max(
old["hallucination_probability"],
cand["hallucination_probability"],
)
for label, prob in cand["type_probabilities"].items():
old["type_probabilities"][label] = max(old["type_probabilities"].get(label, 0.0), prob)
return sorted(by_key.values(), key=lambda x: (x["start"], x["end"]))
def strip_span_whitespace(answer, start, end):
start = max(0, min(int(start), len(answer)))
end = max(0, min(int(end), len(answer)))
while start < end and answer[start].isspace():
start += 1
while end > start and answer[end - 1].isspace():
end -= 1
return start, end
def merge_candidates_to_spans(answer, candidates, config):
threshold = float(config.get("threshold", DEFAULT_DECODING_CONFIG["threshold"]))
min_span_chars = int(config.get("min_span_chars", DEFAULT_DECODING_CONFIG["min_span_chars"]))
min_span_tokens = int(config.get("min_span_tokens", DEFAULT_DECODING_CONFIG["min_span_tokens"]))
merge_gap_chars = int(config.get("merge_gap_chars", DEFAULT_DECODING_CONFIG["merge_gap_chars"]))
selected = [
c for c in deduplicate_candidates(candidates)
if c["hallucination_probability"] >= threshold and c["end"] > c["start"]
]
selected.sort(key=lambda x: (x["start"], x["end"]))
merged = []
cur = None
for item in selected:
if cur is None:
cur = {
"start": item["start"],
"end": item["end"],
"token_count": 1,
"score": item["hallucination_probability"],
"score_sum": item["hallucination_probability"],
"type_score_sums": defaultdict(float),
}
for label, prob in item["type_probabilities"].items():
cur["type_score_sums"][label] += float(prob)
continue
if item["start"] <= cur["end"] + merge_gap_chars:
cur["end"] = max(cur["end"], item["end"])
cur["token_count"] += 1
cur["score"] = max(cur["score"], item["hallucination_probability"])
cur["score_sum"] += item["hallucination_probability"]
for label, prob in item["type_probabilities"].items():
cur["type_score_sums"][label] += float(prob)
else:
merged.append(cur)
cur = {
"start": item["start"],
"end": item["end"],
"token_count": 1,
"score": item["hallucination_probability"],
"score_sum": item["hallucination_probability"],
"type_score_sums": defaultdict(float),
}
for label, prob in item["type_probabilities"].items():
cur["type_score_sums"][label] += float(prob)
if cur is not None:
merged.append(cur)
spans = []
for span in merged:
start, end = strip_span_whitespace(answer, span["start"], span["end"])
if end <= start:
continue
if (end - start) < min_span_chars:
continue
if span["token_count"] < min_span_tokens:
continue
text = answer[start:end]
if config.get("drop_spans_without_alnum", True) and not has_alnum(text):
continue
type_scores = {
label: safe_divide(value, span["token_count"])
for label, value in span["type_score_sums"].items()
}
pred_type = max(type_scores.items(), key=lambda x: x[1])[0] if type_scores else "overgeneration"
spans.append({
"start": int(start),
"end": int(end),
"text": text,
"label_type": "Predicted Hallucination",
"hallucination_type": pred_type,
"score": float(span["score"]),
"mean_score": float(safe_divide(span["score_sum"], span["token_count"])),
"type_scores": type_scores,
"token_count": int(span["token_count"]),
})
return spans
def predict_record(record, model, tokenizer, device=None, decoding_config=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if decoding_config is None:
decoding_config = DEFAULT_DECODING_CONFIG
model.to(device)
model.eval()
answer = record.get("output", "")
candidates = []
id2label = getattr(model.config, "id2label", None) or {i: label for i, label in enumerate(SPAN_LABELS)}
id2label = {int(k): v for k, v in id2label.items()}
for window in make_answer_windows(answer, tokenizer):
text, answer_window_text, local_answer_start, local_answer_end = build_window_text(record, window)
enc = tokenizer(
text,
truncation=True,
max_length=MAX_LENGTH,
padding=False,
return_offsets_mapping=True,
return_tensors="pt",
add_special_tokens=True,
)
offsets = enc.pop("offset_mapping")[0].tolist()
inputs = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
logits = model(**inputs).logits[0]
probs = torch.softmax(logits.float(), dim=-1).detach().cpu()
for token_idx, (start, end) in enumerate(offsets):
start = int(start)
end = int(end)
if end <= start:
continue
if not (end > local_answer_start and start < local_answer_end):
continue
local_start = max(0, start - local_answer_start)
local_end = min(local_answer_end - local_answer_start, end - local_answer_start)
if local_end <= local_start:
continue
global_start = window["answer_window_start"] + local_start
global_end = window["answer_window_start"] + local_end
token_probs = probs[token_idx]
non_o_probs = token_probs[1:]
hallucination_probability = float(non_o_probs.sum().item())
best_non_o_id = int(torch.argmax(non_o_probs).item()) + 1
candidates.append({
"start": int(global_start),
"end": int(global_end),
"text": answer[global_start:global_end],
"hallucination_probability": hallucination_probability,
"pred_label": id2label.get(best_non_o_id, str(best_non_o_id)),
"type_probabilities": {
id2label.get(i, str(i)): float(token_probs[i].item())
for i in range(1, token_probs.shape[-1])
},
"o_probability": float(token_probs[0].item()),
"window_index": window["window_index"],
})
spans = merge_candidates_to_spans(answer, candidates, decoding_config)
if spans:
best_span = max(spans, key=lambda x: (x.get("score", 0.0), x["end"] - x["start"]))
row_label = best_span.get("hallucination_type", "overgeneration")
else:
row_label = "clean"
return {
"pred_hallucination_type": row_label,
"pred_is_hallucinated": bool(spans),
"predicted_spans": spans,
"num_token_candidates": len(candidates),
"decoding_config": decoding_config,
}