File size: 4,368 Bytes
2aed081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import json
from transformers import AutoTokenizer
import torch


class StepPreprocessor:
    def __init__(self, model_name="microsoft/deberta-v3-base", max_len=512):
        # use_fast=False is required for DeBERTa-v3 to avoid spm.model parsing errors on Linux
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
        self.max_len = max_len

    def encode_step(self, query, step):
        """
        Encodes a single step as a concatenated string:
          [QUERY] ... [THOUGHT] ... [ACTION] ... [OBS] ...

        This is the "single-channel concatenated" representation where the
        model learns to attend across all information jointly.
        """
        # Channel 1 β€” Reasoning trace (the LLM's textual thought / plan)
        thought = step.get("content", "") or ""

        # Channel 2 β€” Tool context (action + observation)
        tool_calls = step.get("tool_calls") or []
        tool_responses = step.get("tool_responses") or []

        actions_str = ""
        for tc in tool_calls:
            name = tc.get("name", "")
            args = json.dumps(tc.get("arguments", {}))
            actions_str += f"{name} {args} | "

        observations_str = " | ".join(str(tr) for tr in tool_responses) if tool_responses else ""

        tool_context = f"[ACTION] {actions_str.strip()} [OBS] {observations_str.strip()}"

        # Channel 3 β€” Query context (the original question / task)
        if isinstance(query, list):
            flat_query = " ".join(q[0] if isinstance(q, list) else str(q) for q in query)
        else:
            flat_query = str(query) if query else ""

        # Truncation strategy: keep the query short but preserve tool context
        # because that is the dominant signal (ablation: -28.8 pt without it).
        # We truncate THOUGHT in the middle if necessary.
        full_text = f"[QUERY] {flat_query} [THOUGHT] {thought} {tool_context}"

        encoded = self.tokenizer(
            full_text,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        return encoded

    def encode_trajectory(self, sample):
        """
        Returns a list of dicts, one per step, each containing:
          - encoding: tokenized input (input_ids + attention_mask)
          - label:    1 if this step is the hallucination step, else 0
          - step_idx: the integer step index from the JSON
        """
        # Support both 'trajectory' and 'history' JSON keys
        trajectory = sample.get("trajectory") or sample.get("history") or []

        is_hal = sample.get("is_hallucination", False)
        if isinstance(is_hal, str):
            is_hal = is_hal.lower() == "true"

        hal_step = sample.get("hallucination_step")
        if hal_step is not None:
            hal_step = int(hal_step)

        query = sample.get("question", sample.get("query", ""))

        steps = []
        for step in trajectory:
            enc = self.encode_step(query, step)
            step_idx = step.get("step")

            # A step is labelled positive only if the trajectory is hallucinated
            # AND this is exactly the annotated root-cause step.
            label = 1 if (is_hal and step_idx == hal_step) else 0

            steps.append({
                "encoding": enc,
                "label": label,
                "step_idx": step_idx,
            })

        return steps


# ── quick smoke-test ─────────────────────────────────────────────────────────
if __name__ == "__main__":
    import os

    print("Testing StepPreprocessor (will download tokenizer if not cached)…")
    pre = StepPreprocessor()

    script_dir = os.path.dirname(os.path.abspath(__file__))
    sample_file = os.path.join(script_dir, "..", "..", "data", "splits", "train.json")

    if os.path.exists(sample_file):
        with open(sample_file, "r", encoding="utf-8") as fh:
            samples = json.load(fh)
        sample = samples[0]
        encoded = pre.encode_trajectory(sample)
        print(f"OK β€” encoded {len(encoded)} steps.")
        print(f"input_ids shape: {encoded[0]['encoding']['input_ids'].shape}")
    else:
        print(f"File not found: {sample_file}")