""" IC Design Content Firewall (ZH) — Inference Module Loads the LoRA adapter from Hugging Face Hub on top of Qwen2.5-1.5B-Instruct, produces structured JSON for IC-design content moderation. Output schema: { "content_categories": [...], # multi-label, 9 types "risk_level": "L0/L1/L2/L3", "sensitive_entities": [{"type", "value", "reason"}, ...], "reasoning": {"category_reason", "risk_reason"} } Two universal filters (Inference safeguards): Filter 1 - Schema validation : 確保 enum 合法 Filter 2 - Provenance check : 確保 entity value 必在原文(防幻覺) """ import json import re from typing import Any, Dict, Optional import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer # ============================================================================= # 預設 model 路徑(HF Hub) # ============================================================================= DEFAULT_BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct" DEFAULT_ADAPTER = "GOSHUNCLE/ic-firewall-zh" MAX_NEW_TOKENS = 512 # ============================================================================= # System prompt # ============================================================================= SYSTEM_PROMPT = """你是 IC 設計業內容防火牆分析專家。請分析輸入文字,輸出嚴格 JSON。 任務: 1. content_categories:從以下 9 類中選擇相關的(multi-label): RTL, CUSTOMER, QUOTE, VENDOR, PROCESS, SCHEDULE, TESTING, INTERNAL, PUBLIC 2. risk_level:選擇 L0/L1/L2/L3 之一 3. sensitive_entities:標註所有機敏實體(不取代,僅標註) 每筆含 type、value(必須逐字在原文)、reason(10-20 字) 4. reasoning:含 category_reason 與 risk_reason(各 30-60 字) 實體類型: CUSTOMER, PROJECT, VENDOR, PRICE, PROCESS_NODE, MODULE_NAME, IP_BLOCK, YIELD_NUMBER, SPEC_PARAM 輸出格式: {"content_categories":[...],"risk_level":"L?","sensitive_entities":[{"type":"...","value":"...","reason":"..."}],"reasoning":{"category_reason":"...","risk_reason":"..."}} 規則: 1. value 必須逐字出現在原文 2. 不要保險全選類別,只標 actually 相關的 3. 只回傳 JSON,無其他說明文字""" # ============================================================================= # 合法 enum # ============================================================================= VALID_CATEGORIES = { "RTL", "CUSTOMER", "QUOTE", "VENDOR", "PROCESS", "SCHEDULE", "TESTING", "INTERNAL", "PUBLIC" } VALID_RISKS = {"L0", "L1", "L2", "L3"} VALID_ENTITY_TYPES = { "CUSTOMER", "PROJECT", "VENDOR", "PRICE", "PROCESS_NODE", "MODULE_NAME", "IP_BLOCK", "YIELD_NUMBER", "SPEC_PARAM" } # ============================================================================= # JSON 解析輔助 # ============================================================================= def extract_json(raw_text: str) -> Optional[Dict[str, Any]]: text = raw_text.strip() try: return json.loads(text) except Exception: pass s, e = text.find("{"), text.rfind("}") if s >= 0 and e > s: try: return json.loads(text[s:e + 1]) except Exception: pass cleaned = re.sub(r"```(?:json)?", "", text).strip().replace("```", "").strip() try: return json.loads(cleaned) except Exception: return None # ============================================================================= # Filter 1 — Schema validation # ============================================================================= def filter_schema(parsed: Dict[str, Any]) -> Dict[str, Any]: cleaned = { "content_categories": [], "risk_level": "L1", "sensitive_entities": [], "reasoning": {"category_reason": "", "risk_reason": ""}, } cats = parsed.get("content_categories", []) if isinstance(cats, list): seen = set() for c in cats: if isinstance(c, str) and c in VALID_CATEGORIES and c not in seen: cleaned["content_categories"].append(c) seen.add(c) if not cleaned["content_categories"]: cleaned["content_categories"] = ["INTERNAL"] risk = parsed.get("risk_level", "") if isinstance(risk, str) and risk in VALID_RISKS: cleaned["risk_level"] = risk entities = parsed.get("sensitive_entities", []) if isinstance(entities, list): for e in entities: if not isinstance(e, dict): continue t = e.get("type", "") v = e.get("value", "") r = e.get("reason", "") if not (isinstance(t, str) and t in VALID_ENTITY_TYPES): continue if not (isinstance(v, str) and v): continue cleaned["sensitive_entities"].append({ "type": t, "value": v, "reason": str(r), }) rsn = parsed.get("reasoning", {}) if isinstance(rsn, dict): cr = rsn.get("category_reason", "") rr = rsn.get("risk_reason", "") if isinstance(cr, str): cleaned["reasoning"]["category_reason"] = cr if isinstance(rr, str): cleaned["reasoning"]["risk_reason"] = rr return cleaned # ============================================================================= # Filter 2 — Provenance check # ============================================================================= def _normalize(s: str) -> str: return s.replace(" ", "").replace(" ", "").replace("\t", "") def value_in_text(value: str, text: str) -> bool: if not value: return False if value in text: return True return _normalize(value) in _normalize(text) def filter_provenance(parsed: Dict[str, Any], input_text: str) -> Dict[str, Any]: kept = [e for e in parsed.get("sensitive_entities", []) if value_in_text(e.get("value", ""), input_text)] parsed["sensitive_entities"] = kept return parsed # ============================================================================= # Detector # ============================================================================= class Detector: def __init__(self, base_model_path: str = DEFAULT_BASE_MODEL, adapter_path: Optional[str] = DEFAULT_ADAPTER, use_adapter: bool = True): print(f"[Detector] Loading tokenizer from {base_model_path}") self.tokenizer = AutoTokenizer.from_pretrained(base_model_path) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print(f"[Detector] Loading base model (fp32 CPU)...") self.model = AutoModelForCausalLM.from_pretrained( base_model_path, torch_dtype=torch.float32, device_map="cpu", ) if use_adapter and adapter_path: print(f"[Detector] Applying LoRA adapter from {adapter_path}") self.model = PeftModel.from_pretrained(self.model, adapter_path) self.model.eval() print(f"[Detector] Ready.") def generate_raw(self, text: str) -> str: messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": text}, ] prompt_text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = self.tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, num_beams=1, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, ) new_tokens = outputs[0][inputs["input_ids"].shape[1]:] return self.tokenizer.decode(new_tokens, skip_special_tokens=True) @staticmethod def post_process(raw_output: str, input_text: str, apply_filters: bool = True) -> Dict[str, Any]: parsed = extract_json(raw_output) if parsed is None: return { "content_categories": ["INTERNAL"] if apply_filters else [], "risk_level": "L1" if apply_filters else "", "sensitive_entities": [], "reasoning": { "category_reason": "(JSON parse failed)", "risk_reason": "(JSON parse failed)", }, "_parse_failed": True, } if apply_filters: result = filter_schema(parsed) result = filter_provenance(result, input_text) else: # Lenient mode result = { "content_categories": parsed.get("content_categories", []) if isinstance(parsed.get("content_categories"), list) else [], "risk_level": parsed.get("risk_level", "") if isinstance(parsed.get("risk_level"), str) else "", "sensitive_entities": parsed.get("sensitive_entities", []) if isinstance(parsed.get("sensitive_entities"), list) else [], "reasoning": parsed.get("reasoning", {"category_reason": "", "risk_reason": ""}) if isinstance(parsed.get("reasoning"), dict) else {"category_reason": "", "risk_reason": ""}, } return result def detect(self, text: str, apply_filters: bool = True) -> Dict[str, Any]: raw = self.generate_raw(text) return self.post_process(raw, text, apply_filters=apply_filters)