File size: 7,144 Bytes
d861b4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import os
import json
# 导入 spaCy 和相关的库用于句子分割
import spacy
# 导入 Hugging Face Transformers 和 PyTorch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# 导入 unsloth 用于加载模型
from unsloth import FastLanguageModel
from tqdm import tqdm
# pip install spacy
# python -m spacy download en_core_web_sm
# 加载英文语言模型 (用于句子分割)
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
print("未找到 spaCy 英文语言模型。请运行 'python -m spacy download en_core_web_sm' 进行下载。")
nlp = None
def split_article_into_sentences(article_text):
"""
将一篇英文文章分割为句子。
"""
if not nlp:
print("无法进行句子分割,spaCy 语言模型未加载。")
return []
if not isinstance(article_text, str) or not article_text:
return []
doc = nlp(article_text)
sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]
return sentences
class ModelManager:
"""管理模型加载和单句推理"""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.is_loaded = False
def load_model(self, model_path="finetuned_model"):
"""加载微调后的模型和 tokenizer"""
if self.is_loaded:
print("模型已加载,无需重复加载。")
return
print(f"\n正在加载模型: {model_path}...")
max_seq_length = 2048 # 与训练时保持一致
dtype = None # 让 unsloth 自动选择
load_in_4bit = True # 与训练时保持一致
try:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = model_path,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
# 将模型设置为推理模式
model.eval()
# unsloth 会自动处理设备放置,无需 model.to(device)
self.model = model
self.tokenizer = tokenizer
self.is_loaded = True
print("模型加载成功。")
except Exception as e:
print(f"加载模型失败: {e}")
self.is_loaded = False
# 可以在此处选择退出或抛出异常,以便 Api.py 捕获
raise e
def predict_sentence(self, sentence: str, instruction: str) -> str:
"""使用已加载的模型对单个句子进行推理并返回预测标签 (1 或 0)"""
if not self.is_loaded or self.model is None or self.tokenizer is None:
raise RuntimeError("模型尚未加载,请先调用 load_model")
EOS_TOKEN = self.tokenizer.eos_token
# 构建 prompt
prompt = f"Instruction: {instruction}\nInput: {sentence}\nOutput:"
# 编码输入
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
# 将输入张量移动到模型所在的设备 (unsloth 会自动设置模型设备)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad(): # 推理时不需要计算梯度
# 使用 model.generate 进行推理
outputs = self.model.generate(
**inputs,
max_new_tokens=10, # 预期输出是 '1' 或 '0',所以设置一个较小的 max_new_tokens
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# 解码输出
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# 提取预测标签 (1 或 0)
pred_label = ""
try:
output_start = response.rindex("Output:")
output_text = response[output_start + 7:].strip()
# 只取第一个词作为预测标签
pred_label = output_text.split()[0].strip()
# 确保预测标签是 '1' 或 '0'
if pred_label not in ['1', '0']:
pred_label = '' # 非预期的输出
except (ValueError, IndexError):
pred_label = "" # 提取失败
# 返回提取到的预测标签 (字符串类型)
return pred_label
# 创建一个 ModelManager 实例
model_manager = ModelManager()
# 定义文章处理函数,它将使用 ModelManager 实例进行单句推理
def process_article_for_inference(article_file_path: str):
"""对文章进行句子分割,并使用 ModelManager 进行推理"""
# 确保 spaCy 模型已加载
if not nlp:
print("跳过文章推理,因为 spaCy 语言模型未加载。")
return [] # 返回空列表表示处理失败
# 确保模型已加载
if not model_manager.is_loaded:
# 如果模型未加载,尝试加载。如果加载失败,load_model 会抛出异常。
try:
model_manager.load_model()
except Exception as e:
print(f"模型加载失败,无法进行推理: {e}")
return [] # 返回空列表表示处理失败
# 读取文章文件
print(f"正在读取文章文件: {article_file_path}...")
try:
with open(article_file_path, "r", encoding='utf-8') as f:
article_text = f.read()
except FileNotFoundError:
print(f"错误:未找到文件 {article_file_path}")
return []
except Exception as e:
print(f"读取文件 {article_file_path} 时发生错误: {e}")
return []
# 分割文章为句子
print("正在分割文章为句子...")
sentences = split_article_into_sentences(article_text)
print(f"分割出 {len(sentences)} 个句子。")
if not sentences:
print("没有句子可供推理。")
return []
# 定义推理指令 (与训练时保持一致)
instruction = "判断以下句子是否包含性别歧视。请回复\"1\"表示包含,回复\"0\"表示不包含。只能回复1或0,不要输出其他内容。"
print("\n开始对每个句子进行推理预测...")
results = []
# 使用 ModelManager 实例进行单句推理
for sentence in tqdm(sentences, desc="推理中"):
predicted_label = model_manager.predict_sentence(sentence, instruction)
# 为了返回给前端更友好的格式,可以将 1/0 转换为中文描述
original_label = "包含歧视" if predicted_label == '1' else ("不包含歧视" if predicted_label == '0' else "未知")
results.append({"sentence": sentence, "predicted_label_numeric": predicted_label, "predicted_label_text": original_label})
print("推理完成。")
return results
|