crackrammer's picture
Upload folder using huggingface_hub
db60e24 verified
"""
推理脚本:使用训练好的模型预测文本是否包含敏感内容
"""
import os
import json
import argparse
import torch
from transformers import BertTokenizer, BertForSequenceClassification
class SensitiveWordPredictor:
"""敏感词预测器"""
def __init__(self, model_path: str, device: str = None):
if device:
self.device = torch.device(device)
else:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载配置
config_path = os.path.join(model_path, "filter_config.json")
with open(config_path, "r", encoding="utf-8") as f:
self.config = json.load(f)
self.label_map = self.config["label_map"]
self.max_length = self.config.get("max_length", 128)
# 加载模型和 tokenizer
self.tokenizer = BertTokenizer.from_pretrained(model_path)
self.model = BertForSequenceClassification.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
def predict(self, text: str) -> dict:
"""预测单条文本"""
encoding = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
input_ids = encoding["input_ids"].to(self.device)
attention_mask = encoding["attention_mask"].to(self.device)
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
probs = torch.softmax(outputs.logits, dim=1)
pred_label = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred_label].item()
return {
"text": text,
"label": pred_label,
"label_name": self.label_map[str(pred_label)],
"confidence": round(confidence, 4),
"is_sensitive": pred_label == 1,
}
def predict_batch(self, texts: list[str], batch_size: int = 32) -> list[dict]:
"""批量预测"""
results = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
encoding = self.tokenizer(
batch_texts,
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
input_ids = encoding["input_ids"].to(self.device)
attention_mask = encoding["attention_mask"].to(self.device)
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
probs = torch.softmax(outputs.logits, dim=1)
pred_labels = torch.argmax(probs, dim=1)
for j, text in enumerate(batch_texts):
label = pred_labels[j].item()
conf = probs[j][label].item()
results.append({
"text": text,
"label": label,
"label_name": self.label_map[str(label)],
"confidence": round(conf, 4),
"is_sensitive": label == 1,
})
return results
def main():
parser = argparse.ArgumentParser(description="敏感词预测")
parser.add_argument("--model_path", type=str, default="output/best_model")
parser.add_argument("--text", type=str, help="要检测的文本(单条)")
parser.add_argument("--file", type=str, help="要检测的文本文件(每行一条)")
parser.add_argument("--device", type=str, default=None)
args = parser.parse_args()
predictor = SensitiveWordPredictor(args.model_path, args.device)
if args.text:
result = predictor.predict(args.text)
status = "🔴 敏感" if result["is_sensitive"] else "🟢 正常"
print(f"\n输入: {result['text']}")
print(f"结果: {status}")
print(f"标签: {result['label_name']} (label={result['label']})")
print(f"置信度: {result['confidence']:.4f}")
elif args.file:
with open(args.file, "r", encoding="utf-8") as f:
texts = [line.strip() for line in f if line.strip()]
results = predictor.predict_batch(texts)
print(f"\n{'='*70}")
print(f"批量检测结果 (共 {len(results)} 条)")
print(f"{'='*70}")
sensitive_count = 0
for r in results:
status = "🔴 敏感" if r["is_sensitive"] else "🟢 正常"
print(f" {status} [{r['confidence']:.4f}] {r['text'][:50]}...")
if r["is_sensitive"]:
sensitive_count += 1
print(f"\n统计: 正常 {len(results) - sensitive_count} 条, 敏感 {sensitive_count} 条")
else:
print("敏感词检测系统 - 交互模式")
print("输入文本进行检测,输入 'quit' 退出\n")
while True:
text = input("请输入文本> ").strip()
if text.lower() in ("quit", "exit", "q"):
print("再见!")
break
if not text:
continue
result = predictor.predict(text)
status = "🔴 敏感" if result["is_sensitive"] else "🟢 正常"
print(f" 结果: {status} | 置信度: {result['confidence']:.4f}\n")
if __name__ == "__main__":
main()