x-guard / model.py
codemo's picture
Upload 7 files
5f7092b verified
import os
import torch
import threading
import re
from typing import List, Dict, Any, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
def resolve_model_path(model_id: str) -> str:
"""
解析模型路径:如果是本地路径则直接返回,否则从 ModelScope 下载。
参数:
model_id: 模型标识符(ModelScope model_id)或本地目录路径
返回:
模型的本地目录路径
"""
if os.path.isdir(model_id):
print(f"使用本地模型: {model_id}")
return model_id
print(f"从 ModelScope 下载模型: {model_id} ...")
from modelscope import snapshot_download
local_path = snapshot_download(model_id)
print(f"模型已下载到: {local_path}")
return local_path
class VisionLanguageModel:
"""
Qwen3-VL 视觉语言模型封装,用于图片内容描述。
支持两种运行模式:
- 在线 API 模式: 通过 DashScope OpenAI 兼容接口调用(速度快,无需 GPU)
- 本地模型模式: 加载模型到本地 GPU/CPU 推理
"""
# 默认图片描述提示 -- 纯内容提取,不含风险分析(风险判断由 XGuard 完成)
DEFAULT_PROMPT = (
"请按以下结构如实描述这张图片,仅提取事实内容,不要做任何风险分析或价值判断:\n\n"
"【图片文字】逐字提取图片中出现的所有文字(包括标题、正文、水印、"
"对话气泡、标语、商标等),保持原文不做任何修改。如果没有文字请注明。\n\n"
"【视觉内容】描述场景、人物、动作、表情、物体、符号等所有可见元素。"
"如果包含敏感、暴力、色情等内容,请如实描述,不要回避。\n\n"
"【内容类型】判断图片类型(如:表情包、聊天截图、广告、新闻、普通照片等)。"
)
def __init__(
self,
model_path: str = None,
device: str = "auto",
use_api: bool = False,
api_base: str = None,
api_key: str = None,
api_model: str = None,
load_local: bool = True,
api_max_calls: int = 200,
):
self.model_path = model_path
self.device = device
self.model = None
self.processor = None
self._lock = threading.Lock()
# 在线 API 调用次数限制
self._api_call_count = 0
self._api_max_calls = api_max_calls
self._api_count_lock = threading.Lock()
# 在线 API 客户端(始终初始化,非常轻量)
self.api_client = None
self.api_model = api_model
if api_base and api_key:
self._init_api_client(api_base, api_key, api_model)
# 本地模型(仅在需要时加载)
self.local_loaded = False
if load_local and model_path:
self._load_local_model()
# ==============================================================
# 在线 API 模式
# ==============================================================
def _init_api_client(self, api_base: str, api_key: str, api_model: str):
"""初始化 DashScope OpenAI 兼容 API 客户端"""
from openai import OpenAI
self.api_client = OpenAI(
api_key=api_key,
base_url=api_base,
)
self.api_model = api_model
print(f"视觉语言模型 API 已就绪: {api_base} / {api_model}")
print(f"API 调用次数上限: {self._api_max_calls}")
# ==============================================================
# API 调用次数限制
# ==============================================================
@property
def api_call_count(self) -> int:
"""当前已使用的 API 调用次数"""
with self._api_count_lock:
return self._api_call_count
@property
def api_remaining(self) -> int:
"""剩余可用的 API 调用次数"""
with self._api_count_lock:
return max(0, self._api_max_calls - self._api_call_count)
@property
def api_limit_reached(self) -> bool:
"""API 调用次数是否已达上限"""
with self._api_count_lock:
return self._api_call_count >= self._api_max_calls
def _increment_api_count(self):
"""递增 API 调用计数(线程安全)"""
with self._api_count_lock:
self._api_call_count += 1
remaining = self._api_max_calls - self._api_call_count
if remaining <= 10 and remaining >= 0:
print(f"[警告] 在线 API 剩余调用次数: {remaining}/{self._api_max_calls}")
elif self._api_call_count == self._api_max_calls:
print(f"[警告] 在线 API 调用次数已达上限 ({self._api_max_calls}),后续将自动降级为本地模型")
@staticmethod
def _image_to_data_url(image_path: str) -> str:
"""将本地图片文件转换为 base64 data URL"""
import base64
with open(image_path, "rb") as f:
data = base64.b64encode(f.read()).decode()
ext = os.path.splitext(image_path)[1].lower()
mime_map = {
".jpg": "image/jpeg", ".jpeg": "image/jpeg",
".png": "image/png", ".gif": "image/gif",
".webp": "image/webp", ".bmp": "image/bmp",
}
mime = mime_map.get(ext, "image/png")
return f"data:{mime};base64,{data}"
def _describe_image_api(self, image_path: str, prompt: str) -> str:
"""通过在线 API 生成图片描述"""
if self.api_client is None:
raise RuntimeError("在线 API 未配置,请检查 vl_api_base / vl_api_key 设置")
data_url = self._image_to_data_url(image_path)
response = self.api_client.chat.completions.create(
model=self.api_model,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": prompt},
],
}
],
max_tokens=512,
)
return response.choices[0].message.content
# ==============================================================
# 本地模型模式
# ==============================================================
def _load_local_model(self):
"""加载本地 Qwen3-VL 模型"""
from transformers import Qwen3VLForConditionalGeneration
local_path = resolve_model_path(self.model_path)
print(f"正在加载本地视觉语言模型: {local_path}...")
self.processor = self._load_processor(local_path)
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
local_path,
torch_dtype="auto",
device_map=self.device,
trust_remote_code=True,
).eval()
self.local_loaded = True
print("本地视觉语言模型加载完成。")
def _load_processor(self, local_path: str):
"""
加载处理器,包含多级回退机制。
某些 transformers 版本中 VIDEO_PROCESSOR_MAPPING_NAMES 未正确初始化,
导致 AutoProcessor.from_pretrained 抛出 TypeError,此处做兼容处理。
"""
# 方式 1: 标准 AutoProcessor 加载
try:
from transformers import AutoProcessor
return AutoProcessor.from_pretrained(
local_path,
trust_remote_code=True,
)
except TypeError as e:
if "NoneType" in str(e):
print(f"AutoProcessor 遇到视频处理器兼容性问题: {e}")
else:
raise
# 方式 2: 修复 VIDEO_PROCESSOR_MAPPING_NAMES 后重试
try:
from transformers.models.auto import video_processing_auto
if video_processing_auto.VIDEO_PROCESSOR_MAPPING_NAMES is None:
video_processing_auto.VIDEO_PROCESSOR_MAPPING_NAMES = {}
print("已修复 VIDEO_PROCESSOR_MAPPING_NAMES 初始化问题,重新加载...")
from transformers import AutoProcessor
return AutoProcessor.from_pretrained(
local_path,
trust_remote_code=True,
)
except Exception as e:
print(f"修复后重试仍失败: {e}")
# 方式 3: 手动组装处理器(仅图片处理能力,不含视频)
print("回退方案: 手动组装处理器...")
from transformers import AutoTokenizer, AutoImageProcessor
tokenizer = AutoTokenizer.from_pretrained(
local_path, trust_remote_code=True
)
image_processor = AutoImageProcessor.from_pretrained(
local_path, trust_remote_code=True
)
try:
from transformers import Qwen3VLProcessor
processor = Qwen3VLProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
)
print("手动组装处理器成功。")
return processor
except (ImportError, Exception) as e:
raise RuntimeError(
f"处理器加载失败: {e}\n"
"请尝试: pip install -U transformers torchvision qwen-vl-utils"
)
def _describe_image_local(self, image_path: str, prompt: str) -> str:
"""使用本地模型生成图片描述"""
if not self.local_loaded:
raise RuntimeError(
"本地视觉模型未加载。请设置 XGUARD_VL_USE_API=false 重启,或切换为在线 API 模式。"
)
with self._lock:
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_path},
{"type": "text", "text": prompt},
],
}
]
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
with torch.no_grad():
generated_ids = self.model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
return output_text[0]
# ==============================================================
# 统一对外接口
# ==============================================================
def _ensure_local_model(self):
"""确保本地模型已加载(用于 API 限额耗尽时的延迟加载)"""
if self.local_loaded:
return
if not self.model_path:
raise RuntimeError(
"在线 API 调用次数已达上限,且未配置本地模型路径 (XGUARD_VL_MODEL_PATH),"
"无法降级到本地模型。请配置本地模型或重启服务以重置 API 计数。"
)
print("[自动降级] API 次数耗尽,正在加载本地视觉语言模型...")
self._load_local_model()
print("[自动降级] 本地视觉语言模型加载完成。")
def describe_image(self, image_path: str, prompt: str = None, use_api: bool = None) -> str:
"""
生成图片描述(统一接口)。
参数:
image_path: 图片文件路径
prompt: 自定义描述提示,为空则使用默认提示
use_api: 是否使用在线 API,为 None 时由 api_client 是否可用决定
返回:
图片的文本描述
注意:
当 use_api=True 但 API 调用次数已达上限时,会自动降级到本地模型。
降级信息通过返回值中的 metadata 属性传递(如有需要请检查 self.api_limit_reached)。
"""
if not prompt:
prompt = self.DEFAULT_PROMPT
# 决定使用哪种模式
if use_api is None:
use_api = self.api_client is not None
# API 调用次数限制检查:超限自动降级
if use_api and self.api_limit_reached:
remaining = self.api_remaining
print(
f"[API 限流] 在线 API 调用已达上限 "
f"({self._api_call_count}/{self._api_max_calls}),自动降级到本地模型"
)
self._ensure_local_model()
use_api = False
if use_api:
self._increment_api_count()
return self._describe_image_api(image_path, prompt)
else:
return self._describe_image_local(image_path, prompt)
class XGuardModel:
"""
YuFeng-XGuard 安全检测模型封装。
推理逻辑完全对齐官方实现:
- apply_chat_template 支持 policy / reason_first 参数
- 通过 decoded text 直接匹配 id2risk(而非 token_id 中转)
- reason_first 模式下正确定位风险 token 的 score 位置
"""
def __init__(self, model_path: str, device: str = "auto"):
self.model_path = model_path
self.device = device
self.model = None
self.tokenizer = None
self.id2risk = None
self._lock = threading.Lock()
self._load_model()
def _load_model(self):
"""加载模型和 tokenizer,提取 id2risk 映射表"""
local_path = resolve_model_path(self.model_path)
print(f"正在加载安全检测模型: {local_path}...")
self.tokenizer = AutoTokenizer.from_pretrained(
local_path,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
local_path,
torch_dtype="auto",
device_map=self.device,
trust_remote_code=True
).eval()
# 从 tokenizer 配置中获取 id2risk 映射
# id2risk 格式: {'sec': 'Safe-Safe', 'pc': 'Crimes and Illegal Activities-Pornographic Contraband', ...}
# key 是短文本标记(如 'sec', 'pc'),value 是风险类别全名
self.id2risk = self.tokenizer.init_kwargs.get('id2risk', {})
print(f"id2risk 映射条目数: {len(self.id2risk)}")
print(f"##################self.id2risk: {self.id2risk} #####################")
if self.id2risk:
print(f"示例映射: {list(self.id2risk.items())[:5]}")
def infer(self, messages: List[Dict[str, str]], policy=None,
max_new_tokens: int = 1, reason_first: bool = False) -> Dict[str, Any]:
"""
官方推理接口,完全对齐 XGuard 官方推理逻辑。
参数:
messages: 对话消息列表
policy: 动态策略(可选),用于运行时自定义安全检测规则
max_new_tokens: 最大生成 token 数
reason_first: 是否先生成归因分析再输出风险 token
返回:
{
'response': str, # 完整解码文本
'token_score': {text: prob, ...}, # 风险 token 位置的 topk token 分数
'risk_score': {risk_name: prob, ...} # 匹配到 id2risk 的风险类别分数
}
"""
with self._lock:
# 使用 chat template 渲染输入(含 policy 和 reason_first 参数)
rendered_query = self.tokenizer.apply_chat_template(
messages,
policy=policy,
reason_first=reason_first,
tokenize=False
)
model_inputs = self.tokenizer(
[rendered_query], return_tensors="pt"
).to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
output_scores=True,
return_dict_in_generate=True
)
batch_idx = 0
input_length = model_inputs['input_ids'].shape[1]
# 解码响应文本
output_ids = outputs["sequences"].tolist()[batch_idx][input_length:]
response = self.tokenizer.decode(output_ids, skip_special_tokens=True)
# ---- 解析每个生成位置的 topk 分数 (官方逻辑) ----
generated_tokens = outputs.sequences[:, input_length:]
scores = torch.stack(outputs.scores, dim=1)
scores = scores.softmax(dim=-1)
scores_topk_value, scores_topk_index = scores.topk(k=10, dim=-1)
generated_tokens_with_probs = []
for generated_token, score_topk_value, score_topk_index in zip(
generated_tokens, scores_topk_value, scores_topk_index
):
generated_tokens_with_prob = []
for token, topk_value, topk_index in zip(
generated_token, score_topk_value, score_topk_index
):
token = int(token.cpu())
if token == self.tokenizer.pad_token_id:
continue
res_topk_score = {}
for ii, (value, index) in enumerate(zip(topk_value, topk_index)):
if ii == 0 or value.cpu().numpy() > 1e-4:
text = self.tokenizer.decode(index.cpu().numpy())
res_topk_score[text] = {
"id": str(int(index.cpu().numpy())),
"prob": round(float(value.cpu().numpy()), 4),
}
generated_tokens_with_prob.append(res_topk_score)
generated_tokens_with_probs.append(generated_tokens_with_prob)
# 确定风险分数的 token 位置索引
# reason_first=False: 风险 token 在第一个位置 (idx=0)
# reason_first=True: 风险 token 在倒数第二个位置 (reasoning 后、EOS 前)
score_idx = (
max(len(generated_tokens_with_probs[batch_idx]) - 2, 0)
if reason_first else 0
)
# 提取 token 分数和风险分数(官方方式: decoded text 直接匹配 id2risk)
token_score = {
k: v['prob']
for k, v in generated_tokens_with_probs[batch_idx][score_idx].items()
}
risk_score = {
self.id2risk[k]: v['prob']
for k, v in generated_tokens_with_probs[batch_idx][score_idx].items()
if k in self.id2risk
}
return {
'response': response,
'token_score': token_score,
'risk_score': risk_score,
}
def parse_explanation(self, response: str) -> Optional[str]:
"""
从响应中解析归因分析部分。
XGuard 在 reason_first=False 模式下,输出格式为:
[风险分类 token][归因分析文本]
风险 token 是 id2risk 中的短字符串 key(如 'sec', 'pc' 等),
后续文本为自然语言的归因分析说明。
"""
if not response or not response.strip():
return None
# 方式 1: 兼容 <explanation>...</explanation> 标签格式
match = re.search(r'<explanation>(.*?)</explanation>', response, re.DOTALL)
if match:
return match.group(1).strip()
text = response.strip()
# 方式 2: 剥离开头的风险分类 token,提取后续归因文本
# id2risk 的 key 是短字符串(如 'sec', 'pc'),模型输出以它开头
if self.id2risk:
for key in sorted(self.id2risk.keys(), key=len, reverse=True):
if text.startswith(key):
remainder = text[len(key):].strip()
if remainder:
return remainder
break # 匹配到 token 但无后续文本,说明未生成归因
# 方式 3: 响应长度明显超过单个风险 token(通常 2-4 字符),直接作为归因返回
if len(text) > 8:
return text
return None
def analyze(self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]],
enable_reasoning: bool = False, policy=None) -> Dict[str, Any]:
"""
高层分析接口,封装推理结果为结构化格式。
参数:
messages: 对话消息列表
tools: 工具信息(已拼接到 messages 中,暂未使用)
enable_reasoning: 是否启用归因分析(生成更多 token)
policy: 动态策略(可选)
"""
# 启用归因分析时生成更多 token 以获取完整的归因解释
max_new_tokens = 512 if enable_reasoning else 1
infer_result = self.infer(
messages,
policy=policy,
max_new_tokens=max_new_tokens,
reason_first=False
)
risk_scores = infer_result.get("risk_score", {})
response = infer_result.get("response", "")
# ================================================================
# 风险判定 — 基于 XGuard 论文的 argmax + 置信度分级框架
#
# 理论基础 (arxiv 2601.15588):
# XGuard 的训练目标 max_θ log P(y_cls | X; θ) 保证
# 第一个 token 的 softmax argmax = 模型预测的风险类别。
# 概率值即为校准后的置信度。
#
# 判定流程:
# Layer 1 — argmax: 概率最高的类别就是模型的答案
# Layer 2 — 置信度门控: safe 判定需 >= 0.5 (过半数置信)
# Layer 3 — 风险分级: 按 top_risk_prob 划分 high/medium/low
# ================================================================
SAFE_CATEGORY = "Safe-Safe"
safe_prob = risk_scores.get(SAFE_CATEGORY, 0.0)
# 提取非安全类风险项,按分数降序排列
risk_items = {k: v for k, v in risk_scores.items() if k != SAFE_CATEGORY}
sorted_risks = sorted(risk_items.items(), key=lambda x: x[1], reverse=True)
top_risk_name = sorted_risks[0][0] if sorted_risks else ""
top_risk_prob = sorted_risks[0][1] if sorted_risks else 0.0
# Layer 1 + 2: argmax 决策 + 置信度门控
if safe_prob >= top_risk_prob and safe_prob >= 0.5:
# argmax = Safe-Safe, 且置信度过半 → 判定安全
is_safe = 1
risk_level = "safe"
elif safe_prob >= top_risk_prob:
# argmax = Safe-Safe, 但置信度不足 0.5
# 模型最倾向安全,但不够确定,谨慎标记为低风险
is_safe = 0
risk_level = "low"
else:
# argmax = 某风险类别 (top_risk_prob > safe_prob)
# Layer 3: 按风险置信度分级
is_safe = 0
if top_risk_prob >= 0.5:
risk_level = "high"
elif top_risk_prob >= 0.3:
risk_level = "medium"
else:
risk_level = "low"
# 置信度: 模型对当前判定的确信程度
confidence = safe_prob if is_safe == 1 else top_risk_prob
# 构建风险类型列表和原因说明
# 无论安全与否,始终输出最高风险项作为风险提示
if is_safe == 0:
top_risks = sorted_risks[:3]
else:
# 安全时仅取最高风险项作为提示
top_risks = sorted_risks[:1] if sorted_risks else []
risk_types = [r[0] for r in top_risks]
reason = "; ".join([f"{r}: {s}" for r, s in top_risks])
result = {
"is_safe": is_safe,
"risk_level": risk_level,
"confidence": round(confidence, 4),
"risk_type": risk_types,
"reason": reason,
"detail_scores": risk_scores,
"response": response
}
# 如果启用了归因分析,解析并添加 explanation
if enable_reasoning:
explanation = self.parse_explanation(response)
if explanation:
result["explanation"] = explanation
return result