File size: 8,868 Bytes
d659bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""
tool_call_validator_zh - Inference Reference Implementation

提供給 HF Hub 使用者的最小可用推論程式碼。
包含 Filter 1 (Schema) + Filter 2 (Provenance) 雙層保險。

使用範例(quickstart):

  from inference import Detector
  detector = Detector("Qwen/Qwen2.5-3B-Instruct", "GOSHUNCLE/tool_call_validator_zh")
  result = detector.detect(
      user_prompt="請幫我查一下今天台北的 PM2.5 空氣品質指數。",
      tools=[
          {"name": "web_search", "description": "透過搜尋引擎即時取得網路上最新資訊"},
          {"name": "calendar_view", "description": "查看使用者的行事曆"},
      ],
  )
  print(result)
"""
from __future__ import annotations

import json
from typing import Optional

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer


SYSTEM_PROMPT = """你是工具選擇守門員(Tool Selection Guardrail)。
你的職責是分析使用者請求,從候選工具清單中選出最適合的工具,或在無合適工具時拒絕匹配。

任務:
1. 閱讀使用者的請求(user_prompt)與候選工具清單(tools,含 name 與 description)。
2. 判斷哪一個 tool 最符合使用者意圖,或所有候選皆不適用。
3. 輸出嚴格 JSON 結果。

輸出格式:
{
  "reasoning": {
    "intent_summary": "<30-60字:辨識使用者意圖>",
    "key_signals": "<20-40字:抓出使用者請求中的關鍵詞與語意訊號>",
    "conclusion": "<30-60字:說明為什麼選 X 或為什麼拒絕匹配>"
  },
  "selected_tool": "<候選工具名稱,或在拒絕匹配時為 null>",
  "signal": "commit" 或 "abstain",
  "confidence": "high" 、 "medium" 或 "low"
}

判斷原則:
1. selected_tool 必須是候選清單中的 tool name 之一(commit 時)或 null(abstain 時)。
2. signal = "commit":候選中至少有 1 個明確相關工具,能涵蓋使用者意圖。
3. signal = "abstain":候選清單中沒有任何工具能涵蓋使用者核心意圖;即使部分功能沾邊也應拒答。
4. confidence 等級:
   - high:候選中僅 1 個明確相關(或全部明確不相關),無語意混淆。
   - medium:候選中有 1~2 個邊緣相關(混淆 pair),需轉一個彎才能對應。
   - low:多個候選都可能適用,理由勉強選 1 個(或極邊緣拒答)。

規則:
1. selected_tool 必須逐字符合候選清單中的 name(含大小寫與底線)。
2. 不要為了避免 abstain 而強選不適用的工具——abstain 是有效輸出。
3. reasoning 用繁體中文,不直接抄 tool description 全文,要重組為意圖陳述。
4. 只回傳 JSON,無其他說明文字。
"""


VALID_SIGNAL = {"commit", "abstain"}
VALID_CONFIDENCE = {"high", "medium", "low"}
REQUIRED_REASONING = {"intent_summary", "key_signals", "conclusion"}


class Detector:
    """LoRA + Filter 1 + Filter 2 完整推論器"""

    def __init__(
        self,
        base_model: str = "Qwen/Qwen2.5-3B-Instruct",
        adapter: Optional[str] = "GOSHUNCLE/tool_call_validator_zh",
        max_new_tokens: int = 384,
        device: Optional[str] = None,
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        kwargs = {"torch_dtype": dtype, "trust_remote_code": True}
        if torch.cuda.is_available():
            kwargs["device_map"] = "auto"
        else:
            kwargs["low_cpu_mem_usage"] = True

        self.model = AutoModelForCausalLM.from_pretrained(base_model, **kwargs)
        if adapter:
            self.model = PeftModel.from_pretrained(self.model, adapter)
        self.model.eval()
        self.max_new_tokens = max_new_tokens

    @staticmethod
    def _format_user_message(user_prompt: str, tools: list) -> str:
        tools_block = "\n".join(
            f"{i+1}. {t['name']}: {t['description']}" for i, t in enumerate(tools)
        )
        return f"使用者請求:\n{user_prompt}\n\n候選工具:\n{tools_block}"

    @torch.inference_mode()
    def generate_raw(self, user_prompt: str, tools: list) -> str:
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": self._format_user_message(user_prompt, tools)},
        ]
        prompt = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=self.max_new_tokens,
            do_sample=False,
            pad_token_id=self.tokenizer.pad_token_id,
        )
        gen = outputs[0][inputs.input_ids.shape[1]:]
        return self.tokenizer.decode(gen, skip_special_tokens=True).strip()

    # ------------------------------------------------------------------
    # Filter 1: Schema validation
    # ------------------------------------------------------------------
    @staticmethod
    def _parse_json_lenient(text: str) -> Optional[dict]:
        text = text.strip()
        start = text.find("{")
        if start < 0:
            return None
        depth = 0
        for i in range(start, len(text)):
            if text[i] == "{":
                depth += 1
            elif text[i] == "}":
                depth -= 1
                if depth == 0:
                    try:
                        return json.loads(text[start:i+1])
                    except json.JSONDecodeError:
                        return None
        return None

    @staticmethod
    def _filter_schema(parsed: Optional[dict]) -> tuple[dict, bool]:
        fallback = {
            "reasoning": {
                "intent_summary": "[Filter fallback]",
                "key_signals": "[Filter fallback]",
                "conclusion": "[Filter fallback] 輸出格式錯誤,安全拒答。",
            },
            "selected_tool": None,
            "signal": "abstain",
            "confidence": "low",
        }
        if not isinstance(parsed, dict):
            return fallback, False
        if not all(k in parsed for k in ("reasoning", "selected_tool", "signal", "confidence")):
            return fallback, False
        if parsed["signal"] not in VALID_SIGNAL:
            return fallback, False
        if parsed["confidence"] not in VALID_CONFIDENCE:
            return fallback, False
        if not isinstance(parsed["reasoning"], dict):
            return fallback, False
        if not REQUIRED_REASONING.issubset(parsed["reasoning"].keys()):
            return fallback, False
        if parsed["signal"] == "commit" and parsed["selected_tool"] is None:
            return fallback, False
        if parsed["signal"] == "abstain":
            parsed["selected_tool"] = None
        return parsed, True

    # ------------------------------------------------------------------
    # Filter 2: Provenance check
    # ------------------------------------------------------------------
    @staticmethod
    def _filter_provenance(parsed: dict, tools: list) -> dict:
        if parsed["signal"] != "commit":
            return parsed
        names = {t["name"] for t in tools}
        if parsed.get("selected_tool") not in names:
            parsed = dict(parsed)
            parsed["signal"] = "abstain"
            parsed["selected_tool"] = None
            parsed["confidence"] = "low"
            parsed["reasoning"] = dict(parsed["reasoning"])
            parsed["reasoning"]["conclusion"] = (
                "[Filter fallback] 模型輸出的 selected_tool 不在候選清單中,安全拒答。"
            )
        return parsed

    def detect(self, user_prompt: str, tools: list, apply_filters: bool = True) -> dict:
        raw = self.generate_raw(user_prompt, tools)
        parsed = self._parse_json_lenient(raw)
        if not apply_filters:
            return parsed if parsed else {"_unparseable": True, "_raw": raw}
        parsed, _ = self._filter_schema(parsed)
        parsed = self._filter_provenance(parsed, tools)
        return parsed


if __name__ == "__main__":
    # Quick demo
    detector = Detector()
    result = detector.detect(
        user_prompt="請幫我查一下今天台北的 PM2.5 空氣品質指數。",
        tools=[
            {"name": "web_search",    "description": "透過搜尋引擎即時取得網路上最新資訊"},
            {"name": "calendar_view", "description": "查看使用者的行事曆"},
            {"name": "calculator",    "description": "進行數值與數學運算"},
        ],
    )
    print(json.dumps(result, ensure_ascii=False, indent=2))