""" tool_call_validator_zh - Inference Reference Implementation 提供給 HF Hub 使用者的最小可用推論程式碼。 包含 Filter 1 (Schema) + Filter 2 (Provenance) 雙層保險。 使用範例(quickstart): from inference import Detector detector = Detector("Qwen/Qwen2.5-3B-Instruct", "GOSHUNCLE/tool_call_validator_zh") result = detector.detect( user_prompt="請幫我查一下今天台北的 PM2.5 空氣品質指數。", tools=[ {"name": "web_search", "description": "透過搜尋引擎即時取得網路上最新資訊"}, {"name": "calendar_view", "description": "查看使用者的行事曆"}, ], ) print(result) """ from __future__ import annotations import json from typing import Optional import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer SYSTEM_PROMPT = """你是工具選擇守門員(Tool Selection Guardrail)。 你的職責是分析使用者請求,從候選工具清單中選出最適合的工具,或在無合適工具時拒絕匹配。 任務: 1. 閱讀使用者的請求(user_prompt)與候選工具清單(tools,含 name 與 description)。 2. 判斷哪一個 tool 最符合使用者意圖,或所有候選皆不適用。 3. 輸出嚴格 JSON 結果。 輸出格式: { "reasoning": { "intent_summary": "<30-60字:辨識使用者意圖>", "key_signals": "<20-40字:抓出使用者請求中的關鍵詞與語意訊號>", "conclusion": "<30-60字:說明為什麼選 X 或為什麼拒絕匹配>" }, "selected_tool": "<候選工具名稱,或在拒絕匹配時為 null>", "signal": "commit" 或 "abstain", "confidence": "high" 、 "medium" 或 "low" } 判斷原則: 1. selected_tool 必須是候選清單中的 tool name 之一(commit 時)或 null(abstain 時)。 2. signal = "commit":候選中至少有 1 個明確相關工具,能涵蓋使用者意圖。 3. signal = "abstain":候選清單中沒有任何工具能涵蓋使用者核心意圖;即使部分功能沾邊也應拒答。 4. confidence 等級: - high:候選中僅 1 個明確相關(或全部明確不相關),無語意混淆。 - medium:候選中有 1~2 個邊緣相關(混淆 pair),需轉一個彎才能對應。 - low:多個候選都可能適用,理由勉強選 1 個(或極邊緣拒答)。 規則: 1. selected_tool 必須逐字符合候選清單中的 name(含大小寫與底線)。 2. 不要為了避免 abstain 而強選不適用的工具——abstain 是有效輸出。 3. reasoning 用繁體中文,不直接抄 tool description 全文,要重組為意圖陳述。 4. 只回傳 JSON,無其他說明文字。 """ VALID_SIGNAL = {"commit", "abstain"} VALID_CONFIDENCE = {"high", "medium", "low"} REQUIRED_REASONING = {"intent_summary", "key_signals", "conclusion"} class Detector: """LoRA + Filter 1 + Filter 2 完整推論器""" def __init__( self, base_model: str = "Qwen/Qwen2.5-3B-Instruct", adapter: Optional[str] = "GOSHUNCLE/tool_call_validator_zh", max_new_tokens: int = 384, device: Optional[str] = None, ): self.tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token dtype = torch.float16 if torch.cuda.is_available() else torch.float32 kwargs = {"torch_dtype": dtype, "trust_remote_code": True} if torch.cuda.is_available(): kwargs["device_map"] = "auto" else: kwargs["low_cpu_mem_usage"] = True self.model = AutoModelForCausalLM.from_pretrained(base_model, **kwargs) if adapter: self.model = PeftModel.from_pretrained(self.model, adapter) self.model.eval() self.max_new_tokens = max_new_tokens @staticmethod def _format_user_message(user_prompt: str, tools: list) -> str: tools_block = "\n".join( f"{i+1}. {t['name']}: {t['description']}" for i, t in enumerate(tools) ) return f"使用者請求:\n{user_prompt}\n\n候選工具:\n{tools_block}" @torch.inference_mode() def generate_raw(self, user_prompt: str, tools: list) -> str: messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": self._format_user_message(user_prompt, tools)}, ] prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) outputs = self.model.generate( **inputs, max_new_tokens=self.max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.pad_token_id, ) gen = outputs[0][inputs.input_ids.shape[1]:] return self.tokenizer.decode(gen, skip_special_tokens=True).strip() # ------------------------------------------------------------------ # Filter 1: Schema validation # ------------------------------------------------------------------ @staticmethod def _parse_json_lenient(text: str) -> Optional[dict]: text = text.strip() start = text.find("{") if start < 0: return None depth = 0 for i in range(start, len(text)): if text[i] == "{": depth += 1 elif text[i] == "}": depth -= 1 if depth == 0: try: return json.loads(text[start:i+1]) except json.JSONDecodeError: return None return None @staticmethod def _filter_schema(parsed: Optional[dict]) -> tuple[dict, bool]: fallback = { "reasoning": { "intent_summary": "[Filter fallback]", "key_signals": "[Filter fallback]", "conclusion": "[Filter fallback] 輸出格式錯誤,安全拒答。", }, "selected_tool": None, "signal": "abstain", "confidence": "low", } if not isinstance(parsed, dict): return fallback, False if not all(k in parsed for k in ("reasoning", "selected_tool", "signal", "confidence")): return fallback, False if parsed["signal"] not in VALID_SIGNAL: return fallback, False if parsed["confidence"] not in VALID_CONFIDENCE: return fallback, False if not isinstance(parsed["reasoning"], dict): return fallback, False if not REQUIRED_REASONING.issubset(parsed["reasoning"].keys()): return fallback, False if parsed["signal"] == "commit" and parsed["selected_tool"] is None: return fallback, False if parsed["signal"] == "abstain": parsed["selected_tool"] = None return parsed, True # ------------------------------------------------------------------ # Filter 2: Provenance check # ------------------------------------------------------------------ @staticmethod def _filter_provenance(parsed: dict, tools: list) -> dict: if parsed["signal"] != "commit": return parsed names = {t["name"] for t in tools} if parsed.get("selected_tool") not in names: parsed = dict(parsed) parsed["signal"] = "abstain" parsed["selected_tool"] = None parsed["confidence"] = "low" parsed["reasoning"] = dict(parsed["reasoning"]) parsed["reasoning"]["conclusion"] = ( "[Filter fallback] 模型輸出的 selected_tool 不在候選清單中,安全拒答。" ) return parsed def detect(self, user_prompt: str, tools: list, apply_filters: bool = True) -> dict: raw = self.generate_raw(user_prompt, tools) parsed = self._parse_json_lenient(raw) if not apply_filters: return parsed if parsed else {"_unparseable": True, "_raw": raw} parsed, _ = self._filter_schema(parsed) parsed = self._filter_provenance(parsed, tools) return parsed if __name__ == "__main__": # Quick demo detector = Detector() result = detector.detect( user_prompt="請幫我查一下今天台北的 PM2.5 空氣品質指數。", tools=[ {"name": "web_search", "description": "透過搜尋引擎即時取得網路上最新資訊"}, {"name": "calendar_view", "description": "查看使用者的行事曆"}, {"name": "calculator", "description": "進行數值與數學運算"}, ], ) print(json.dumps(result, ensure_ascii=False, indent=2))