Text Classification
Transformers
Safetensors
Chinese
bert
content-moderation
sensitive-word-detection
text-embeddings-inference
Instructions to use crackrammer/ShieldBERT-Base-Chinese-Sensitive with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use crackrammer/ShieldBERT-Base-Chinese-Sensitive with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="crackrammer/ShieldBERT-Base-Chinese-Sensitive")# Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained("crackrammer/ShieldBERT-Base-Chinese-Sensitive") model = AutoModelForSequenceClassification.from_pretrained("crackrammer/ShieldBERT-Base-Chinese-Sensitive") - Notebooks
- Google Colab
- Kaggle
| """ | |
| 推理脚本:使用训练好的模型预测文本是否包含敏感内容 | |
| """ | |
| import os | |
| import json | |
| import argparse | |
| import torch | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| class SensitiveWordPredictor: | |
| """敏感词预测器""" | |
| def __init__(self, model_path: str, device: str = None): | |
| if device: | |
| self.device = torch.device(device) | |
| else: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # 加载配置 | |
| config_path = os.path.join(model_path, "filter_config.json") | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| self.config = json.load(f) | |
| self.label_map = self.config["label_map"] | |
| self.max_length = self.config.get("max_length", 128) | |
| # 加载模型和 tokenizer | |
| self.tokenizer = BertTokenizer.from_pretrained(model_path) | |
| self.model = BertForSequenceClassification.from_pretrained(model_path) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def predict(self, text: str) -> dict: | |
| """预测单条文本""" | |
| encoding = self.tokenizer( | |
| text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_tensors="pt", | |
| ) | |
| input_ids = encoding["input_ids"].to(self.device) | |
| attention_mask = encoding["attention_mask"].to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
| probs = torch.softmax(outputs.logits, dim=1) | |
| pred_label = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0][pred_label].item() | |
| return { | |
| "text": text, | |
| "label": pred_label, | |
| "label_name": self.label_map[str(pred_label)], | |
| "confidence": round(confidence, 4), | |
| "is_sensitive": pred_label == 1, | |
| } | |
| def predict_batch(self, texts: list[str], batch_size: int = 32) -> list[dict]: | |
| """批量预测""" | |
| results = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch_texts = texts[i : i + batch_size] | |
| encoding = self.tokenizer( | |
| batch_texts, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_tensors="pt", | |
| ) | |
| input_ids = encoding["input_ids"].to(self.device) | |
| attention_mask = encoding["attention_mask"].to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
| probs = torch.softmax(outputs.logits, dim=1) | |
| pred_labels = torch.argmax(probs, dim=1) | |
| for j, text in enumerate(batch_texts): | |
| label = pred_labels[j].item() | |
| conf = probs[j][label].item() | |
| results.append({ | |
| "text": text, | |
| "label": label, | |
| "label_name": self.label_map[str(label)], | |
| "confidence": round(conf, 4), | |
| "is_sensitive": label == 1, | |
| }) | |
| return results | |
| def main(): | |
| parser = argparse.ArgumentParser(description="敏感词预测") | |
| parser.add_argument("--model_path", type=str, default="output/best_model") | |
| parser.add_argument("--text", type=str, help="要检测的文本(单条)") | |
| parser.add_argument("--file", type=str, help="要检测的文本文件(每行一条)") | |
| parser.add_argument("--device", type=str, default=None) | |
| args = parser.parse_args() | |
| predictor = SensitiveWordPredictor(args.model_path, args.device) | |
| if args.text: | |
| result = predictor.predict(args.text) | |
| status = "🔴 敏感" if result["is_sensitive"] else "🟢 正常" | |
| print(f"\n输入: {result['text']}") | |
| print(f"结果: {status}") | |
| print(f"标签: {result['label_name']} (label={result['label']})") | |
| print(f"置信度: {result['confidence']:.4f}") | |
| elif args.file: | |
| with open(args.file, "r", encoding="utf-8") as f: | |
| texts = [line.strip() for line in f if line.strip()] | |
| results = predictor.predict_batch(texts) | |
| print(f"\n{'='*70}") | |
| print(f"批量检测结果 (共 {len(results)} 条)") | |
| print(f"{'='*70}") | |
| sensitive_count = 0 | |
| for r in results: | |
| status = "🔴 敏感" if r["is_sensitive"] else "🟢 正常" | |
| print(f" {status} [{r['confidence']:.4f}] {r['text'][:50]}...") | |
| if r["is_sensitive"]: | |
| sensitive_count += 1 | |
| print(f"\n统计: 正常 {len(results) - sensitive_count} 条, 敏感 {sensitive_count} 条") | |
| else: | |
| print("敏感词检测系统 - 交互模式") | |
| print("输入文本进行检测,输入 'quit' 退出\n") | |
| while True: | |
| text = input("请输入文本> ").strip() | |
| if text.lower() in ("quit", "exit", "q"): | |
| print("再见!") | |
| break | |
| if not text: | |
| continue | |
| result = predictor.predict(text) | |
| status = "🔴 敏感" if result["is_sensitive"] else "🟢 正常" | |
| print(f" 结果: {status} | 置信度: {result['confidence']:.4f}\n") | |
| if __name__ == "__main__": | |
| main() | |