Text Generation
PEFT
Safetensors
Chinese
English
lora
tool-selection
tool-call
guardrail
chinese
traditional-chinese
fine-tuned
qwen2
conversational
Instructions to use GOSHUNCLE/tool_call_validator_zh with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use GOSHUNCLE/tool_call_validator_zh with PEFT:
from peft import PeftModel from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct") model = PeftModel.from_pretrained(base_model, "GOSHUNCLE/tool_call_validator_zh") - Notebooks
- Google Colab
- Kaggle
File size: 8,868 Bytes
d659bd1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | """
tool_call_validator_zh - Inference Reference Implementation
提供給 HF Hub 使用者的最小可用推論程式碼。
包含 Filter 1 (Schema) + Filter 2 (Provenance) 雙層保險。
使用範例(quickstart):
from inference import Detector
detector = Detector("Qwen/Qwen2.5-3B-Instruct", "GOSHUNCLE/tool_call_validator_zh")
result = detector.detect(
user_prompt="請幫我查一下今天台北的 PM2.5 空氣品質指數。",
tools=[
{"name": "web_search", "description": "透過搜尋引擎即時取得網路上最新資訊"},
{"name": "calendar_view", "description": "查看使用者的行事曆"},
],
)
print(result)
"""
from __future__ import annotations
import json
from typing import Optional
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
SYSTEM_PROMPT = """你是工具選擇守門員(Tool Selection Guardrail)。
你的職責是分析使用者請求,從候選工具清單中選出最適合的工具,或在無合適工具時拒絕匹配。
任務:
1. 閱讀使用者的請求(user_prompt)與候選工具清單(tools,含 name 與 description)。
2. 判斷哪一個 tool 最符合使用者意圖,或所有候選皆不適用。
3. 輸出嚴格 JSON 結果。
輸出格式:
{
"reasoning": {
"intent_summary": "<30-60字:辨識使用者意圖>",
"key_signals": "<20-40字:抓出使用者請求中的關鍵詞與語意訊號>",
"conclusion": "<30-60字:說明為什麼選 X 或為什麼拒絕匹配>"
},
"selected_tool": "<候選工具名稱,或在拒絕匹配時為 null>",
"signal": "commit" 或 "abstain",
"confidence": "high" 、 "medium" 或 "low"
}
判斷原則:
1. selected_tool 必須是候選清單中的 tool name 之一(commit 時)或 null(abstain 時)。
2. signal = "commit":候選中至少有 1 個明確相關工具,能涵蓋使用者意圖。
3. signal = "abstain":候選清單中沒有任何工具能涵蓋使用者核心意圖;即使部分功能沾邊也應拒答。
4. confidence 等級:
- high:候選中僅 1 個明確相關(或全部明確不相關),無語意混淆。
- medium:候選中有 1~2 個邊緣相關(混淆 pair),需轉一個彎才能對應。
- low:多個候選都可能適用,理由勉強選 1 個(或極邊緣拒答)。
規則:
1. selected_tool 必須逐字符合候選清單中的 name(含大小寫與底線)。
2. 不要為了避免 abstain 而強選不適用的工具——abstain 是有效輸出。
3. reasoning 用繁體中文,不直接抄 tool description 全文,要重組為意圖陳述。
4. 只回傳 JSON,無其他說明文字。
"""
VALID_SIGNAL = {"commit", "abstain"}
VALID_CONFIDENCE = {"high", "medium", "low"}
REQUIRED_REASONING = {"intent_summary", "key_signals", "conclusion"}
class Detector:
"""LoRA + Filter 1 + Filter 2 完整推論器"""
def __init__(
self,
base_model: str = "Qwen/Qwen2.5-3B-Instruct",
adapter: Optional[str] = "GOSHUNCLE/tool_call_validator_zh",
max_new_tokens: int = 384,
device: Optional[str] = None,
):
self.tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
kwargs = {"torch_dtype": dtype, "trust_remote_code": True}
if torch.cuda.is_available():
kwargs["device_map"] = "auto"
else:
kwargs["low_cpu_mem_usage"] = True
self.model = AutoModelForCausalLM.from_pretrained(base_model, **kwargs)
if adapter:
self.model = PeftModel.from_pretrained(self.model, adapter)
self.model.eval()
self.max_new_tokens = max_new_tokens
@staticmethod
def _format_user_message(user_prompt: str, tools: list) -> str:
tools_block = "\n".join(
f"{i+1}. {t['name']}: {t['description']}" for i, t in enumerate(tools)
)
return f"使用者請求:\n{user_prompt}\n\n候選工具:\n{tools_block}"
@torch.inference_mode()
def generate_raw(self, user_prompt: str, tools: list) -> str:
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": self._format_user_message(user_prompt, tools)},
]
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=False,
pad_token_id=self.tokenizer.pad_token_id,
)
gen = outputs[0][inputs.input_ids.shape[1]:]
return self.tokenizer.decode(gen, skip_special_tokens=True).strip()
# ------------------------------------------------------------------
# Filter 1: Schema validation
# ------------------------------------------------------------------
@staticmethod
def _parse_json_lenient(text: str) -> Optional[dict]:
text = text.strip()
start = text.find("{")
if start < 0:
return None
depth = 0
for i in range(start, len(text)):
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
if depth == 0:
try:
return json.loads(text[start:i+1])
except json.JSONDecodeError:
return None
return None
@staticmethod
def _filter_schema(parsed: Optional[dict]) -> tuple[dict, bool]:
fallback = {
"reasoning": {
"intent_summary": "[Filter fallback]",
"key_signals": "[Filter fallback]",
"conclusion": "[Filter fallback] 輸出格式錯誤,安全拒答。",
},
"selected_tool": None,
"signal": "abstain",
"confidence": "low",
}
if not isinstance(parsed, dict):
return fallback, False
if not all(k in parsed for k in ("reasoning", "selected_tool", "signal", "confidence")):
return fallback, False
if parsed["signal"] not in VALID_SIGNAL:
return fallback, False
if parsed["confidence"] not in VALID_CONFIDENCE:
return fallback, False
if not isinstance(parsed["reasoning"], dict):
return fallback, False
if not REQUIRED_REASONING.issubset(parsed["reasoning"].keys()):
return fallback, False
if parsed["signal"] == "commit" and parsed["selected_tool"] is None:
return fallback, False
if parsed["signal"] == "abstain":
parsed["selected_tool"] = None
return parsed, True
# ------------------------------------------------------------------
# Filter 2: Provenance check
# ------------------------------------------------------------------
@staticmethod
def _filter_provenance(parsed: dict, tools: list) -> dict:
if parsed["signal"] != "commit":
return parsed
names = {t["name"] for t in tools}
if parsed.get("selected_tool") not in names:
parsed = dict(parsed)
parsed["signal"] = "abstain"
parsed["selected_tool"] = None
parsed["confidence"] = "low"
parsed["reasoning"] = dict(parsed["reasoning"])
parsed["reasoning"]["conclusion"] = (
"[Filter fallback] 模型輸出的 selected_tool 不在候選清單中,安全拒答。"
)
return parsed
def detect(self, user_prompt: str, tools: list, apply_filters: bool = True) -> dict:
raw = self.generate_raw(user_prompt, tools)
parsed = self._parse_json_lenient(raw)
if not apply_filters:
return parsed if parsed else {"_unparseable": True, "_raw": raw}
parsed, _ = self._filter_schema(parsed)
parsed = self._filter_provenance(parsed, tools)
return parsed
if __name__ == "__main__":
# Quick demo
detector = Detector()
result = detector.detect(
user_prompt="請幫我查一下今天台北的 PM2.5 空氣品質指數。",
tools=[
{"name": "web_search", "description": "透過搜尋引擎即時取得網路上最新資訊"},
{"name": "calendar_view", "description": "查看使用者的行事曆"},
{"name": "calculator", "description": "進行數值與數學運算"},
],
)
print(json.dumps(result, ensure_ascii=False, indent=2))
|