Spaces:
Running
Running
| import json | |
| from transformers import AutoTokenizer | |
| import torch | |
| class StepPreprocessor: | |
| def __init__(self, model_name="microsoft/deberta-v3-base", max_len=512): | |
| # use_fast=False is required for DeBERTa-v3 to avoid spm.model parsing errors on Linux | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
| self.max_len = max_len | |
| def encode_step(self, query, step): | |
| """ | |
| Encodes a single step as a concatenated string: | |
| [QUERY] ... [THOUGHT] ... [ACTION] ... [OBS] ... | |
| This is the "single-channel concatenated" representation where the | |
| model learns to attend across all information jointly. | |
| """ | |
| # Channel 1 β Reasoning trace (the LLM's textual thought / plan) | |
| thought = step.get("content", "") or "" | |
| # Channel 2 β Tool context (action + observation) | |
| tool_calls = step.get("tool_calls") or [] | |
| tool_responses = step.get("tool_responses") or [] | |
| actions_str = "" | |
| for tc in tool_calls: | |
| name = tc.get("name", "") | |
| args = json.dumps(tc.get("arguments", {})) | |
| actions_str += f"{name} {args} | " | |
| observations_str = " | ".join(str(tr) for tr in tool_responses) if tool_responses else "" | |
| tool_context = f"[ACTION] {actions_str.strip()} [OBS] {observations_str.strip()}" | |
| # Channel 3 β Query context (the original question / task) | |
| if isinstance(query, list): | |
| flat_query = " ".join(q[0] if isinstance(q, list) else str(q) for q in query) | |
| else: | |
| flat_query = str(query) if query else "" | |
| # Truncation strategy: keep the query short but preserve tool context | |
| # because that is the dominant signal (ablation: -28.8 pt without it). | |
| # We truncate THOUGHT in the middle if necessary. | |
| full_text = f"[QUERY] {flat_query} [THOUGHT] {thought} {tool_context}" | |
| encoded = self.tokenizer( | |
| full_text, | |
| max_length=self.max_len, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| return encoded | |
| def encode_trajectory(self, sample): | |
| """ | |
| Returns a list of dicts, one per step, each containing: | |
| - encoding: tokenized input (input_ids + attention_mask) | |
| - label: 1 if this step is the hallucination step, else 0 | |
| - step_idx: the integer step index from the JSON | |
| """ | |
| # Support both 'trajectory' and 'history' JSON keys | |
| trajectory = sample.get("trajectory") or sample.get("history") or [] | |
| is_hal = sample.get("is_hallucination", False) | |
| if isinstance(is_hal, str): | |
| is_hal = is_hal.lower() == "true" | |
| hal_step = sample.get("hallucination_step") | |
| if hal_step is not None: | |
| hal_step = int(hal_step) | |
| query = sample.get("question", sample.get("query", "")) | |
| steps = [] | |
| for step in trajectory: | |
| enc = self.encode_step(query, step) | |
| step_idx = step.get("step") | |
| # A step is labelled positive only if the trajectory is hallucinated | |
| # AND this is exactly the annotated root-cause step. | |
| label = 1 if (is_hal and step_idx == hal_step) else 0 | |
| steps.append({ | |
| "encoding": enc, | |
| "label": label, | |
| "step_idx": step_idx, | |
| }) | |
| return steps | |
| # ββ quick smoke-test βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import os | |
| print("Testing StepPreprocessor (will download tokenizer if not cached)β¦") | |
| pre = StepPreprocessor() | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| sample_file = os.path.join(script_dir, "..", "..", "data", "splits", "train.json") | |
| if os.path.exists(sample_file): | |
| with open(sample_file, "r", encoding="utf-8") as fh: | |
| samples = json.load(fh) | |
| sample = samples[0] | |
| encoded = pre.encode_trajectory(sample) | |
| print(f"OK β encoded {len(encoded)} steps.") | |
| print(f"input_ids shape: {encoded[0]['encoding']['input_ids'].shape}") | |
| else: | |
| print(f"File not found: {sample_file}") | |