Text Classification
Transformers
Safetensors
Chinese
bert
content-moderation
sensitive-word-detection
text-embeddings-inference
Instructions to use crackrammer/ShieldBERT-Base-Chinese-Sensitive with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use crackrammer/ShieldBERT-Base-Chinese-Sensitive with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="crackrammer/ShieldBERT-Base-Chinese-Sensitive")# Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained("crackrammer/ShieldBERT-Base-Chinese-Sensitive") model = AutoModelForSequenceClassification.from_pretrained("crackrammer/ShieldBERT-Base-Chinese-Sensitive") - Notebooks
- Google Colab
- Kaggle
File size: 5,427 Bytes
db60e24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """
推理脚本:使用训练好的模型预测文本是否包含敏感内容
"""
import os
import json
import argparse
import torch
from transformers import BertTokenizer, BertForSequenceClassification
class SensitiveWordPredictor:
"""敏感词预测器"""
def __init__(self, model_path: str, device: str = None):
if device:
self.device = torch.device(device)
else:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载配置
config_path = os.path.join(model_path, "filter_config.json")
with open(config_path, "r", encoding="utf-8") as f:
self.config = json.load(f)
self.label_map = self.config["label_map"]
self.max_length = self.config.get("max_length", 128)
# 加载模型和 tokenizer
self.tokenizer = BertTokenizer.from_pretrained(model_path)
self.model = BertForSequenceClassification.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
def predict(self, text: str) -> dict:
"""预测单条文本"""
encoding = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
input_ids = encoding["input_ids"].to(self.device)
attention_mask = encoding["attention_mask"].to(self.device)
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
probs = torch.softmax(outputs.logits, dim=1)
pred_label = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred_label].item()
return {
"text": text,
"label": pred_label,
"label_name": self.label_map[str(pred_label)],
"confidence": round(confidence, 4),
"is_sensitive": pred_label == 1,
}
def predict_batch(self, texts: list[str], batch_size: int = 32) -> list[dict]:
"""批量预测"""
results = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
encoding = self.tokenizer(
batch_texts,
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
input_ids = encoding["input_ids"].to(self.device)
attention_mask = encoding["attention_mask"].to(self.device)
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
probs = torch.softmax(outputs.logits, dim=1)
pred_labels = torch.argmax(probs, dim=1)
for j, text in enumerate(batch_texts):
label = pred_labels[j].item()
conf = probs[j][label].item()
results.append({
"text": text,
"label": label,
"label_name": self.label_map[str(label)],
"confidence": round(conf, 4),
"is_sensitive": label == 1,
})
return results
def main():
parser = argparse.ArgumentParser(description="敏感词预测")
parser.add_argument("--model_path", type=str, default="output/best_model")
parser.add_argument("--text", type=str, help="要检测的文本(单条)")
parser.add_argument("--file", type=str, help="要检测的文本文件(每行一条)")
parser.add_argument("--device", type=str, default=None)
args = parser.parse_args()
predictor = SensitiveWordPredictor(args.model_path, args.device)
if args.text:
result = predictor.predict(args.text)
status = "🔴 敏感" if result["is_sensitive"] else "🟢 正常"
print(f"\n输入: {result['text']}")
print(f"结果: {status}")
print(f"标签: {result['label_name']} (label={result['label']})")
print(f"置信度: {result['confidence']:.4f}")
elif args.file:
with open(args.file, "r", encoding="utf-8") as f:
texts = [line.strip() for line in f if line.strip()]
results = predictor.predict_batch(texts)
print(f"\n{'='*70}")
print(f"批量检测结果 (共 {len(results)} 条)")
print(f"{'='*70}")
sensitive_count = 0
for r in results:
status = "🔴 敏感" if r["is_sensitive"] else "🟢 正常"
print(f" {status} [{r['confidence']:.4f}] {r['text'][:50]}...")
if r["is_sensitive"]:
sensitive_count += 1
print(f"\n统计: 正常 {len(results) - sensitive_count} 条, 敏感 {sensitive_count} 条")
else:
print("敏感词检测系统 - 交互模式")
print("输入文本进行检测,输入 'quit' 退出\n")
while True:
text = input("请输入文本> ").strip()
if text.lower() in ("quit", "exit", "q"):
print("再见!")
break
if not text:
continue
result = predictor.predict(text)
status = "🔴 敏感" if result["is_sensitive"] else "🟢 正常"
print(f" 结果: {status} | 置信度: {result['confidence']:.4f}\n")
if __name__ == "__main__":
main()
|