File size: 4,701 Bytes
db60e24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
训练脚本:使用 HuggingFace Trainer 微调 BERT 进行敏感词二分类

使用 BertForSequenceClassification + Trainer API,
支持自动混合精度、梯度累积、学习率调度等。
"""

import os
import json
import argparse
import numpy as np
import torch

from transformers import (
    BertTokenizer,
    BertForSequenceClassification,
    TrainingArguments,
    Trainer,
)
from datasets import load_dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


def compute_metrics(eval_pred):
    """计算评估指标"""
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "precision": precision_score(labels, preds, average="binary"),
        "recall": recall_score(labels, preds, average="binary"),
        "f1": f1_score(labels, preds, average="binary"),
    }


def main():
    parser = argparse.ArgumentParser(description="训练敏感词过滤模型")
    parser.add_argument("--model_name", type=str, default="bert-base-chinese")
    parser.add_argument("--train_file", type=str, default="data/train.csv")
    parser.add_argument("--val_file", type=str, default="data/val.csv")
    parser.add_argument("--output_dir", type=str, default="output")
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=2e-5)
    parser.add_argument("--max_length", type=int, default=128)
    parser.add_argument("--warmup_ratio", type=float, default=0.1)
    parser.add_argument("--weight_decay", type=float, default=0.01)
    args = parser.parse_args()

    # 加载 tokenizer
    print(f"加载 tokenizer: {args.model_name}")
    tokenizer = BertTokenizer.from_pretrained(args.model_name)

    # 加载数据集
    print("加载数据集...")
    dataset = load_dataset(
        "csv",
        data_files={"train": args.train_file, "validation": args.val_file},
    )
    print(f"训练集: {len(dataset['train'])} 条")
    print(f"验证集: {len(dataset['validation'])} 条")

    # Tokenize
    def tokenize_fn(examples):
        return tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=args.max_length,
        )

    dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
    dataset = dataset.rename_column("label", "labels")
    dataset.set_format("torch")

    # 加载模型
    print(f"加载模型: {args.model_name}")
    model = BertForSequenceClassification.from_pretrained(
        args.model_name,
        num_labels=2,
    )

    # 训练参数
    best_model_dir = os.path.join(args.output_dir, "best_model")
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size * 2,
        learning_rate=args.lr,
        weight_decay=args.weight_decay,
        warmup_ratio=args.warmup_ratio,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        logging_steps=50,
        fp16=False,
        bf16=False,
        save_total_limit=2,
        report_to="none",
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],
        compute_metrics=compute_metrics,
    )

    # 训练
    print(f"\n{'='*60}")
    print(f"开始训练")
    print(f"Epochs: {args.epochs}, Batch Size: {args.batch_size}, LR: {args.lr}")
    print(f"{'='*60}\n")

    trainer.train()

    # 保存最佳模型
    print(f"\n保存最佳模型至: {best_model_dir}")
    trainer.save_model(best_model_dir)
    tokenizer.save_pretrained(best_model_dir)

    # 保存配置信息
    config = {
        "model_name": args.model_name,
        "max_length": args.max_length,
        "num_labels": 2,
        "label_map": {"0": "正常", "1": "敏感"},
    }
    with open(os.path.join(best_model_dir, "filter_config.json"), "w", encoding="utf-8") as f:
        json.dump(config, f, ensure_ascii=False, indent=2)

    # 最终评估
    final_metrics = trainer.evaluate()
    print(f"\n{'='*60}")
    print(f"训练完成!")
    print(f"F1: {final_metrics.get('eval_f1', 'N/A'):.4f}")
    print(f"Accuracy: {final_metrics.get('eval_accuracy', 'N/A'):.4f}")
    print(f"模型已保存至: {best_model_dir}")
    print(f"{'='*60}")


if __name__ == "__main__":
    main()