agentsight-api / src /data /preprocessor.py
Minato Namikaze
Deploy to Hugging Face Spaces
2aed081
Raw
History Blame Contribute Delete
4.37 kB
import json
from transformers import AutoTokenizer
import torch
class StepPreprocessor:
def __init__(self, model_name="microsoft/deberta-v3-base", max_len=512):
# use_fast=False is required for DeBERTa-v3 to avoid spm.model parsing errors on Linux
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
self.max_len = max_len
def encode_step(self, query, step):
"""
Encodes a single step as a concatenated string:
[QUERY] ... [THOUGHT] ... [ACTION] ... [OBS] ...
This is the "single-channel concatenated" representation where the
model learns to attend across all information jointly.
"""
# Channel 1 β€” Reasoning trace (the LLM's textual thought / plan)
thought = step.get("content", "") or ""
# Channel 2 β€” Tool context (action + observation)
tool_calls = step.get("tool_calls") or []
tool_responses = step.get("tool_responses") or []
actions_str = ""
for tc in tool_calls:
name = tc.get("name", "")
args = json.dumps(tc.get("arguments", {}))
actions_str += f"{name} {args} | "
observations_str = " | ".join(str(tr) for tr in tool_responses) if tool_responses else ""
tool_context = f"[ACTION] {actions_str.strip()} [OBS] {observations_str.strip()}"
# Channel 3 β€” Query context (the original question / task)
if isinstance(query, list):
flat_query = " ".join(q[0] if isinstance(q, list) else str(q) for q in query)
else:
flat_query = str(query) if query else ""
# Truncation strategy: keep the query short but preserve tool context
# because that is the dominant signal (ablation: -28.8 pt without it).
# We truncate THOUGHT in the middle if necessary.
full_text = f"[QUERY] {flat_query} [THOUGHT] {thought} {tool_context}"
encoded = self.tokenizer(
full_text,
max_length=self.max_len,
padding="max_length",
truncation=True,
return_tensors="pt",
)
return encoded
def encode_trajectory(self, sample):
"""
Returns a list of dicts, one per step, each containing:
- encoding: tokenized input (input_ids + attention_mask)
- label: 1 if this step is the hallucination step, else 0
- step_idx: the integer step index from the JSON
"""
# Support both 'trajectory' and 'history' JSON keys
trajectory = sample.get("trajectory") or sample.get("history") or []
is_hal = sample.get("is_hallucination", False)
if isinstance(is_hal, str):
is_hal = is_hal.lower() == "true"
hal_step = sample.get("hallucination_step")
if hal_step is not None:
hal_step = int(hal_step)
query = sample.get("question", sample.get("query", ""))
steps = []
for step in trajectory:
enc = self.encode_step(query, step)
step_idx = step.get("step")
# A step is labelled positive only if the trajectory is hallucinated
# AND this is exactly the annotated root-cause step.
label = 1 if (is_hal and step_idx == hal_step) else 0
steps.append({
"encoding": enc,
"label": label,
"step_idx": step_idx,
})
return steps
# ── quick smoke-test ─────────────────────────────────────────────────────────
if __name__ == "__main__":
import os
print("Testing StepPreprocessor (will download tokenizer if not cached)…")
pre = StepPreprocessor()
script_dir = os.path.dirname(os.path.abspath(__file__))
sample_file = os.path.join(script_dir, "..", "..", "data", "splits", "train.json")
if os.path.exists(sample_file):
with open(sample_file, "r", encoding="utf-8") as fh:
samples = json.load(fh)
sample = samples[0]
encoded = pre.encode_trajectory(sample)
print(f"OK β€” encoded {len(encoded)} steps.")
print(f"input_ids shape: {encoded[0]['encoding']['input_ids'].shape}")
else:
print(f"File not found: {sample_file}")