crackrammer's picture
Upload folder using huggingface_hub
db60e24 verified
"""
训练脚本:使用 HuggingFace Trainer 微调 BERT 进行敏感词二分类
使用 BertForSequenceClassification + Trainer API,
支持自动混合精度、梯度累积、学习率调度等。
"""
import os
import json
import argparse
import numpy as np
import torch
from transformers import (
BertTokenizer,
BertForSequenceClassification,
TrainingArguments,
Trainer,
)
from datasets import load_dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
def compute_metrics(eval_pred):
"""计算评估指标"""
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
return {
"accuracy": accuracy_score(labels, preds),
"precision": precision_score(labels, preds, average="binary"),
"recall": recall_score(labels, preds, average="binary"),
"f1": f1_score(labels, preds, average="binary"),
}
def main():
parser = argparse.ArgumentParser(description="训练敏感词过滤模型")
parser.add_argument("--model_name", type=str, default="bert-base-chinese")
parser.add_argument("--train_file", type=str, default="data/train.csv")
parser.add_argument("--val_file", type=str, default="data/val.csv")
parser.add_argument("--output_dir", type=str, default="output")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--warmup_ratio", type=float, default=0.1)
parser.add_argument("--weight_decay", type=float, default=0.01)
args = parser.parse_args()
# 加载 tokenizer
print(f"加载 tokenizer: {args.model_name}")
tokenizer = BertTokenizer.from_pretrained(args.model_name)
# 加载数据集
print("加载数据集...")
dataset = load_dataset(
"csv",
data_files={"train": args.train_file, "validation": args.val_file},
)
print(f"训练集: {len(dataset['train'])} 条")
print(f"验证集: {len(dataset['validation'])} 条")
# Tokenize
def tokenize_fn(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=args.max_length,
)
dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
dataset = dataset.rename_column("label", "labels")
dataset.set_format("torch")
# 加载模型
print(f"加载模型: {args.model_name}")
model = BertForSequenceClassification.from_pretrained(
args.model_name,
num_labels=2,
)
# 训练参数
best_model_dir = os.path.join(args.output_dir, "best_model")
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size * 2,
learning_rate=args.lr,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
logging_steps=50,
fp16=False,
bf16=False,
save_total_limit=2,
report_to="none",
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
compute_metrics=compute_metrics,
)
# 训练
print(f"\n{'='*60}")
print(f"开始训练")
print(f"Epochs: {args.epochs}, Batch Size: {args.batch_size}, LR: {args.lr}")
print(f"{'='*60}\n")
trainer.train()
# 保存最佳模型
print(f"\n保存最佳模型至: {best_model_dir}")
trainer.save_model(best_model_dir)
tokenizer.save_pretrained(best_model_dir)
# 保存配置信息
config = {
"model_name": args.model_name,
"max_length": args.max_length,
"num_labels": 2,
"label_map": {"0": "正常", "1": "敏感"},
}
with open(os.path.join(best_model_dir, "filter_config.json"), "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=2)
# 最终评估
final_metrics = trainer.evaluate()
print(f"\n{'='*60}")
print(f"训练完成!")
print(f"F1: {final_metrics.get('eval_f1', 'N/A'):.4f}")
print(f"Accuracy: {final_metrics.get('eval_accuracy', 'N/A'):.4f}")
print(f"模型已保存至: {best_model_dir}")
print(f"{'='*60}")
if __name__ == "__main__":
main()