tool_call_validator_zh / inference.py
GOSHUNCLE's picture
Upload inference.py
d659bd1 verified
"""
tool_call_validator_zh - Inference Reference Implementation
提供給 HF Hub 使用者的最小可用推論程式碼。
包含 Filter 1 (Schema) + Filter 2 (Provenance) 雙層保險。
使用範例(quickstart):
from inference import Detector
detector = Detector("Qwen/Qwen2.5-3B-Instruct", "GOSHUNCLE/tool_call_validator_zh")
result = detector.detect(
user_prompt="請幫我查一下今天台北的 PM2.5 空氣品質指數。",
tools=[
{"name": "web_search", "description": "透過搜尋引擎即時取得網路上最新資訊"},
{"name": "calendar_view", "description": "查看使用者的行事曆"},
],
)
print(result)
"""
from __future__ import annotations
import json
from typing import Optional
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
SYSTEM_PROMPT = """你是工具選擇守門員(Tool Selection Guardrail)。
你的職責是分析使用者請求,從候選工具清單中選出最適合的工具,或在無合適工具時拒絕匹配。
任務:
1. 閱讀使用者的請求(user_prompt)與候選工具清單(tools,含 name 與 description)。
2. 判斷哪一個 tool 最符合使用者意圖,或所有候選皆不適用。
3. 輸出嚴格 JSON 結果。
輸出格式:
{
"reasoning": {
"intent_summary": "<30-60字:辨識使用者意圖>",
"key_signals": "<20-40字:抓出使用者請求中的關鍵詞與語意訊號>",
"conclusion": "<30-60字:說明為什麼選 X 或為什麼拒絕匹配>"
},
"selected_tool": "<候選工具名稱,或在拒絕匹配時為 null>",
"signal": "commit" 或 "abstain",
"confidence": "high" 、 "medium" 或 "low"
}
判斷原則:
1. selected_tool 必須是候選清單中的 tool name 之一(commit 時)或 null(abstain 時)。
2. signal = "commit":候選中至少有 1 個明確相關工具,能涵蓋使用者意圖。
3. signal = "abstain":候選清單中沒有任何工具能涵蓋使用者核心意圖;即使部分功能沾邊也應拒答。
4. confidence 等級:
- high:候選中僅 1 個明確相關(或全部明確不相關),無語意混淆。
- medium:候選中有 1~2 個邊緣相關(混淆 pair),需轉一個彎才能對應。
- low:多個候選都可能適用,理由勉強選 1 個(或極邊緣拒答)。
規則:
1. selected_tool 必須逐字符合候選清單中的 name(含大小寫與底線)。
2. 不要為了避免 abstain 而強選不適用的工具——abstain 是有效輸出。
3. reasoning 用繁體中文,不直接抄 tool description 全文,要重組為意圖陳述。
4. 只回傳 JSON,無其他說明文字。
"""
VALID_SIGNAL = {"commit", "abstain"}
VALID_CONFIDENCE = {"high", "medium", "low"}
REQUIRED_REASONING = {"intent_summary", "key_signals", "conclusion"}
class Detector:
"""LoRA + Filter 1 + Filter 2 完整推論器"""
def __init__(
self,
base_model: str = "Qwen/Qwen2.5-3B-Instruct",
adapter: Optional[str] = "GOSHUNCLE/tool_call_validator_zh",
max_new_tokens: int = 384,
device: Optional[str] = None,
):
self.tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
kwargs = {"torch_dtype": dtype, "trust_remote_code": True}
if torch.cuda.is_available():
kwargs["device_map"] = "auto"
else:
kwargs["low_cpu_mem_usage"] = True
self.model = AutoModelForCausalLM.from_pretrained(base_model, **kwargs)
if adapter:
self.model = PeftModel.from_pretrained(self.model, adapter)
self.model.eval()
self.max_new_tokens = max_new_tokens
@staticmethod
def _format_user_message(user_prompt: str, tools: list) -> str:
tools_block = "\n".join(
f"{i+1}. {t['name']}: {t['description']}" for i, t in enumerate(tools)
)
return f"使用者請求:\n{user_prompt}\n\n候選工具:\n{tools_block}"
@torch.inference_mode()
def generate_raw(self, user_prompt: str, tools: list) -> str:
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": self._format_user_message(user_prompt, tools)},
]
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=False,
pad_token_id=self.tokenizer.pad_token_id,
)
gen = outputs[0][inputs.input_ids.shape[1]:]
return self.tokenizer.decode(gen, skip_special_tokens=True).strip()
# ------------------------------------------------------------------
# Filter 1: Schema validation
# ------------------------------------------------------------------
@staticmethod
def _parse_json_lenient(text: str) -> Optional[dict]:
text = text.strip()
start = text.find("{")
if start < 0:
return None
depth = 0
for i in range(start, len(text)):
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
if depth == 0:
try:
return json.loads(text[start:i+1])
except json.JSONDecodeError:
return None
return None
@staticmethod
def _filter_schema(parsed: Optional[dict]) -> tuple[dict, bool]:
fallback = {
"reasoning": {
"intent_summary": "[Filter fallback]",
"key_signals": "[Filter fallback]",
"conclusion": "[Filter fallback] 輸出格式錯誤,安全拒答。",
},
"selected_tool": None,
"signal": "abstain",
"confidence": "low",
}
if not isinstance(parsed, dict):
return fallback, False
if not all(k in parsed for k in ("reasoning", "selected_tool", "signal", "confidence")):
return fallback, False
if parsed["signal"] not in VALID_SIGNAL:
return fallback, False
if parsed["confidence"] not in VALID_CONFIDENCE:
return fallback, False
if not isinstance(parsed["reasoning"], dict):
return fallback, False
if not REQUIRED_REASONING.issubset(parsed["reasoning"].keys()):
return fallback, False
if parsed["signal"] == "commit" and parsed["selected_tool"] is None:
return fallback, False
if parsed["signal"] == "abstain":
parsed["selected_tool"] = None
return parsed, True
# ------------------------------------------------------------------
# Filter 2: Provenance check
# ------------------------------------------------------------------
@staticmethod
def _filter_provenance(parsed: dict, tools: list) -> dict:
if parsed["signal"] != "commit":
return parsed
names = {t["name"] for t in tools}
if parsed.get("selected_tool") not in names:
parsed = dict(parsed)
parsed["signal"] = "abstain"
parsed["selected_tool"] = None
parsed["confidence"] = "low"
parsed["reasoning"] = dict(parsed["reasoning"])
parsed["reasoning"]["conclusion"] = (
"[Filter fallback] 模型輸出的 selected_tool 不在候選清單中,安全拒答。"
)
return parsed
def detect(self, user_prompt: str, tools: list, apply_filters: bool = True) -> dict:
raw = self.generate_raw(user_prompt, tools)
parsed = self._parse_json_lenient(raw)
if not apply_filters:
return parsed if parsed else {"_unparseable": True, "_raw": raw}
parsed, _ = self._filter_schema(parsed)
parsed = self._filter_provenance(parsed, tools)
return parsed
if __name__ == "__main__":
# Quick demo
detector = Detector()
result = detector.detect(
user_prompt="請幫我查一下今天台北的 PM2.5 空氣品質指數。",
tools=[
{"name": "web_search", "description": "透過搜尋引擎即時取得網路上最新資訊"},
{"name": "calendar_view", "description": "查看使用者的行事曆"},
{"name": "calculator", "description": "進行數值與數學運算"},
],
)
print(json.dumps(result, ensure_ascii=False, indent=2))