| """ |
| Downloads ToolBench conversation data, constructs proxy anomaly labels, |
| and saves processed splits to data/processed/. |
| |
| IMPORTANT: ToolBench stores conversations as a dict of two parallel lists: |
| {"from": ["system", "user", "gpt", ...], "value": ["...", "...", "...", ...]} |
| NOT as a list of dicts. This script handles that format. |
| |
| Proxy labeling strategy: |
| - Look at the LAST assistant message in each conversation. |
| - If it contains failure indicators β label = 1 (anomalous). |
| - Zero tool calls (Action: ...) in the trace β label = 1 (anomalous). |
| - Otherwise β label = 0 (normal). |
| |
| Source: https://huggingface.co/datasets/tuandunghcmut/toolbench-v1 |
| Paper: Qin et al., "ToolLLM: Facilitating Large Language Models to Master 16000+ Real-world APIs", ICLR 2024. |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import re |
| import sys |
| from pathlib import Path |
|
|
| import pandas as pd |
| from datasets import load_dataset |
| from sklearn.model_selection import train_test_split |
|
|
|
|
| |
| FAILURE_PATTERNS = [ |
| r"i cannot", |
| r"i can't", |
| r"i'm sorry", |
| r"failed", |
| r"unable to", |
| r"error occurred", |
| r"apologize", |
| r"unfortunately", |
| r"not possible", |
| r"no results", |
| r"couldn't find", |
| r"don't have access", |
| r"not available", |
| r"give up", |
| r"i will stop", |
| r"give_up_and_restart", |
| ] |
|
|
| FAILURE_RE = re.compile("|".join(FAILURE_PATTERNS), re.IGNORECASE) |
|
|
| |
| ACTION_RE = re.compile(r"Action\s*:\s*(.+)", re.IGNORECASE) |
|
|
|
|
| def normalize_conversations(conv) -> list[dict]: |
| """ |
| Convert ToolBench conversation format into a flat list of message dicts. |
| |
| ToolBench stores conversations as: |
| {"from": ["system", "user", "gpt", ...], "value": ["...", "...", "...", ...]} |
| |
| This function converts to: |
| [{"from": "system", "value": "..."}, {"from": "user", "value": "..."}, ...] |
| """ |
| if isinstance(conv, list): |
| |
| result = [] |
| for item in conv: |
| if isinstance(item, dict) and "from" in item and "value" in item: |
| |
| if isinstance(item["from"], str): |
| result.append(item) |
| else: |
| |
| pass |
| elif isinstance(item, str): |
| try: |
| parsed = json.loads(item) |
| if isinstance(parsed, dict): |
| result.append(parsed) |
| else: |
| result.append({"from": "unknown", "value": item}) |
| except (json.JSONDecodeError, TypeError): |
| result.append({"from": "unknown", "value": item}) |
| else: |
| result.append({"from": "unknown", "value": str(item)}) |
| return result |
|
|
| if isinstance(conv, dict): |
| |
| froms = conv.get("from", []) |
| values = conv.get("value", []) |
|
|
| if isinstance(froms, list) and isinstance(values, list): |
| return [ |
| {"from": str(f), "value": str(v) if v is not None else ""} |
| for f, v in zip(froms, values) |
| ] |
|
|
| |
| if isinstance(froms, str): |
| return [{"from": froms, "value": str(values) if values else ""}] |
|
|
| |
| return [] |
|
|
|
|
| def extract_actions_from_text(text: str) -> list[str]: |
| """ |
| Extract tool call (Action) names from a gpt turn using ReAct format. |
| Ignores 'Finish' action (end-of-task marker). |
| """ |
| actions = ACTION_RE.findall(text) |
| cleaned = [] |
| for a in actions: |
| a = a.strip() |
| if a.lower() not in ("finish", "none", ""): |
| cleaned.append(a) |
| return cleaned |
|
|
|
|
| def parse_conversation(conv) -> dict: |
| """ |
| Parse a ToolBench conversation into a structured trace dict. |
| """ |
| messages = normalize_conversations(conv) |
|
|
| turns = [] |
| tool_calls = [] |
| observations = [] |
| assistant_turns = [] |
| system_prompt = "" |
| user_query = "" |
|
|
| for msg in messages: |
| role = msg.get("from", "unknown") |
| content = msg.get("value", "") |
| if content is None: |
| content = "" |
|
|
| turns.append((role, content)) |
|
|
| if role == "system": |
| system_prompt = content |
| elif role in ("human", "user") and not user_query: |
| user_query = content |
| elif role in ("gpt", "assistant", "chatgpt"): |
| assistant_turns.append(content) |
| |
| actions = extract_actions_from_text(content) |
| tool_calls.extend(actions) |
| elif role in ("function_call", "tool_call"): |
| tool_calls.append(content) |
| elif role in ("observation", "tool_response", "function"): |
| observations.append(content) |
|
|
| return { |
| "turns": turns, |
| "tool_calls": tool_calls, |
| "observations": observations, |
| "assistant_turns": assistant_turns, |
| "system_prompt": system_prompt, |
| "user_query": user_query, |
| } |
|
|
|
|
| def label_trace(parsed: dict) -> int: |
| """ |
| Assign a proxy anomaly label. |
| Returns 1 (anomalous) if failure signals present, 0 otherwise. |
| """ |
| if not parsed["assistant_turns"]: |
| return 1 |
|
|
| last_assistant = parsed["assistant_turns"][-1].lower() |
|
|
| if FAILURE_RE.search(last_assistant): |
| return 1 |
|
|
| if len(parsed["tool_calls"]) == 0: |
| return 1 |
|
|
| return 0 |
|
|
|
|
| def extract_raw_trace_text(conv) -> str: |
| """Flatten conversation into a single text string for DL models.""" |
| messages = normalize_conversations(conv) |
| parts = [] |
| for msg in messages: |
| role = msg.get("from", "unknown") |
| content = msg.get("value", "") |
| if content is None: |
| content = "" |
| parts.append(f"[{role.upper()}] {content}") |
| return "\n".join(parts) |
|
|
|
|
| def process_dataset(max_samples: int = None) -> pd.DataFrame: |
| """Load ToolBench, parse traces, assign labels, return DataFrame.""" |
| print("[INFO] Loading ToolBench dataset (default config)...") |
|
|
| try: |
| ds = load_dataset("tuandunghcmut/toolbench-v1", "default", split="train") |
| except Exception: |
| ds = load_dataset("tuandunghcmut/toolbench-v1", split="train") |
|
|
| print(f"[INFO] Loaded {len(ds)} raw samples.") |
|
|
| if max_samples and max_samples < len(ds): |
| ds = ds.shuffle(seed=42).select(range(max_samples)) |
| print(f"[INFO] Subsampled to {max_samples} samples.") |
|
|
| |
| first = ds[0] |
| conv_raw = first.get("conversations", {}) |
| print(f"[DEBUG] conversations type: {type(conv_raw)}") |
| if isinstance(conv_raw, dict): |
| print(f"[DEBUG] conversations keys: {list(conv_raw.keys())}") |
| for k, v in conv_raw.items(): |
| print(f"[DEBUG] {k}: type={type(v)}, len={len(v) if isinstance(v, list) else 'N/A'}") |
| if isinstance(v, list) and len(v) > 0: |
| print(f"[DEBUG] {k}[0]: {str(v[0])[:120]}") |
| elif isinstance(conv_raw, list): |
| print(f"[DEBUG] conversations is list, len={len(conv_raw)}") |
| if len(conv_raw) > 0: |
| print(f"[DEBUG] [0] type={type(conv_raw[0])}, preview={str(conv_raw[0])[:120]}") |
|
|
| records = [] |
| skipped = 0 |
| for idx, example in enumerate(ds): |
| conv = example.get("conversations", {}) |
| messages = normalize_conversations(conv) |
|
|
| if not messages: |
| skipped += 1 |
| continue |
|
|
| try: |
| parsed = parse_conversation(conv) |
| label = label_trace(parsed) |
| raw_text = extract_raw_trace_text(conv) |
| except Exception as e: |
| skipped += 1 |
| if skipped <= 5: |
| print(f"[WARN] Skipping sample {idx}: {e}") |
| continue |
|
|
| records.append({ |
| "id": example.get("id", str(idx)), |
| "user_query": parsed["user_query"][:500], |
| "num_turns": len(parsed["turns"]), |
| "num_tool_calls": len(parsed["tool_calls"]), |
| "num_observations": len(parsed["observations"]), |
| "num_assistant_turns": len(parsed["assistant_turns"]), |
| "raw_trace": raw_text, |
| "conversations_json": json.dumps(messages), |
| "label": label, |
| }) |
|
|
| if skipped > 0: |
| print(f"[WARN] Skipped {skipped} malformed samples.") |
|
|
| df = pd.DataFrame(records) |
| print(f"[INFO] Processed {len(df)} traces.") |
| print(f"[INFO] Label distribution:\n{df['label'].value_counts().to_string()}") |
| print(f"[INFO] Anomaly rate: {df['label'].mean():.2%}") |
| print(f"[INFO] Tool calls stats:\n{df['num_tool_calls'].describe()}") |
|
|
| return df |
|
|
|
|
| def split_and_save(df, output_dir, test_size=0.15, val_size=0.15, seed=42): |
| """Stratified train/val/test split. Saves as parquet.""" |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| if df["label"].nunique() < 2: |
| print("[WARN] Only one class found. Using random splits.") |
| train_val, test = train_test_split(df, test_size=test_size, random_state=seed) |
| relative_val = val_size / (1 - test_size) |
| train, val = train_test_split(train_val, test_size=relative_val, random_state=seed) |
| else: |
| train_val, test = train_test_split( |
| df, test_size=test_size, random_state=seed, stratify=df["label"] |
| ) |
| relative_val = val_size / (1 - test_size) |
| train, val = train_test_split( |
| train_val, test_size=relative_val, random_state=seed, stratify=train_val["label"] |
| ) |
|
|
| for name, split_df in [("train", train), ("val", val), ("test", test)]: |
| path = os.path.join(output_dir, f"{name}.parquet") |
| split_df.to_parquet(path, index=False) |
| print(f"[INFO] Saved {name}: {len(split_df)} samples β {path}") |
| print(f" Label dist: {dict(split_df['label'].value_counts())}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Build ToolBench anomaly detection dataset") |
| parser.add_argument("--max_samples", type=int, default=None) |
| parser.add_argument("--test_size", type=float, default=0.15) |
| parser.add_argument("--val_size", type=float, default=0.15) |
| parser.add_argument("--output_dir", type=str, default="data/processed") |
| args = parser.parse_args() |
|
|
| df = process_dataset(max_samples=args.max_samples) |
| split_and_save(df, args.output_dir, args.test_size, args.val_size) |
| print("[DONE] Dataset ready.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |