ic_content_firewall_zh / inference.py
GOSHUNCLE's picture
Initial release: v1 adapter + inference code
7af1b94 verified
"""
IC Design Content Firewall (ZH) — Inference Module
Loads the LoRA adapter from Hugging Face Hub on top of Qwen2.5-1.5B-Instruct,
produces structured JSON for IC-design content moderation.
Output schema:
{
"content_categories": [...], # multi-label, 9 types
"risk_level": "L0/L1/L2/L3",
"sensitive_entities": [{"type", "value", "reason"}, ...],
"reasoning": {"category_reason", "risk_reason"}
}
Two universal filters (Inference safeguards):
Filter 1 - Schema validation : 確保 enum 合法
Filter 2 - Provenance check : 確保 entity value 必在原文(防幻覺)
"""
import json
import re
from typing import Any, Dict, Optional
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# =============================================================================
# 預設 model 路徑(HF Hub)
# =============================================================================
DEFAULT_BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_ADAPTER = "GOSHUNCLE/ic-firewall-zh"
MAX_NEW_TOKENS = 512
# =============================================================================
# System prompt
# =============================================================================
SYSTEM_PROMPT = """你是 IC 設計業內容防火牆分析專家。請分析輸入文字,輸出嚴格 JSON。
任務:
1. content_categories:從以下 9 類中選擇相關的(multi-label):
RTL, CUSTOMER, QUOTE, VENDOR, PROCESS, SCHEDULE, TESTING, INTERNAL, PUBLIC
2. risk_level:選擇 L0/L1/L2/L3 之一
3. sensitive_entities:標註所有機敏實體(不取代,僅標註)
每筆含 type、value(必須逐字在原文)、reason(10-20 字)
4. reasoning:含 category_reason 與 risk_reason(各 30-60 字)
實體類型:
CUSTOMER, PROJECT, VENDOR, PRICE, PROCESS_NODE,
MODULE_NAME, IP_BLOCK, YIELD_NUMBER, SPEC_PARAM
輸出格式:
{"content_categories":[...],"risk_level":"L?","sensitive_entities":[{"type":"...","value":"...","reason":"..."}],"reasoning":{"category_reason":"...","risk_reason":"..."}}
規則:
1. value 必須逐字出現在原文
2. 不要保險全選類別,只標 actually 相關的
3. 只回傳 JSON,無其他說明文字"""
# =============================================================================
# 合法 enum
# =============================================================================
VALID_CATEGORIES = {
"RTL", "CUSTOMER", "QUOTE", "VENDOR", "PROCESS",
"SCHEDULE", "TESTING", "INTERNAL", "PUBLIC"
}
VALID_RISKS = {"L0", "L1", "L2", "L3"}
VALID_ENTITY_TYPES = {
"CUSTOMER", "PROJECT", "VENDOR", "PRICE", "PROCESS_NODE",
"MODULE_NAME", "IP_BLOCK", "YIELD_NUMBER", "SPEC_PARAM"
}
# =============================================================================
# JSON 解析輔助
# =============================================================================
def extract_json(raw_text: str) -> Optional[Dict[str, Any]]:
text = raw_text.strip()
try:
return json.loads(text)
except Exception:
pass
s, e = text.find("{"), text.rfind("}")
if s >= 0 and e > s:
try:
return json.loads(text[s:e + 1])
except Exception:
pass
cleaned = re.sub(r"```(?:json)?", "", text).strip().replace("```", "").strip()
try:
return json.loads(cleaned)
except Exception:
return None
# =============================================================================
# Filter 1 — Schema validation
# =============================================================================
def filter_schema(parsed: Dict[str, Any]) -> Dict[str, Any]:
cleaned = {
"content_categories": [],
"risk_level": "L1",
"sensitive_entities": [],
"reasoning": {"category_reason": "", "risk_reason": ""},
}
cats = parsed.get("content_categories", [])
if isinstance(cats, list):
seen = set()
for c in cats:
if isinstance(c, str) and c in VALID_CATEGORIES and c not in seen:
cleaned["content_categories"].append(c)
seen.add(c)
if not cleaned["content_categories"]:
cleaned["content_categories"] = ["INTERNAL"]
risk = parsed.get("risk_level", "")
if isinstance(risk, str) and risk in VALID_RISKS:
cleaned["risk_level"] = risk
entities = parsed.get("sensitive_entities", [])
if isinstance(entities, list):
for e in entities:
if not isinstance(e, dict):
continue
t = e.get("type", "")
v = e.get("value", "")
r = e.get("reason", "")
if not (isinstance(t, str) and t in VALID_ENTITY_TYPES):
continue
if not (isinstance(v, str) and v):
continue
cleaned["sensitive_entities"].append({
"type": t, "value": v, "reason": str(r),
})
rsn = parsed.get("reasoning", {})
if isinstance(rsn, dict):
cr = rsn.get("category_reason", "")
rr = rsn.get("risk_reason", "")
if isinstance(cr, str):
cleaned["reasoning"]["category_reason"] = cr
if isinstance(rr, str):
cleaned["reasoning"]["risk_reason"] = rr
return cleaned
# =============================================================================
# Filter 2 — Provenance check
# =============================================================================
def _normalize(s: str) -> str:
return s.replace(" ", "").replace(" ", "").replace("\t", "")
def value_in_text(value: str, text: str) -> bool:
if not value:
return False
if value in text:
return True
return _normalize(value) in _normalize(text)
def filter_provenance(parsed: Dict[str, Any], input_text: str) -> Dict[str, Any]:
kept = [e for e in parsed.get("sensitive_entities", [])
if value_in_text(e.get("value", ""), input_text)]
parsed["sensitive_entities"] = kept
return parsed
# =============================================================================
# Detector
# =============================================================================
class Detector:
def __init__(self,
base_model_path: str = DEFAULT_BASE_MODEL,
adapter_path: Optional[str] = DEFAULT_ADAPTER,
use_adapter: bool = True):
print(f"[Detector] Loading tokenizer from {base_model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(base_model_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f"[Detector] Loading base model (fp32 CPU)...")
self.model = AutoModelForCausalLM.from_pretrained(
base_model_path,
torch_dtype=torch.float32,
device_map="cpu",
)
if use_adapter and adapter_path:
print(f"[Detector] Applying LoRA adapter from {adapter_path}")
self.model = PeftModel.from_pretrained(self.model, adapter_path)
self.model.eval()
print(f"[Detector] Ready.")
def generate_raw(self, text: str) -> str:
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": text},
]
prompt_text = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
inputs = self.tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
num_beams=1,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
return self.tokenizer.decode(new_tokens, skip_special_tokens=True)
@staticmethod
def post_process(raw_output: str, input_text: str,
apply_filters: bool = True) -> Dict[str, Any]:
parsed = extract_json(raw_output)
if parsed is None:
return {
"content_categories": ["INTERNAL"] if apply_filters else [],
"risk_level": "L1" if apply_filters else "",
"sensitive_entities": [],
"reasoning": {
"category_reason": "(JSON parse failed)",
"risk_reason": "(JSON parse failed)",
},
"_parse_failed": True,
}
if apply_filters:
result = filter_schema(parsed)
result = filter_provenance(result, input_text)
else:
# Lenient mode
result = {
"content_categories": parsed.get("content_categories", []) if isinstance(parsed.get("content_categories"), list) else [],
"risk_level": parsed.get("risk_level", "") if isinstance(parsed.get("risk_level"), str) else "",
"sensitive_entities": parsed.get("sensitive_entities", []) if isinstance(parsed.get("sensitive_entities"), list) else [],
"reasoning": parsed.get("reasoning", {"category_reason": "", "risk_reason": ""}) if isinstance(parsed.get("reasoning"), dict) else {"category_reason": "", "risk_reason": ""},
}
return result
def detect(self, text: str, apply_filters: bool = True) -> Dict[str, Any]:
raw = self.generate_raw(text)
return self.post_process(raw, text, apply_filters=apply_filters)