Text Classification
Transformers
Safetensors
Chinese
bert
content-moderation
sensitive-word-detection
text-embeddings-inference
Instructions to use crackrammer/ShieldBERT-Base-Chinese-Sensitive with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use crackrammer/ShieldBERT-Base-Chinese-Sensitive with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="crackrammer/ShieldBERT-Base-Chinese-Sensitive")# Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained("crackrammer/ShieldBERT-Base-Chinese-Sensitive") model = AutoModelForSequenceClassification.from_pretrained("crackrammer/ShieldBERT-Base-Chinese-Sensitive") - Notebooks
- Google Colab
- Kaggle
| """ | |
| 训练脚本:使用 HuggingFace Trainer 微调 BERT 进行敏感词二分类 | |
| 使用 BertForSequenceClassification + Trainer API, | |
| 支持自动混合精度、梯度累积、学习率调度等。 | |
| """ | |
| import os | |
| import json | |
| import argparse | |
| import numpy as np | |
| import torch | |
| from transformers import ( | |
| BertTokenizer, | |
| BertForSequenceClassification, | |
| TrainingArguments, | |
| Trainer, | |
| ) | |
| from datasets import load_dataset | |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score | |
| def compute_metrics(eval_pred): | |
| """计算评估指标""" | |
| logits, labels = eval_pred | |
| preds = np.argmax(logits, axis=-1) | |
| return { | |
| "accuracy": accuracy_score(labels, preds), | |
| "precision": precision_score(labels, preds, average="binary"), | |
| "recall": recall_score(labels, preds, average="binary"), | |
| "f1": f1_score(labels, preds, average="binary"), | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser(description="训练敏感词过滤模型") | |
| parser.add_argument("--model_name", type=str, default="bert-base-chinese") | |
| parser.add_argument("--train_file", type=str, default="data/train.csv") | |
| parser.add_argument("--val_file", type=str, default="data/val.csv") | |
| parser.add_argument("--output_dir", type=str, default="output") | |
| parser.add_argument("--epochs", type=int, default=3) | |
| parser.add_argument("--batch_size", type=int, default=32) | |
| parser.add_argument("--lr", type=float, default=2e-5) | |
| parser.add_argument("--max_length", type=int, default=128) | |
| parser.add_argument("--warmup_ratio", type=float, default=0.1) | |
| parser.add_argument("--weight_decay", type=float, default=0.01) | |
| args = parser.parse_args() | |
| # 加载 tokenizer | |
| print(f"加载 tokenizer: {args.model_name}") | |
| tokenizer = BertTokenizer.from_pretrained(args.model_name) | |
| # 加载数据集 | |
| print("加载数据集...") | |
| dataset = load_dataset( | |
| "csv", | |
| data_files={"train": args.train_file, "validation": args.val_file}, | |
| ) | |
| print(f"训练集: {len(dataset['train'])} 条") | |
| print(f"验证集: {len(dataset['validation'])} 条") | |
| # Tokenize | |
| def tokenize_fn(examples): | |
| return tokenizer( | |
| examples["text"], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=args.max_length, | |
| ) | |
| dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["text"]) | |
| dataset = dataset.rename_column("label", "labels") | |
| dataset.set_format("torch") | |
| # 加载模型 | |
| print(f"加载模型: {args.model_name}") | |
| model = BertForSequenceClassification.from_pretrained( | |
| args.model_name, | |
| num_labels=2, | |
| ) | |
| # 训练参数 | |
| best_model_dir = os.path.join(args.output_dir, "best_model") | |
| training_args = TrainingArguments( | |
| output_dir=args.output_dir, | |
| num_train_epochs=args.epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| per_device_eval_batch_size=args.batch_size * 2, | |
| learning_rate=args.lr, | |
| weight_decay=args.weight_decay, | |
| warmup_ratio=args.warmup_ratio, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="f1", | |
| greater_is_better=True, | |
| logging_steps=50, | |
| fp16=False, | |
| bf16=False, | |
| save_total_limit=2, | |
| report_to="none", | |
| ) | |
| # Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["validation"], | |
| compute_metrics=compute_metrics, | |
| ) | |
| # 训练 | |
| print(f"\n{'='*60}") | |
| print(f"开始训练") | |
| print(f"Epochs: {args.epochs}, Batch Size: {args.batch_size}, LR: {args.lr}") | |
| print(f"{'='*60}\n") | |
| trainer.train() | |
| # 保存最佳模型 | |
| print(f"\n保存最佳模型至: {best_model_dir}") | |
| trainer.save_model(best_model_dir) | |
| tokenizer.save_pretrained(best_model_dir) | |
| # 保存配置信息 | |
| config = { | |
| "model_name": args.model_name, | |
| "max_length": args.max_length, | |
| "num_labels": 2, | |
| "label_map": {"0": "正常", "1": "敏感"}, | |
| } | |
| with open(os.path.join(best_model_dir, "filter_config.json"), "w", encoding="utf-8") as f: | |
| json.dump(config, f, ensure_ascii=False, indent=2) | |
| # 最终评估 | |
| final_metrics = trainer.evaluate() | |
| print(f"\n{'='*60}") | |
| print(f"训练完成!") | |
| print(f"F1: {final_metrics.get('eval_f1', 'N/A'):.4f}") | |
| print(f"Accuracy: {final_metrics.get('eval_accuracy', 'N/A'):.4f}") | |
| print(f"模型已保存至: {best_model_dir}") | |
| print(f"{'='*60}") | |
| if __name__ == "__main__": | |
| main() | |