Spaces:
Running
Running
File size: 4,368 Bytes
2aed081 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | import json
from transformers import AutoTokenizer
import torch
class StepPreprocessor:
def __init__(self, model_name="microsoft/deberta-v3-base", max_len=512):
# use_fast=False is required for DeBERTa-v3 to avoid spm.model parsing errors on Linux
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
self.max_len = max_len
def encode_step(self, query, step):
"""
Encodes a single step as a concatenated string:
[QUERY] ... [THOUGHT] ... [ACTION] ... [OBS] ...
This is the "single-channel concatenated" representation where the
model learns to attend across all information jointly.
"""
# Channel 1 β Reasoning trace (the LLM's textual thought / plan)
thought = step.get("content", "") or ""
# Channel 2 β Tool context (action + observation)
tool_calls = step.get("tool_calls") or []
tool_responses = step.get("tool_responses") or []
actions_str = ""
for tc in tool_calls:
name = tc.get("name", "")
args = json.dumps(tc.get("arguments", {}))
actions_str += f"{name} {args} | "
observations_str = " | ".join(str(tr) for tr in tool_responses) if tool_responses else ""
tool_context = f"[ACTION] {actions_str.strip()} [OBS] {observations_str.strip()}"
# Channel 3 β Query context (the original question / task)
if isinstance(query, list):
flat_query = " ".join(q[0] if isinstance(q, list) else str(q) for q in query)
else:
flat_query = str(query) if query else ""
# Truncation strategy: keep the query short but preserve tool context
# because that is the dominant signal (ablation: -28.8 pt without it).
# We truncate THOUGHT in the middle if necessary.
full_text = f"[QUERY] {flat_query} [THOUGHT] {thought} {tool_context}"
encoded = self.tokenizer(
full_text,
max_length=self.max_len,
padding="max_length",
truncation=True,
return_tensors="pt",
)
return encoded
def encode_trajectory(self, sample):
"""
Returns a list of dicts, one per step, each containing:
- encoding: tokenized input (input_ids + attention_mask)
- label: 1 if this step is the hallucination step, else 0
- step_idx: the integer step index from the JSON
"""
# Support both 'trajectory' and 'history' JSON keys
trajectory = sample.get("trajectory") or sample.get("history") or []
is_hal = sample.get("is_hallucination", False)
if isinstance(is_hal, str):
is_hal = is_hal.lower() == "true"
hal_step = sample.get("hallucination_step")
if hal_step is not None:
hal_step = int(hal_step)
query = sample.get("question", sample.get("query", ""))
steps = []
for step in trajectory:
enc = self.encode_step(query, step)
step_idx = step.get("step")
# A step is labelled positive only if the trajectory is hallucinated
# AND this is exactly the annotated root-cause step.
label = 1 if (is_hal and step_idx == hal_step) else 0
steps.append({
"encoding": enc,
"label": label,
"step_idx": step_idx,
})
return steps
# ββ quick smoke-test βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
import os
print("Testing StepPreprocessor (will download tokenizer if not cached)β¦")
pre = StepPreprocessor()
script_dir = os.path.dirname(os.path.abspath(__file__))
sample_file = os.path.join(script_dir, "..", "..", "data", "splits", "train.json")
if os.path.exists(sample_file):
with open(sample_file, "r", encoding="utf-8") as fh:
samples = json.load(fh)
sample = samples[0]
encoded = pre.encode_trajectory(sample)
print(f"OK β encoded {len(encoded)} steps.")
print(f"input_ids shape: {encoded[0]['encoding']['input_ids'].shape}")
else:
print(f"File not found: {sample_file}")
|