File size: 10,808 Bytes
07660e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
"""
Downloads ToolBench conversation data, constructs proxy anomaly labels,
and saves processed splits to data/processed/.

IMPORTANT: ToolBench stores conversations as a dict of two parallel lists:
    {"from": ["system", "user", "gpt", ...], "value": ["...", "...", "...", ...]}
NOT as a list of dicts. This script handles that format.

Proxy labeling strategy:
  - Look at the LAST assistant message in each conversation.
  - If it contains failure indicators β†’ label = 1 (anomalous).
  - Zero tool calls (Action: ...) in the trace β†’ label = 1 (anomalous).
  - Otherwise β†’ label = 0 (normal).

Source: https://huggingface.co/datasets/tuandunghcmut/toolbench-v1
Paper:  Qin et al., "ToolLLM: Facilitating Large Language Models to Master 16000+ Real-world APIs", ICLR 2024.
"""

import argparse
import json
import os
import re
import sys
from pathlib import Path

import pandas as pd
from datasets import load_dataset
from sklearn.model_selection import train_test_split


# ── failure indicators in the last assistant turn ──────────────────────────
FAILURE_PATTERNS = [
    r"i cannot",
    r"i can't",
    r"i'm sorry",
    r"failed",
    r"unable to",
    r"error occurred",
    r"apologize",
    r"unfortunately",
    r"not possible",
    r"no results",
    r"couldn't find",
    r"don't have access",
    r"not available",
    r"give up",
    r"i will stop",
    r"give_up_and_restart",
]

FAILURE_RE = re.compile("|".join(FAILURE_PATTERNS), re.IGNORECASE)

# regex to detect Action: lines inside gpt turns (ReAct format used by ToolBench)
ACTION_RE = re.compile(r"Action\s*:\s*(.+)", re.IGNORECASE)


def normalize_conversations(conv) -> list[dict]:
    """
    Convert ToolBench conversation format into a flat list of message dicts.

    ToolBench stores conversations as:
        {"from": ["system", "user", "gpt", ...], "value": ["...", "...", "...", ...]}

    This function converts to:
        [{"from": "system", "value": "..."}, {"from": "user", "value": "..."}, ...]
    """
    if isinstance(conv, list):
        # already a list β€” normalize each element
        result = []
        for item in conv:
            if isinstance(item, dict) and "from" in item and "value" in item:
                # already correct format
                if isinstance(item["from"], str):
                    result.append(item)
                else:
                    # nested parallel lists somehow
                    pass
            elif isinstance(item, str):
                try:
                    parsed = json.loads(item)
                    if isinstance(parsed, dict):
                        result.append(parsed)
                    else:
                        result.append({"from": "unknown", "value": item})
                except (json.JSONDecodeError, TypeError):
                    result.append({"from": "unknown", "value": item})
            else:
                result.append({"from": "unknown", "value": str(item)})
        return result

    if isinstance(conv, dict):
        # ToolBench parallel-list format: {"from": [...], "value": [...]}
        froms = conv.get("from", [])
        values = conv.get("value", [])

        if isinstance(froms, list) and isinstance(values, list):
            return [
                {"from": str(f), "value": str(v) if v is not None else ""}
                for f, v in zip(froms, values)
            ]

        # single message dict
        if isinstance(froms, str):
            return [{"from": froms, "value": str(values) if values else ""}]

    # fallback
    return []


def extract_actions_from_text(text: str) -> list[str]:
    """
    Extract tool call (Action) names from a gpt turn using ReAct format.
    Ignores 'Finish' action (end-of-task marker).
    """
    actions = ACTION_RE.findall(text)
    cleaned = []
    for a in actions:
        a = a.strip()
        if a.lower() not in ("finish", "none", ""):
            cleaned.append(a)
    return cleaned


def parse_conversation(conv) -> dict:
    """
    Parse a ToolBench conversation into a structured trace dict.
    """
    messages = normalize_conversations(conv)

    turns = []
    tool_calls = []
    observations = []
    assistant_turns = []
    system_prompt = ""
    user_query = ""

    for msg in messages:
        role = msg.get("from", "unknown")
        content = msg.get("value", "")
        if content is None:
            content = ""

        turns.append((role, content))

        if role == "system":
            system_prompt = content
        elif role in ("human", "user") and not user_query:
            user_query = content
        elif role in ("gpt", "assistant", "chatgpt"):
            assistant_turns.append(content)
            # ToolBench embeds tool calls inside gpt turns as "Action: api_name"
            actions = extract_actions_from_text(content)
            tool_calls.extend(actions)
        elif role in ("function_call", "tool_call"):
            tool_calls.append(content)
        elif role in ("observation", "tool_response", "function"):
            observations.append(content)

    return {
        "turns": turns,
        "tool_calls": tool_calls,
        "observations": observations,
        "assistant_turns": assistant_turns,
        "system_prompt": system_prompt,
        "user_query": user_query,
    }


def label_trace(parsed: dict) -> int:
    """
    Assign a proxy anomaly label.
    Returns 1 (anomalous) if failure signals present, 0 otherwise.
    """
    if not parsed["assistant_turns"]:
        return 1

    last_assistant = parsed["assistant_turns"][-1].lower()

    if FAILURE_RE.search(last_assistant):
        return 1

    if len(parsed["tool_calls"]) == 0:
        return 1

    return 0


def extract_raw_trace_text(conv) -> str:
    """Flatten conversation into a single text string for DL models."""
    messages = normalize_conversations(conv)
    parts = []
    for msg in messages:
        role = msg.get("from", "unknown")
        content = msg.get("value", "")
        if content is None:
            content = ""
        parts.append(f"[{role.upper()}] {content}")
    return "\n".join(parts)


def process_dataset(max_samples: int = None) -> pd.DataFrame:
    """Load ToolBench, parse traces, assign labels, return DataFrame."""
    print("[INFO] Loading ToolBench dataset (default config)...")

    try:
        ds = load_dataset("tuandunghcmut/toolbench-v1", "default", split="train")
    except Exception:
        ds = load_dataset("tuandunghcmut/toolbench-v1", split="train")

    print(f"[INFO] Loaded {len(ds)} raw samples.")

    if max_samples and max_samples < len(ds):
        ds = ds.shuffle(seed=42).select(range(max_samples))
        print(f"[INFO] Subsampled to {max_samples} samples.")

    # debug: peek at first sample structure
    first = ds[0]
    conv_raw = first.get("conversations", {})
    print(f"[DEBUG] conversations type: {type(conv_raw)}")
    if isinstance(conv_raw, dict):
        print(f"[DEBUG] conversations keys: {list(conv_raw.keys())}")
        for k, v in conv_raw.items():
            print(f"[DEBUG]   {k}: type={type(v)}, len={len(v) if isinstance(v, list) else 'N/A'}")
            if isinstance(v, list) and len(v) > 0:
                print(f"[DEBUG]   {k}[0]: {str(v[0])[:120]}")
    elif isinstance(conv_raw, list):
        print(f"[DEBUG] conversations is list, len={len(conv_raw)}")
        if len(conv_raw) > 0:
            print(f"[DEBUG]   [0] type={type(conv_raw[0])}, preview={str(conv_raw[0])[:120]}")

    records = []
    skipped = 0
    for idx, example in enumerate(ds):
        conv = example.get("conversations", {})
        messages = normalize_conversations(conv)

        if not messages:
            skipped += 1
            continue

        try:
            parsed = parse_conversation(conv)
            label = label_trace(parsed)
            raw_text = extract_raw_trace_text(conv)
        except Exception as e:
            skipped += 1
            if skipped <= 5:
                print(f"[WARN] Skipping sample {idx}: {e}")
            continue

        records.append({
            "id": example.get("id", str(idx)),
            "user_query": parsed["user_query"][:500],
            "num_turns": len(parsed["turns"]),
            "num_tool_calls": len(parsed["tool_calls"]),
            "num_observations": len(parsed["observations"]),
            "num_assistant_turns": len(parsed["assistant_turns"]),
            "raw_trace": raw_text,
            "conversations_json": json.dumps(messages),  # save as list-of-dicts
            "label": label,
        })

    if skipped > 0:
        print(f"[WARN] Skipped {skipped} malformed samples.")

    df = pd.DataFrame(records)
    print(f"[INFO] Processed {len(df)} traces.")
    print(f"[INFO] Label distribution:\n{df['label'].value_counts().to_string()}")
    print(f"[INFO] Anomaly rate: {df['label'].mean():.2%}")
    print(f"[INFO] Tool calls stats:\n{df['num_tool_calls'].describe()}")

    return df


def split_and_save(df, output_dir, test_size=0.15, val_size=0.15, seed=42):
    """Stratified train/val/test split. Saves as parquet."""
    os.makedirs(output_dir, exist_ok=True)

    if df["label"].nunique() < 2:
        print("[WARN] Only one class found. Using random splits.")
        train_val, test = train_test_split(df, test_size=test_size, random_state=seed)
        relative_val = val_size / (1 - test_size)
        train, val = train_test_split(train_val, test_size=relative_val, random_state=seed)
    else:
        train_val, test = train_test_split(
            df, test_size=test_size, random_state=seed, stratify=df["label"]
        )
        relative_val = val_size / (1 - test_size)
        train, val = train_test_split(
            train_val, test_size=relative_val, random_state=seed, stratify=train_val["label"]
        )

    for name, split_df in [("train", train), ("val", val), ("test", test)]:
        path = os.path.join(output_dir, f"{name}.parquet")
        split_df.to_parquet(path, index=False)
        print(f"[INFO] Saved {name}: {len(split_df)} samples β†’ {path}")
        print(f"       Label dist: {dict(split_df['label'].value_counts())}")


def main():
    parser = argparse.ArgumentParser(description="Build ToolBench anomaly detection dataset")
    parser.add_argument("--max_samples", type=int, default=None)
    parser.add_argument("--test_size", type=float, default=0.15)
    parser.add_argument("--val_size", type=float, default=0.15)
    parser.add_argument("--output_dir", type=str, default="data/processed")
    args = parser.parse_args()

    df = process_dataset(max_samples=args.max_samples)
    split_and_save(df, args.output_dir, args.test_size, args.val_size)
    print("[DONE] Dataset ready.")


if __name__ == "__main__":
    main()