gender_discrimination_detection_backend / ModelInferenceServe.py
LiuShisan123's picture
上传文件夹
d861b4b verified
import os
import json
# 导入 spaCy 和相关的库用于句子分割
import spacy
# 导入 Hugging Face Transformers 和 PyTorch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# 导入 unsloth 用于加载模型
from unsloth import FastLanguageModel
from tqdm import tqdm
# pip install spacy
# python -m spacy download en_core_web_sm
# 加载英文语言模型 (用于句子分割)
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
print("未找到 spaCy 英文语言模型。请运行 'python -m spacy download en_core_web_sm' 进行下载。")
nlp = None
def split_article_into_sentences(article_text):
"""
将一篇英文文章分割为句子。
"""
if not nlp:
print("无法进行句子分割,spaCy 语言模型未加载。")
return []
if not isinstance(article_text, str) or not article_text:
return []
doc = nlp(article_text)
sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]
return sentences
class ModelManager:
"""管理模型加载和单句推理"""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.is_loaded = False
def load_model(self, model_path="finetuned_model"):
"""加载微调后的模型和 tokenizer"""
if self.is_loaded:
print("模型已加载,无需重复加载。")
return
print(f"\n正在加载模型: {model_path}...")
max_seq_length = 2048 # 与训练时保持一致
dtype = None # 让 unsloth 自动选择
load_in_4bit = True # 与训练时保持一致
try:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = model_path,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
# 将模型设置为推理模式
model.eval()
# unsloth 会自动处理设备放置,无需 model.to(device)
self.model = model
self.tokenizer = tokenizer
self.is_loaded = True
print("模型加载成功。")
except Exception as e:
print(f"加载模型失败: {e}")
self.is_loaded = False
# 可以在此处选择退出或抛出异常,以便 Api.py 捕获
raise e
def predict_sentence(self, sentence: str, instruction: str) -> str:
"""使用已加载的模型对单个句子进行推理并返回预测标签 (1 或 0)"""
if not self.is_loaded or self.model is None or self.tokenizer is None:
raise RuntimeError("模型尚未加载,请先调用 load_model")
EOS_TOKEN = self.tokenizer.eos_token
# 构建 prompt
prompt = f"Instruction: {instruction}\nInput: {sentence}\nOutput:"
# 编码输入
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
# 将输入张量移动到模型所在的设备 (unsloth 会自动设置模型设备)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad(): # 推理时不需要计算梯度
# 使用 model.generate 进行推理
outputs = self.model.generate(
**inputs,
max_new_tokens=10, # 预期输出是 '1' 或 '0',所以设置一个较小的 max_new_tokens
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# 解码输出
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# 提取预测标签 (1 或 0)
pred_label = ""
try:
output_start = response.rindex("Output:")
output_text = response[output_start + 7:].strip()
# 只取第一个词作为预测标签
pred_label = output_text.split()[0].strip()
# 确保预测标签是 '1' 或 '0'
if pred_label not in ['1', '0']:
pred_label = '' # 非预期的输出
except (ValueError, IndexError):
pred_label = "" # 提取失败
# 返回提取到的预测标签 (字符串类型)
return pred_label
# 创建一个 ModelManager 实例
model_manager = ModelManager()
# 定义文章处理函数,它将使用 ModelManager 实例进行单句推理
def process_article_for_inference(article_file_path: str):
"""对文章进行句子分割,并使用 ModelManager 进行推理"""
# 确保 spaCy 模型已加载
if not nlp:
print("跳过文章推理,因为 spaCy 语言模型未加载。")
return [] # 返回空列表表示处理失败
# 确保模型已加载
if not model_manager.is_loaded:
# 如果模型未加载,尝试加载。如果加载失败,load_model 会抛出异常。
try:
model_manager.load_model()
except Exception as e:
print(f"模型加载失败,无法进行推理: {e}")
return [] # 返回空列表表示处理失败
# 读取文章文件
print(f"正在读取文章文件: {article_file_path}...")
try:
with open(article_file_path, "r", encoding='utf-8') as f:
article_text = f.read()
except FileNotFoundError:
print(f"错误:未找到文件 {article_file_path}")
return []
except Exception as e:
print(f"读取文件 {article_file_path} 时发生错误: {e}")
return []
# 分割文章为句子
print("正在分割文章为句子...")
sentences = split_article_into_sentences(article_text)
print(f"分割出 {len(sentences)} 个句子。")
if not sentences:
print("没有句子可供推理。")
return []
# 定义推理指令 (与训练时保持一致)
instruction = "判断以下句子是否包含性别歧视。请回复\"1\"表示包含,回复\"0\"表示不包含。只能回复1或0,不要输出其他内容。"
print("\n开始对每个句子进行推理预测...")
results = []
# 使用 ModelManager 实例进行单句推理
for sentence in tqdm(sentences, desc="推理中"):
predicted_label = model_manager.predict_sentence(sentence, instruction)
# 为了返回给前端更友好的格式,可以将 1/0 转换为中文描述
original_label = "包含歧视" if predicted_label == '1' else ("不包含歧视" if predicted_label == '0' else "未知")
results.append({"sentence": sentence, "predicted_label_numeric": predicted_label, "predicted_label_text": original_label})
print("推理完成。")
return results