Text Classification
Transformers
Safetensors
Chinese
bert
content-moderation
sensitive-word-detection
text-embeddings-inference
Instructions to use crackrammer/ShieldBERT-Base-Chinese-Sensitive with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use crackrammer/ShieldBERT-Base-Chinese-Sensitive with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="crackrammer/ShieldBERT-Base-Chinese-Sensitive")# Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained("crackrammer/ShieldBERT-Base-Chinese-Sensitive") model = AutoModelForSequenceClassification.from_pretrained("crackrammer/ShieldBERT-Base-Chinese-Sensitive") - Notebooks
- Google Colab
- Kaggle
File size: 3,864 Bytes
db60e24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | """
评估脚本:在测试集上评估模型性能
"""
import os
import json
import argparse
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
from sklearn.metrics import (
accuracy_score,
precision_score,
recall_score,
f1_score,
confusion_matrix,
classification_report,
)
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
return {
"accuracy": accuracy_score(labels, preds),
"precision": precision_score(labels, preds, average="binary"),
"recall": recall_score(labels, preds, average="binary"),
"f1": f1_score(labels, preds, average="binary"),
}
def main():
parser = argparse.ArgumentParser(description="评估敏感词过滤模型")
parser.add_argument("--model_path", type=str, default="output/best_model")
parser.add_argument("--test_file", type=str, default="data/test.csv")
parser.add_argument("--batch_size", type=int, default=64)
args = parser.parse_args()
# 加载配置
config_path = os.path.join(args.model_path, "filter_config.json")
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
max_length = config.get("max_length", 128)
label_names = [config["label_map"]["0"], config["label_map"]["1"]]
# 加载模型和 tokenizer
print(f"加载模型: {args.model_path}")
tokenizer = BertTokenizer.from_pretrained(args.model_path)
model = BertForSequenceClassification.from_pretrained(args.model_path)
# 加载测试集
dataset = load_dataset("csv", data_files={"test": args.test_file})
def tokenize_fn(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length)
dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
dataset = dataset.rename_column("label", "labels")
dataset.set_format("torch")
print(f"测试集: {len(dataset['test'])} 条")
# 评估
training_args = TrainingArguments(
output_dir="/tmp/eval_output",
per_device_eval_batch_size=args.batch_size,
report_to="none",
)
trainer = Trainer(model=model, args=training_args, compute_metrics=compute_metrics)
# 预测
predictions = trainer.predict(dataset["test"])
preds = np.argmax(predictions.predictions, axis=-1)
labels = predictions.label_ids
acc = accuracy_score(labels, preds)
precision = precision_score(labels, preds, average="binary")
recall = recall_score(labels, preds, average="binary")
f1 = f1_score(labels, preds, average="binary")
cm = confusion_matrix(labels, preds)
print(f"\n{'='*60}")
print("模型评估结果")
print(f"{'='*60}")
print(f"准确率 (Accuracy): {acc:.4f}")
print(f"精确率 (Precision): {precision:.4f}")
print(f"召回率 (Recall): {recall:.4f}")
print(f"F1 值 (F1-Score): {f1:.4f}")
print(f"\n--- 混淆矩阵 ---")
print(f"{'':>12} 预测正常 预测敏感")
print(f"{'实际正常':>10} {cm[0][0]:>6} {cm[0][1]:>6}")
print(f"{'实际敏感':>10} {cm[1][0]:>6} {cm[1][1]:>6}")
print(f"\n--- 分类报告 ---")
report = classification_report(labels, preds, target_names=label_names, digits=4)
print(report)
# 保存结果
results = {
"accuracy": acc,
"precision": precision,
"recall": recall,
"f1": f1,
"confusion_matrix": cm.tolist(),
}
output_path = os.path.join(os.path.dirname(args.model_path), "eval_results.json")
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"\n评估结果已保存至: {output_path}")
if __name__ == "__main__":
main()
|