File size: 3,864 Bytes
db60e24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
评估脚本:在测试集上评估模型性能
"""

import os
import json
import argparse

import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    classification_report,
)


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "precision": precision_score(labels, preds, average="binary"),
        "recall": recall_score(labels, preds, average="binary"),
        "f1": f1_score(labels, preds, average="binary"),
    }


def main():
    parser = argparse.ArgumentParser(description="评估敏感词过滤模型")
    parser.add_argument("--model_path", type=str, default="output/best_model")
    parser.add_argument("--test_file", type=str, default="data/test.csv")
    parser.add_argument("--batch_size", type=int, default=64)
    args = parser.parse_args()

    # 加载配置
    config_path = os.path.join(args.model_path, "filter_config.json")
    with open(config_path, "r", encoding="utf-8") as f:
        config = json.load(f)

    max_length = config.get("max_length", 128)
    label_names = [config["label_map"]["0"], config["label_map"]["1"]]

    # 加载模型和 tokenizer
    print(f"加载模型: {args.model_path}")
    tokenizer = BertTokenizer.from_pretrained(args.model_path)
    model = BertForSequenceClassification.from_pretrained(args.model_path)

    # 加载测试集
    dataset = load_dataset("csv", data_files={"test": args.test_file})

    def tokenize_fn(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length)

    dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
    dataset = dataset.rename_column("label", "labels")
    dataset.set_format("torch")
    print(f"测试集: {len(dataset['test'])} 条")

    # 评估
    training_args = TrainingArguments(
        output_dir="/tmp/eval_output",
        per_device_eval_batch_size=args.batch_size,
        report_to="none",
    )
    trainer = Trainer(model=model, args=training_args, compute_metrics=compute_metrics)

    # 预测
    predictions = trainer.predict(dataset["test"])
    preds = np.argmax(predictions.predictions, axis=-1)
    labels = predictions.label_ids

    acc = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average="binary")
    recall = recall_score(labels, preds, average="binary")
    f1 = f1_score(labels, preds, average="binary")
    cm = confusion_matrix(labels, preds)

    print(f"\n{'='*60}")
    print("模型评估结果")
    print(f"{'='*60}")
    print(f"准确率 (Accuracy):  {acc:.4f}")
    print(f"精确率 (Precision): {precision:.4f}")
    print(f"召回率 (Recall):    {recall:.4f}")
    print(f"F1 值 (F1-Score):   {f1:.4f}")

    print(f"\n--- 混淆矩阵 ---")
    print(f"{'':>12} 预测正常  预测敏感")
    print(f"{'实际正常':>10}  {cm[0][0]:>6}  {cm[0][1]:>6}")
    print(f"{'实际敏感':>10}  {cm[1][0]:>6}  {cm[1][1]:>6}")

    print(f"\n--- 分类报告 ---")
    report = classification_report(labels, preds, target_names=label_names, digits=4)
    print(report)

    # 保存结果
    results = {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "confusion_matrix": cm.tolist(),
    }
    output_path = os.path.join(os.path.dirname(args.model_path), "eval_results.json")
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    print(f"\n评估结果已保存至: {output_path}")


if __name__ == "__main__":
    main()