""" 推理脚本:使用训练好的模型预测文本是否包含敏感内容 """ import os import json import argparse import torch from transformers import BertTokenizer, BertForSequenceClassification class SensitiveWordPredictor: """敏感词预测器""" def __init__(self, model_path: str, device: str = None): if device: self.device = torch.device(device) else: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 加载配置 config_path = os.path.join(model_path, "filter_config.json") with open(config_path, "r", encoding="utf-8") as f: self.config = json.load(f) self.label_map = self.config["label_map"] self.max_length = self.config.get("max_length", 128) # 加载模型和 tokenizer self.tokenizer = BertTokenizer.from_pretrained(model_path) self.model = BertForSequenceClassification.from_pretrained(model_path) self.model.to(self.device) self.model.eval() def predict(self, text: str) -> dict: """预测单条文本""" encoding = self.tokenizer( text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt", ) input_ids = encoding["input_ids"].to(self.device) attention_mask = encoding["attention_mask"].to(self.device) with torch.no_grad(): outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) probs = torch.softmax(outputs.logits, dim=1) pred_label = torch.argmax(probs, dim=1).item() confidence = probs[0][pred_label].item() return { "text": text, "label": pred_label, "label_name": self.label_map[str(pred_label)], "confidence": round(confidence, 4), "is_sensitive": pred_label == 1, } def predict_batch(self, texts: list[str], batch_size: int = 32) -> list[dict]: """批量预测""" results = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i : i + batch_size] encoding = self.tokenizer( batch_texts, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt", ) input_ids = encoding["input_ids"].to(self.device) attention_mask = encoding["attention_mask"].to(self.device) with torch.no_grad(): outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) probs = torch.softmax(outputs.logits, dim=1) pred_labels = torch.argmax(probs, dim=1) for j, text in enumerate(batch_texts): label = pred_labels[j].item() conf = probs[j][label].item() results.append({ "text": text, "label": label, "label_name": self.label_map[str(label)], "confidence": round(conf, 4), "is_sensitive": label == 1, }) return results def main(): parser = argparse.ArgumentParser(description="敏感词预测") parser.add_argument("--model_path", type=str, default="output/best_model") parser.add_argument("--text", type=str, help="要检测的文本(单条)") parser.add_argument("--file", type=str, help="要检测的文本文件(每行一条)") parser.add_argument("--device", type=str, default=None) args = parser.parse_args() predictor = SensitiveWordPredictor(args.model_path, args.device) if args.text: result = predictor.predict(args.text) status = "🔴 敏感" if result["is_sensitive"] else "🟢 正常" print(f"\n输入: {result['text']}") print(f"结果: {status}") print(f"标签: {result['label_name']} (label={result['label']})") print(f"置信度: {result['confidence']:.4f}") elif args.file: with open(args.file, "r", encoding="utf-8") as f: texts = [line.strip() for line in f if line.strip()] results = predictor.predict_batch(texts) print(f"\n{'='*70}") print(f"批量检测结果 (共 {len(results)} 条)") print(f"{'='*70}") sensitive_count = 0 for r in results: status = "🔴 敏感" if r["is_sensitive"] else "🟢 正常" print(f" {status} [{r['confidence']:.4f}] {r['text'][:50]}...") if r["is_sensitive"]: sensitive_count += 1 print(f"\n统计: 正常 {len(results) - sensitive_count} 条, 敏感 {sensitive_count} 条") else: print("敏感词检测系统 - 交互模式") print("输入文本进行检测,输入 'quit' 退出\n") while True: text = input("请输入文本> ").strip() if text.lower() in ("quit", "exit", "q"): print("再见!") break if not text: continue result = predictor.predict(text) status = "🔴 敏感" if result["is_sensitive"] else "🟢 正常" print(f" 结果: {status} | 置信度: {result['confidence']:.4f}\n") if __name__ == "__main__": main()