Token Classification
Transformers
Safetensors
English
deberta-v2
hallucination-detection
span-detection
tool-use
deberta-v3
ragtruth
Instructions to use Ali-Bhai/deberta-tool-hallucination-span-detector with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Ali-Bhai/deberta-tool-hallucination-span-detector with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="Ali-Bhai/deberta-tool-hallucination-span-detector")# Load model directly from transformers import AutoTokenizer, AutoModelForTokenClassification tokenizer = AutoTokenizer.from_pretrained("Ali-Bhai/deberta-tool-hallucination-span-detector") model = AutoModelForTokenClassification.from_pretrained("Ali-Bhai/deberta-tool-hallucination-span-detector") - Notebooks
- Google Colab
- Kaggle
| # span_inference.py | |
| # Inference helper for the DeBERTa-v3-small span-token hallucination detector. | |
| from collections import defaultdict | |
| import torch | |
| MAX_LENGTH = 512 | |
| ANSWER_WINDOW_TOKENS = 384 | |
| ANSWER_WINDOW_STRIDE = 256 | |
| DEFAULT_DECODING_CONFIG = { | |
| "threshold": 0.5, | |
| "min_span_chars": 1, | |
| "min_span_tokens": 1, | |
| "merge_gap_chars": 1, | |
| "strip_predicted_span_whitespace": True, | |
| "drop_spans_without_alnum": True, | |
| "score_name": "sum_non_O_probability", | |
| } | |
| SPAN_LABELS = [ | |
| "O", | |
| "tool_output_conflict", | |
| "overgeneration", | |
| "missing_tool_action_recommendation", | |
| ] | |
| def normalize_tool_names(names, max_names=80): | |
| if not isinstance(names, list): | |
| return "" | |
| out = [] | |
| seen = set() | |
| for item in names: | |
| if isinstance(item, str): | |
| item = " ".join(item.split()).strip() | |
| if item and item.lower() not in seen: | |
| out.append(item) | |
| seen.add(item.lower()) | |
| if len(out) >= max_names: | |
| break | |
| return "; ".join(out) | |
| def has_alnum(text): | |
| return any(ch.isalnum() for ch in text) | |
| def safe_divide(num, den): | |
| return float(num) / float(den) if den else 0.0 | |
| def answer_token_offsets(answer, tokenizer): | |
| enc = tokenizer( | |
| answer, | |
| add_special_tokens=False, | |
| return_offsets_mapping=True, | |
| truncation=False, | |
| ) | |
| offsets = [] | |
| for i, (start, end) in enumerate(enc["offset_mapping"]): | |
| start = int(start) | |
| end = int(end) | |
| if end > start: | |
| offsets.append({"answer_token_index": i, "start": start, "end": end}) | |
| return offsets | |
| def make_answer_windows(answer, tokenizer): | |
| offsets = answer_token_offsets(answer, tokenizer) | |
| if not offsets: | |
| return [{ | |
| "window_index": 0, | |
| "answer_window_start": 0, | |
| "answer_window_end": len(answer), | |
| }] | |
| n = len(offsets) | |
| if n <= ANSWER_WINDOW_TOKENS: | |
| token_ranges = [(0, n)] | |
| else: | |
| starts = list(range(0, n, ANSWER_WINDOW_STRIDE)) | |
| final_start = max(0, n - ANSWER_WINDOW_TOKENS) | |
| if starts[-1] != final_start: | |
| starts.append(final_start) | |
| starts = sorted(set(starts)) | |
| token_ranges = [(s, min(s + ANSWER_WINDOW_TOKENS, n)) for s in starts] | |
| windows = [] | |
| for wi, (tok_start, tok_end) in enumerate(token_ranges): | |
| char_start = offsets[tok_start]["start"] | |
| char_end = offsets[tok_end - 1]["end"] | |
| windows.append({ | |
| "window_index": wi, | |
| "answer_window_start": int(char_start), | |
| "answer_window_end": int(char_end), | |
| }) | |
| return windows | |
| def build_window_text(record, window): | |
| answer = record.get("output", "") | |
| query = record.get("query", "") | |
| context = record.get("context", "") | |
| tool_names = normalize_tool_names(record.get("available_tool_names", [])) | |
| answer_window_text = answer[window["answer_window_start"]:window["answer_window_end"]] | |
| prefix = "Answer window:\n" | |
| text = prefix + answer_window_text | |
| text += "\n\nQuestion:\n" + query | |
| text += "\n\nTool responses:\n" + context | |
| if tool_names: | |
| text += "\n\nAvailable tool names:\n" + tool_names | |
| return text, answer_window_text, len(prefix), len(prefix) + len(answer_window_text) | |
| def deduplicate_candidates(candidates): | |
| by_key = {} | |
| for cand in candidates: | |
| key = (int(cand["start"]), int(cand["end"])) | |
| if key not in by_key: | |
| by_key[key] = dict(cand) | |
| continue | |
| old = by_key[key] | |
| if cand["hallucination_probability"] > old["hallucination_probability"]: | |
| by_key[key] = dict(cand) | |
| else: | |
| old["hallucination_probability"] = max( | |
| old["hallucination_probability"], | |
| cand["hallucination_probability"], | |
| ) | |
| for label, prob in cand["type_probabilities"].items(): | |
| old["type_probabilities"][label] = max(old["type_probabilities"].get(label, 0.0), prob) | |
| return sorted(by_key.values(), key=lambda x: (x["start"], x["end"])) | |
| def strip_span_whitespace(answer, start, end): | |
| start = max(0, min(int(start), len(answer))) | |
| end = max(0, min(int(end), len(answer))) | |
| while start < end and answer[start].isspace(): | |
| start += 1 | |
| while end > start and answer[end - 1].isspace(): | |
| end -= 1 | |
| return start, end | |
| def merge_candidates_to_spans(answer, candidates, config): | |
| threshold = float(config.get("threshold", DEFAULT_DECODING_CONFIG["threshold"])) | |
| min_span_chars = int(config.get("min_span_chars", DEFAULT_DECODING_CONFIG["min_span_chars"])) | |
| min_span_tokens = int(config.get("min_span_tokens", DEFAULT_DECODING_CONFIG["min_span_tokens"])) | |
| merge_gap_chars = int(config.get("merge_gap_chars", DEFAULT_DECODING_CONFIG["merge_gap_chars"])) | |
| selected = [ | |
| c for c in deduplicate_candidates(candidates) | |
| if c["hallucination_probability"] >= threshold and c["end"] > c["start"] | |
| ] | |
| selected.sort(key=lambda x: (x["start"], x["end"])) | |
| merged = [] | |
| cur = None | |
| for item in selected: | |
| if cur is None: | |
| cur = { | |
| "start": item["start"], | |
| "end": item["end"], | |
| "token_count": 1, | |
| "score": item["hallucination_probability"], | |
| "score_sum": item["hallucination_probability"], | |
| "type_score_sums": defaultdict(float), | |
| } | |
| for label, prob in item["type_probabilities"].items(): | |
| cur["type_score_sums"][label] += float(prob) | |
| continue | |
| if item["start"] <= cur["end"] + merge_gap_chars: | |
| cur["end"] = max(cur["end"], item["end"]) | |
| cur["token_count"] += 1 | |
| cur["score"] = max(cur["score"], item["hallucination_probability"]) | |
| cur["score_sum"] += item["hallucination_probability"] | |
| for label, prob in item["type_probabilities"].items(): | |
| cur["type_score_sums"][label] += float(prob) | |
| else: | |
| merged.append(cur) | |
| cur = { | |
| "start": item["start"], | |
| "end": item["end"], | |
| "token_count": 1, | |
| "score": item["hallucination_probability"], | |
| "score_sum": item["hallucination_probability"], | |
| "type_score_sums": defaultdict(float), | |
| } | |
| for label, prob in item["type_probabilities"].items(): | |
| cur["type_score_sums"][label] += float(prob) | |
| if cur is not None: | |
| merged.append(cur) | |
| spans = [] | |
| for span in merged: | |
| start, end = strip_span_whitespace(answer, span["start"], span["end"]) | |
| if end <= start: | |
| continue | |
| if (end - start) < min_span_chars: | |
| continue | |
| if span["token_count"] < min_span_tokens: | |
| continue | |
| text = answer[start:end] | |
| if config.get("drop_spans_without_alnum", True) and not has_alnum(text): | |
| continue | |
| type_scores = { | |
| label: safe_divide(value, span["token_count"]) | |
| for label, value in span["type_score_sums"].items() | |
| } | |
| pred_type = max(type_scores.items(), key=lambda x: x[1])[0] if type_scores else "overgeneration" | |
| spans.append({ | |
| "start": int(start), | |
| "end": int(end), | |
| "text": text, | |
| "label_type": "Predicted Hallucination", | |
| "hallucination_type": pred_type, | |
| "score": float(span["score"]), | |
| "mean_score": float(safe_divide(span["score_sum"], span["token_count"])), | |
| "type_scores": type_scores, | |
| "token_count": int(span["token_count"]), | |
| }) | |
| return spans | |
| def predict_record(record, model, tokenizer, device=None, decoding_config=None): | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if decoding_config is None: | |
| decoding_config = DEFAULT_DECODING_CONFIG | |
| model.to(device) | |
| model.eval() | |
| answer = record.get("output", "") | |
| candidates = [] | |
| id2label = getattr(model.config, "id2label", None) or {i: label for i, label in enumerate(SPAN_LABELS)} | |
| id2label = {int(k): v for k, v in id2label.items()} | |
| for window in make_answer_windows(answer, tokenizer): | |
| text, answer_window_text, local_answer_start, local_answer_end = build_window_text(record, window) | |
| enc = tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| padding=False, | |
| return_offsets_mapping=True, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| ) | |
| offsets = enc.pop("offset_mapping")[0].tolist() | |
| inputs = {k: v.to(device) for k, v in enc.items()} | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits[0] | |
| probs = torch.softmax(logits.float(), dim=-1).detach().cpu() | |
| for token_idx, (start, end) in enumerate(offsets): | |
| start = int(start) | |
| end = int(end) | |
| if end <= start: | |
| continue | |
| if not (end > local_answer_start and start < local_answer_end): | |
| continue | |
| local_start = max(0, start - local_answer_start) | |
| local_end = min(local_answer_end - local_answer_start, end - local_answer_start) | |
| if local_end <= local_start: | |
| continue | |
| global_start = window["answer_window_start"] + local_start | |
| global_end = window["answer_window_start"] + local_end | |
| token_probs = probs[token_idx] | |
| non_o_probs = token_probs[1:] | |
| hallucination_probability = float(non_o_probs.sum().item()) | |
| best_non_o_id = int(torch.argmax(non_o_probs).item()) + 1 | |
| candidates.append({ | |
| "start": int(global_start), | |
| "end": int(global_end), | |
| "text": answer[global_start:global_end], | |
| "hallucination_probability": hallucination_probability, | |
| "pred_label": id2label.get(best_non_o_id, str(best_non_o_id)), | |
| "type_probabilities": { | |
| id2label.get(i, str(i)): float(token_probs[i].item()) | |
| for i in range(1, token_probs.shape[-1]) | |
| }, | |
| "o_probability": float(token_probs[0].item()), | |
| "window_index": window["window_index"], | |
| }) | |
| spans = merge_candidates_to_spans(answer, candidates, decoding_config) | |
| if spans: | |
| best_span = max(spans, key=lambda x: (x.get("score", 0.0), x["end"] - x["start"])) | |
| row_label = best_span.get("hallucination_type", "overgeneration") | |
| else: | |
| row_label = "clean" | |
| return { | |
| "pred_hallucination_type": row_label, | |
| "pred_is_hallucinated": bool(spans), | |
| "predicted_spans": spans, | |
| "num_token_candidates": len(candidates), | |
| "decoding_config": decoding_config, | |
| } | |