Text Generation
PEFT
Safetensors
Chinese
English
lora
tool-selection
tool-call
guardrail
chinese
traditional-chinese
fine-tuned
qwen2
conversational
Instructions to use GOSHUNCLE/tool_call_validator_zh with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use GOSHUNCLE/tool_call_validator_zh with PEFT:
from peft import PeftModel from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct") model = PeftModel.from_pretrained(base_model, "GOSHUNCLE/tool_call_validator_zh") - Notebooks
- Google Colab
- Kaggle
| """ | |
| tool_call_validator_zh - Inference Reference Implementation | |
| 提供給 HF Hub 使用者的最小可用推論程式碼。 | |
| 包含 Filter 1 (Schema) + Filter 2 (Provenance) 雙層保險。 | |
| 使用範例(quickstart): | |
| from inference import Detector | |
| detector = Detector("Qwen/Qwen2.5-3B-Instruct", "GOSHUNCLE/tool_call_validator_zh") | |
| result = detector.detect( | |
| user_prompt="請幫我查一下今天台北的 PM2.5 空氣品質指數。", | |
| tools=[ | |
| {"name": "web_search", "description": "透過搜尋引擎即時取得網路上最新資訊"}, | |
| {"name": "calendar_view", "description": "查看使用者的行事曆"}, | |
| ], | |
| ) | |
| print(result) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from typing import Optional | |
| import torch | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| SYSTEM_PROMPT = """你是工具選擇守門員(Tool Selection Guardrail)。 | |
| 你的職責是分析使用者請求,從候選工具清單中選出最適合的工具,或在無合適工具時拒絕匹配。 | |
| 任務: | |
| 1. 閱讀使用者的請求(user_prompt)與候選工具清單(tools,含 name 與 description)。 | |
| 2. 判斷哪一個 tool 最符合使用者意圖,或所有候選皆不適用。 | |
| 3. 輸出嚴格 JSON 結果。 | |
| 輸出格式: | |
| { | |
| "reasoning": { | |
| "intent_summary": "<30-60字:辨識使用者意圖>", | |
| "key_signals": "<20-40字:抓出使用者請求中的關鍵詞與語意訊號>", | |
| "conclusion": "<30-60字:說明為什麼選 X 或為什麼拒絕匹配>" | |
| }, | |
| "selected_tool": "<候選工具名稱,或在拒絕匹配時為 null>", | |
| "signal": "commit" 或 "abstain", | |
| "confidence": "high" 、 "medium" 或 "low" | |
| } | |
| 判斷原則: | |
| 1. selected_tool 必須是候選清單中的 tool name 之一(commit 時)或 null(abstain 時)。 | |
| 2. signal = "commit":候選中至少有 1 個明確相關工具,能涵蓋使用者意圖。 | |
| 3. signal = "abstain":候選清單中沒有任何工具能涵蓋使用者核心意圖;即使部分功能沾邊也應拒答。 | |
| 4. confidence 等級: | |
| - high:候選中僅 1 個明確相關(或全部明確不相關),無語意混淆。 | |
| - medium:候選中有 1~2 個邊緣相關(混淆 pair),需轉一個彎才能對應。 | |
| - low:多個候選都可能適用,理由勉強選 1 個(或極邊緣拒答)。 | |
| 規則: | |
| 1. selected_tool 必須逐字符合候選清單中的 name(含大小寫與底線)。 | |
| 2. 不要為了避免 abstain 而強選不適用的工具——abstain 是有效輸出。 | |
| 3. reasoning 用繁體中文,不直接抄 tool description 全文,要重組為意圖陳述。 | |
| 4. 只回傳 JSON,無其他說明文字。 | |
| """ | |
| VALID_SIGNAL = {"commit", "abstain"} | |
| VALID_CONFIDENCE = {"high", "medium", "low"} | |
| REQUIRED_REASONING = {"intent_summary", "key_signals", "conclusion"} | |
| class Detector: | |
| """LoRA + Filter 1 + Filter 2 完整推論器""" | |
| def __init__( | |
| self, | |
| base_model: str = "Qwen/Qwen2.5-3B-Instruct", | |
| adapter: Optional[str] = "GOSHUNCLE/tool_call_validator_zh", | |
| max_new_tokens: int = 384, | |
| device: Optional[str] = None, | |
| ): | |
| self.tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| kwargs = {"torch_dtype": dtype, "trust_remote_code": True} | |
| if torch.cuda.is_available(): | |
| kwargs["device_map"] = "auto" | |
| else: | |
| kwargs["low_cpu_mem_usage"] = True | |
| self.model = AutoModelForCausalLM.from_pretrained(base_model, **kwargs) | |
| if adapter: | |
| self.model = PeftModel.from_pretrained(self.model, adapter) | |
| self.model.eval() | |
| self.max_new_tokens = max_new_tokens | |
| def _format_user_message(user_prompt: str, tools: list) -> str: | |
| tools_block = "\n".join( | |
| f"{i+1}. {t['name']}: {t['description']}" for i, t in enumerate(tools) | |
| ) | |
| return f"使用者請求:\n{user_prompt}\n\n候選工具:\n{tools_block}" | |
| def generate_raw(self, user_prompt: str, tools: list) -> str: | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": self._format_user_message(user_prompt, tools)}, | |
| ] | |
| prompt = self.tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=self.max_new_tokens, | |
| do_sample=False, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| gen = outputs[0][inputs.input_ids.shape[1]:] | |
| return self.tokenizer.decode(gen, skip_special_tokens=True).strip() | |
| # ------------------------------------------------------------------ | |
| # Filter 1: Schema validation | |
| # ------------------------------------------------------------------ | |
| def _parse_json_lenient(text: str) -> Optional[dict]: | |
| text = text.strip() | |
| start = text.find("{") | |
| if start < 0: | |
| return None | |
| depth = 0 | |
| for i in range(start, len(text)): | |
| if text[i] == "{": | |
| depth += 1 | |
| elif text[i] == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| try: | |
| return json.loads(text[start:i+1]) | |
| except json.JSONDecodeError: | |
| return None | |
| return None | |
| def _filter_schema(parsed: Optional[dict]) -> tuple[dict, bool]: | |
| fallback = { | |
| "reasoning": { | |
| "intent_summary": "[Filter fallback]", | |
| "key_signals": "[Filter fallback]", | |
| "conclusion": "[Filter fallback] 輸出格式錯誤,安全拒答。", | |
| }, | |
| "selected_tool": None, | |
| "signal": "abstain", | |
| "confidence": "low", | |
| } | |
| if not isinstance(parsed, dict): | |
| return fallback, False | |
| if not all(k in parsed for k in ("reasoning", "selected_tool", "signal", "confidence")): | |
| return fallback, False | |
| if parsed["signal"] not in VALID_SIGNAL: | |
| return fallback, False | |
| if parsed["confidence"] not in VALID_CONFIDENCE: | |
| return fallback, False | |
| if not isinstance(parsed["reasoning"], dict): | |
| return fallback, False | |
| if not REQUIRED_REASONING.issubset(parsed["reasoning"].keys()): | |
| return fallback, False | |
| if parsed["signal"] == "commit" and parsed["selected_tool"] is None: | |
| return fallback, False | |
| if parsed["signal"] == "abstain": | |
| parsed["selected_tool"] = None | |
| return parsed, True | |
| # ------------------------------------------------------------------ | |
| # Filter 2: Provenance check | |
| # ------------------------------------------------------------------ | |
| def _filter_provenance(parsed: dict, tools: list) -> dict: | |
| if parsed["signal"] != "commit": | |
| return parsed | |
| names = {t["name"] for t in tools} | |
| if parsed.get("selected_tool") not in names: | |
| parsed = dict(parsed) | |
| parsed["signal"] = "abstain" | |
| parsed["selected_tool"] = None | |
| parsed["confidence"] = "low" | |
| parsed["reasoning"] = dict(parsed["reasoning"]) | |
| parsed["reasoning"]["conclusion"] = ( | |
| "[Filter fallback] 模型輸出的 selected_tool 不在候選清單中,安全拒答。" | |
| ) | |
| return parsed | |
| def detect(self, user_prompt: str, tools: list, apply_filters: bool = True) -> dict: | |
| raw = self.generate_raw(user_prompt, tools) | |
| parsed = self._parse_json_lenient(raw) | |
| if not apply_filters: | |
| return parsed if parsed else {"_unparseable": True, "_raw": raw} | |
| parsed, _ = self._filter_schema(parsed) | |
| parsed = self._filter_provenance(parsed, tools) | |
| return parsed | |
| if __name__ == "__main__": | |
| # Quick demo | |
| detector = Detector() | |
| result = detector.detect( | |
| user_prompt="請幫我查一下今天台北的 PM2.5 空氣品質指數。", | |
| tools=[ | |
| {"name": "web_search", "description": "透過搜尋引擎即時取得網路上最新資訊"}, | |
| {"name": "calendar_view", "description": "查看使用者的行事曆"}, | |
| {"name": "calculator", "description": "進行數值與數學運算"}, | |
| ], | |
| ) | |
| print(json.dumps(result, ensure_ascii=False, indent=2)) | |