Text Classification
Transformers
Safetensors
Chinese
bert
content-moderation
sensitive-word-detection
text-embeddings-inference
Instructions to use crackrammer/ShieldBERT-Base-Chinese-Sensitive with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use crackrammer/ShieldBERT-Base-Chinese-Sensitive with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="crackrammer/ShieldBERT-Base-Chinese-Sensitive")# Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained("crackrammer/ShieldBERT-Base-Chinese-Sensitive") model = AutoModelForSequenceClassification.from_pretrained("crackrammer/ShieldBERT-Base-Chinese-Sensitive") - Notebooks
- Google Colab
- Kaggle
| """ | |
| 评估脚本:在测试集上评估模型性能 | |
| """ | |
| import os | |
| import json | |
| import argparse | |
| import numpy as np | |
| from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments | |
| from datasets import load_dataset | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| precision_score, | |
| recall_score, | |
| f1_score, | |
| confusion_matrix, | |
| classification_report, | |
| ) | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| preds = np.argmax(logits, axis=-1) | |
| return { | |
| "accuracy": accuracy_score(labels, preds), | |
| "precision": precision_score(labels, preds, average="binary"), | |
| "recall": recall_score(labels, preds, average="binary"), | |
| "f1": f1_score(labels, preds, average="binary"), | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser(description="评估敏感词过滤模型") | |
| parser.add_argument("--model_path", type=str, default="output/best_model") | |
| parser.add_argument("--test_file", type=str, default="data/test.csv") | |
| parser.add_argument("--batch_size", type=int, default=64) | |
| args = parser.parse_args() | |
| # 加载配置 | |
| config_path = os.path.join(args.model_path, "filter_config.json") | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| config = json.load(f) | |
| max_length = config.get("max_length", 128) | |
| label_names = [config["label_map"]["0"], config["label_map"]["1"]] | |
| # 加载模型和 tokenizer | |
| print(f"加载模型: {args.model_path}") | |
| tokenizer = BertTokenizer.from_pretrained(args.model_path) | |
| model = BertForSequenceClassification.from_pretrained(args.model_path) | |
| # 加载测试集 | |
| dataset = load_dataset("csv", data_files={"test": args.test_file}) | |
| def tokenize_fn(examples): | |
| return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length) | |
| dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["text"]) | |
| dataset = dataset.rename_column("label", "labels") | |
| dataset.set_format("torch") | |
| print(f"测试集: {len(dataset['test'])} 条") | |
| # 评估 | |
| training_args = TrainingArguments( | |
| output_dir="/tmp/eval_output", | |
| per_device_eval_batch_size=args.batch_size, | |
| report_to="none", | |
| ) | |
| trainer = Trainer(model=model, args=training_args, compute_metrics=compute_metrics) | |
| # 预测 | |
| predictions = trainer.predict(dataset["test"]) | |
| preds = np.argmax(predictions.predictions, axis=-1) | |
| labels = predictions.label_ids | |
| acc = accuracy_score(labels, preds) | |
| precision = precision_score(labels, preds, average="binary") | |
| recall = recall_score(labels, preds, average="binary") | |
| f1 = f1_score(labels, preds, average="binary") | |
| cm = confusion_matrix(labels, preds) | |
| print(f"\n{'='*60}") | |
| print("模型评估结果") | |
| print(f"{'='*60}") | |
| print(f"准确率 (Accuracy): {acc:.4f}") | |
| print(f"精确率 (Precision): {precision:.4f}") | |
| print(f"召回率 (Recall): {recall:.4f}") | |
| print(f"F1 值 (F1-Score): {f1:.4f}") | |
| print(f"\n--- 混淆矩阵 ---") | |
| print(f"{'':>12} 预测正常 预测敏感") | |
| print(f"{'实际正常':>10} {cm[0][0]:>6} {cm[0][1]:>6}") | |
| print(f"{'实际敏感':>10} {cm[1][0]:>6} {cm[1][1]:>6}") | |
| print(f"\n--- 分类报告 ---") | |
| report = classification_report(labels, preds, target_names=label_names, digits=4) | |
| print(report) | |
| # 保存结果 | |
| results = { | |
| "accuracy": acc, | |
| "precision": precision, | |
| "recall": recall, | |
| "f1": f1, | |
| "confusion_matrix": cm.tolist(), | |
| } | |
| output_path = os.path.join(os.path.dirname(args.model_path), "eval_results.json") | |
| with open(output_path, "w", encoding="utf-8") as f: | |
| json.dump(results, f, ensure_ascii=False, indent=2) | |
| print(f"\n评估结果已保存至: {output_path}") | |
| if __name__ == "__main__": | |
| main() | |