# span_inference.py # Inference helper for the DeBERTa-v3-small span-token hallucination detector. from collections import defaultdict import torch MAX_LENGTH = 512 ANSWER_WINDOW_TOKENS = 384 ANSWER_WINDOW_STRIDE = 256 DEFAULT_DECODING_CONFIG = { "threshold": 0.5, "min_span_chars": 1, "min_span_tokens": 1, "merge_gap_chars": 1, "strip_predicted_span_whitespace": True, "drop_spans_without_alnum": True, "score_name": "sum_non_O_probability", } SPAN_LABELS = [ "O", "tool_output_conflict", "overgeneration", "missing_tool_action_recommendation", ] def normalize_tool_names(names, max_names=80): if not isinstance(names, list): return "" out = [] seen = set() for item in names: if isinstance(item, str): item = " ".join(item.split()).strip() if item and item.lower() not in seen: out.append(item) seen.add(item.lower()) if len(out) >= max_names: break return "; ".join(out) def has_alnum(text): return any(ch.isalnum() for ch in text) def safe_divide(num, den): return float(num) / float(den) if den else 0.0 def answer_token_offsets(answer, tokenizer): enc = tokenizer( answer, add_special_tokens=False, return_offsets_mapping=True, truncation=False, ) offsets = [] for i, (start, end) in enumerate(enc["offset_mapping"]): start = int(start) end = int(end) if end > start: offsets.append({"answer_token_index": i, "start": start, "end": end}) return offsets def make_answer_windows(answer, tokenizer): offsets = answer_token_offsets(answer, tokenizer) if not offsets: return [{ "window_index": 0, "answer_window_start": 0, "answer_window_end": len(answer), }] n = len(offsets) if n <= ANSWER_WINDOW_TOKENS: token_ranges = [(0, n)] else: starts = list(range(0, n, ANSWER_WINDOW_STRIDE)) final_start = max(0, n - ANSWER_WINDOW_TOKENS) if starts[-1] != final_start: starts.append(final_start) starts = sorted(set(starts)) token_ranges = [(s, min(s + ANSWER_WINDOW_TOKENS, n)) for s in starts] windows = [] for wi, (tok_start, tok_end) in enumerate(token_ranges): char_start = offsets[tok_start]["start"] char_end = offsets[tok_end - 1]["end"] windows.append({ "window_index": wi, "answer_window_start": int(char_start), "answer_window_end": int(char_end), }) return windows def build_window_text(record, window): answer = record.get("output", "") query = record.get("query", "") context = record.get("context", "") tool_names = normalize_tool_names(record.get("available_tool_names", [])) answer_window_text = answer[window["answer_window_start"]:window["answer_window_end"]] prefix = "Answer window:\n" text = prefix + answer_window_text text += "\n\nQuestion:\n" + query text += "\n\nTool responses:\n" + context if tool_names: text += "\n\nAvailable tool names:\n" + tool_names return text, answer_window_text, len(prefix), len(prefix) + len(answer_window_text) def deduplicate_candidates(candidates): by_key = {} for cand in candidates: key = (int(cand["start"]), int(cand["end"])) if key not in by_key: by_key[key] = dict(cand) continue old = by_key[key] if cand["hallucination_probability"] > old["hallucination_probability"]: by_key[key] = dict(cand) else: old["hallucination_probability"] = max( old["hallucination_probability"], cand["hallucination_probability"], ) for label, prob in cand["type_probabilities"].items(): old["type_probabilities"][label] = max(old["type_probabilities"].get(label, 0.0), prob) return sorted(by_key.values(), key=lambda x: (x["start"], x["end"])) def strip_span_whitespace(answer, start, end): start = max(0, min(int(start), len(answer))) end = max(0, min(int(end), len(answer))) while start < end and answer[start].isspace(): start += 1 while end > start and answer[end - 1].isspace(): end -= 1 return start, end def merge_candidates_to_spans(answer, candidates, config): threshold = float(config.get("threshold", DEFAULT_DECODING_CONFIG["threshold"])) min_span_chars = int(config.get("min_span_chars", DEFAULT_DECODING_CONFIG["min_span_chars"])) min_span_tokens = int(config.get("min_span_tokens", DEFAULT_DECODING_CONFIG["min_span_tokens"])) merge_gap_chars = int(config.get("merge_gap_chars", DEFAULT_DECODING_CONFIG["merge_gap_chars"])) selected = [ c for c in deduplicate_candidates(candidates) if c["hallucination_probability"] >= threshold and c["end"] > c["start"] ] selected.sort(key=lambda x: (x["start"], x["end"])) merged = [] cur = None for item in selected: if cur is None: cur = { "start": item["start"], "end": item["end"], "token_count": 1, "score": item["hallucination_probability"], "score_sum": item["hallucination_probability"], "type_score_sums": defaultdict(float), } for label, prob in item["type_probabilities"].items(): cur["type_score_sums"][label] += float(prob) continue if item["start"] <= cur["end"] + merge_gap_chars: cur["end"] = max(cur["end"], item["end"]) cur["token_count"] += 1 cur["score"] = max(cur["score"], item["hallucination_probability"]) cur["score_sum"] += item["hallucination_probability"] for label, prob in item["type_probabilities"].items(): cur["type_score_sums"][label] += float(prob) else: merged.append(cur) cur = { "start": item["start"], "end": item["end"], "token_count": 1, "score": item["hallucination_probability"], "score_sum": item["hallucination_probability"], "type_score_sums": defaultdict(float), } for label, prob in item["type_probabilities"].items(): cur["type_score_sums"][label] += float(prob) if cur is not None: merged.append(cur) spans = [] for span in merged: start, end = strip_span_whitespace(answer, span["start"], span["end"]) if end <= start: continue if (end - start) < min_span_chars: continue if span["token_count"] < min_span_tokens: continue text = answer[start:end] if config.get("drop_spans_without_alnum", True) and not has_alnum(text): continue type_scores = { label: safe_divide(value, span["token_count"]) for label, value in span["type_score_sums"].items() } pred_type = max(type_scores.items(), key=lambda x: x[1])[0] if type_scores else "overgeneration" spans.append({ "start": int(start), "end": int(end), "text": text, "label_type": "Predicted Hallucination", "hallucination_type": pred_type, "score": float(span["score"]), "mean_score": float(safe_divide(span["score_sum"], span["token_count"])), "type_scores": type_scores, "token_count": int(span["token_count"]), }) return spans def predict_record(record, model, tokenizer, device=None, decoding_config=None): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if decoding_config is None: decoding_config = DEFAULT_DECODING_CONFIG model.to(device) model.eval() answer = record.get("output", "") candidates = [] id2label = getattr(model.config, "id2label", None) or {i: label for i, label in enumerate(SPAN_LABELS)} id2label = {int(k): v for k, v in id2label.items()} for window in make_answer_windows(answer, tokenizer): text, answer_window_text, local_answer_start, local_answer_end = build_window_text(record, window) enc = tokenizer( text, truncation=True, max_length=MAX_LENGTH, padding=False, return_offsets_mapping=True, return_tensors="pt", add_special_tokens=True, ) offsets = enc.pop("offset_mapping")[0].tolist() inputs = {k: v.to(device) for k, v in enc.items()} with torch.no_grad(): logits = model(**inputs).logits[0] probs = torch.softmax(logits.float(), dim=-1).detach().cpu() for token_idx, (start, end) in enumerate(offsets): start = int(start) end = int(end) if end <= start: continue if not (end > local_answer_start and start < local_answer_end): continue local_start = max(0, start - local_answer_start) local_end = min(local_answer_end - local_answer_start, end - local_answer_start) if local_end <= local_start: continue global_start = window["answer_window_start"] + local_start global_end = window["answer_window_start"] + local_end token_probs = probs[token_idx] non_o_probs = token_probs[1:] hallucination_probability = float(non_o_probs.sum().item()) best_non_o_id = int(torch.argmax(non_o_probs).item()) + 1 candidates.append({ "start": int(global_start), "end": int(global_end), "text": answer[global_start:global_end], "hallucination_probability": hallucination_probability, "pred_label": id2label.get(best_non_o_id, str(best_non_o_id)), "type_probabilities": { id2label.get(i, str(i)): float(token_probs[i].item()) for i in range(1, token_probs.shape[-1]) }, "o_probability": float(token_probs[0].item()), "window_index": window["window_index"], }) spans = merge_candidates_to_spans(answer, candidates, decoding_config) if spans: best_span = max(spans, key=lambda x: (x.get("score", 0.0), x["end"] - x["start"])) row_label = best_span.get("hallucination_type", "overgeneration") else: row_label = "clean" return { "pred_hallucination_type": row_label, "pred_is_hallucinated": bool(spans), "predicted_spans": spans, "num_token_candidates": len(candidates), "decoding_config": decoding_config, }