File size: 5,427 Bytes
db60e24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
推理脚本:使用训练好的模型预测文本是否包含敏感内容
"""

import os
import json
import argparse

import torch
from transformers import BertTokenizer, BertForSequenceClassification


class SensitiveWordPredictor:
    """敏感词预测器"""

    def __init__(self, model_path: str, device: str = None):
        if device:
            self.device = torch.device(device)
        else:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # 加载配置
        config_path = os.path.join(model_path, "filter_config.json")
        with open(config_path, "r", encoding="utf-8") as f:
            self.config = json.load(f)

        self.label_map = self.config["label_map"]
        self.max_length = self.config.get("max_length", 128)

        # 加载模型和 tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(model_path)
        self.model = BertForSequenceClassification.from_pretrained(model_path)
        self.model.to(self.device)
        self.model.eval()

    def predict(self, text: str) -> dict:
        """预测单条文本"""
        encoding = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        input_ids = encoding["input_ids"].to(self.device)
        attention_mask = encoding["attention_mask"].to(self.device)

        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
            probs = torch.softmax(outputs.logits, dim=1)
            pred_label = torch.argmax(probs, dim=1).item()
            confidence = probs[0][pred_label].item()

        return {
            "text": text,
            "label": pred_label,
            "label_name": self.label_map[str(pred_label)],
            "confidence": round(confidence, 4),
            "is_sensitive": pred_label == 1,
        }

    def predict_batch(self, texts: list[str], batch_size: int = 32) -> list[dict]:
        """批量预测"""
        results = []
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i : i + batch_size]
            encoding = self.tokenizer(
                batch_texts,
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt",
            )

            input_ids = encoding["input_ids"].to(self.device)
            attention_mask = encoding["attention_mask"].to(self.device)

            with torch.no_grad():
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                probs = torch.softmax(outputs.logits, dim=1)
                pred_labels = torch.argmax(probs, dim=1)

            for j, text in enumerate(batch_texts):
                label = pred_labels[j].item()
                conf = probs[j][label].item()
                results.append({
                    "text": text,
                    "label": label,
                    "label_name": self.label_map[str(label)],
                    "confidence": round(conf, 4),
                    "is_sensitive": label == 1,
                })

        return results


def main():
    parser = argparse.ArgumentParser(description="敏感词预测")
    parser.add_argument("--model_path", type=str, default="output/best_model")
    parser.add_argument("--text", type=str, help="要检测的文本(单条)")
    parser.add_argument("--file", type=str, help="要检测的文本文件(每行一条)")
    parser.add_argument("--device", type=str, default=None)
    args = parser.parse_args()

    predictor = SensitiveWordPredictor(args.model_path, args.device)

    if args.text:
        result = predictor.predict(args.text)
        status = "🔴 敏感" if result["is_sensitive"] else "🟢 正常"
        print(f"\n输入: {result['text']}")
        print(f"结果: {status}")
        print(f"标签: {result['label_name']} (label={result['label']})")
        print(f"置信度: {result['confidence']:.4f}")

    elif args.file:
        with open(args.file, "r", encoding="utf-8") as f:
            texts = [line.strip() for line in f if line.strip()]

        results = predictor.predict_batch(texts)
        print(f"\n{'='*70}")
        print(f"批量检测结果 (共 {len(results)} 条)")
        print(f"{'='*70}")

        sensitive_count = 0
        for r in results:
            status = "🔴 敏感" if r["is_sensitive"] else "🟢 正常"
            print(f"  {status} [{r['confidence']:.4f}] {r['text'][:50]}...")
            if r["is_sensitive"]:
                sensitive_count += 1

        print(f"\n统计: 正常 {len(results) - sensitive_count} 条, 敏感 {sensitive_count} 条")

    else:
        print("敏感词检测系统 - 交互模式")
        print("输入文本进行检测,输入 'quit' 退出\n")
        while True:
            text = input("请输入文本> ").strip()
            if text.lower() in ("quit", "exit", "q"):
                print("再见!")
                break
            if not text:
                continue
            result = predictor.predict(text)
            status = "🔴 敏感" if result["is_sensitive"] else "🟢 正常"
            print(f"  结果: {status} | 置信度: {result['confidence']:.4f}\n")


if __name__ == "__main__":
    main()