|
|
import os
|
|
|
import json
|
|
|
|
|
|
import spacy
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
import torch
|
|
|
|
|
|
from unsloth import FastLanguageModel
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
nlp = spacy.load("en_core_web_sm")
|
|
|
except OSError:
|
|
|
print("未找到 spaCy 英文语言模型。请运行 'python -m spacy download en_core_web_sm' 进行下载。")
|
|
|
nlp = None
|
|
|
|
|
|
def split_article_into_sentences(article_text):
|
|
|
"""
|
|
|
将一篇英文文章分割为句子。
|
|
|
"""
|
|
|
if not nlp:
|
|
|
print("无法进行句子分割,spaCy 语言模型未加载。")
|
|
|
return []
|
|
|
if not isinstance(article_text, str) or not article_text:
|
|
|
return []
|
|
|
|
|
|
doc = nlp(article_text)
|
|
|
sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]
|
|
|
return sentences
|
|
|
|
|
|
class ModelManager:
|
|
|
"""管理模型加载和单句推理"""
|
|
|
def __init__(self):
|
|
|
self.model = None
|
|
|
self.tokenizer = None
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
self.is_loaded = False
|
|
|
|
|
|
def load_model(self, model_path="finetuned_model"):
|
|
|
"""加载微调后的模型和 tokenizer"""
|
|
|
if self.is_loaded:
|
|
|
print("模型已加载,无需重复加载。")
|
|
|
return
|
|
|
|
|
|
print(f"\n正在加载模型: {model_path}...")
|
|
|
max_seq_length = 2048
|
|
|
dtype = None
|
|
|
load_in_4bit = True
|
|
|
|
|
|
try:
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
|
model_name = model_path,
|
|
|
max_seq_length = max_seq_length,
|
|
|
dtype = dtype,
|
|
|
load_in_4bit = load_in_4bit,
|
|
|
)
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
self.model = model
|
|
|
self.tokenizer = tokenizer
|
|
|
self.is_loaded = True
|
|
|
print("模型加载成功。")
|
|
|
except Exception as e:
|
|
|
print(f"加载模型失败: {e}")
|
|
|
self.is_loaded = False
|
|
|
|
|
|
raise e
|
|
|
|
|
|
def predict_sentence(self, sentence: str, instruction: str) -> str:
|
|
|
"""使用已加载的模型对单个句子进行推理并返回预测标签 (1 或 0)"""
|
|
|
if not self.is_loaded or self.model is None or self.tokenizer is None:
|
|
|
raise RuntimeError("模型尚未加载,请先调用 load_model")
|
|
|
|
|
|
EOS_TOKEN = self.tokenizer.eos_token
|
|
|
|
|
|
|
|
|
prompt = f"Instruction: {instruction}\nInput: {sentence}\nOutput:"
|
|
|
|
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
|
|
|
|
|
|
|
|
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
outputs = self.model.generate(
|
|
|
**inputs,
|
|
|
max_new_tokens=10,
|
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
|
eos_token_id=self.tokenizer.eos_token_id
|
|
|
)
|
|
|
|
|
|
|
|
|
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
pred_label = ""
|
|
|
try:
|
|
|
output_start = response.rindex("Output:")
|
|
|
output_text = response[output_start + 7:].strip()
|
|
|
|
|
|
pred_label = output_text.split()[0].strip()
|
|
|
|
|
|
if pred_label not in ['1', '0']:
|
|
|
pred_label = ''
|
|
|
|
|
|
except (ValueError, IndexError):
|
|
|
pred_label = ""
|
|
|
|
|
|
|
|
|
return pred_label
|
|
|
|
|
|
|
|
|
model_manager = ModelManager()
|
|
|
|
|
|
|
|
|
def process_article_for_inference(article_file_path: str):
|
|
|
"""对文章进行句子分割,并使用 ModelManager 进行推理"""
|
|
|
|
|
|
if not nlp:
|
|
|
print("跳过文章推理,因为 spaCy 语言模型未加载。")
|
|
|
return []
|
|
|
|
|
|
|
|
|
if not model_manager.is_loaded:
|
|
|
|
|
|
try:
|
|
|
model_manager.load_model()
|
|
|
except Exception as e:
|
|
|
print(f"模型加载失败,无法进行推理: {e}")
|
|
|
return []
|
|
|
|
|
|
|
|
|
print(f"正在读取文章文件: {article_file_path}...")
|
|
|
try:
|
|
|
with open(article_file_path, "r", encoding='utf-8') as f:
|
|
|
article_text = f.read()
|
|
|
except FileNotFoundError:
|
|
|
print(f"错误:未找到文件 {article_file_path}")
|
|
|
return []
|
|
|
except Exception as e:
|
|
|
print(f"读取文件 {article_file_path} 时发生错误: {e}")
|
|
|
return []
|
|
|
|
|
|
|
|
|
print("正在分割文章为句子...")
|
|
|
sentences = split_article_into_sentences(article_text)
|
|
|
print(f"分割出 {len(sentences)} 个句子。")
|
|
|
|
|
|
if not sentences:
|
|
|
print("没有句子可供推理。")
|
|
|
return []
|
|
|
|
|
|
|
|
|
instruction = "判断以下句子是否包含性别歧视。请回复\"1\"表示包含,回复\"0\"表示不包含。只能回复1或0,不要输出其他内容。"
|
|
|
|
|
|
print("\n开始对每个句子进行推理预测...")
|
|
|
results = []
|
|
|
|
|
|
for sentence in tqdm(sentences, desc="推理中"):
|
|
|
predicted_label = model_manager.predict_sentence(sentence, instruction)
|
|
|
|
|
|
original_label = "包含歧视" if predicted_label == '1' else ("不包含歧视" if predicted_label == '0' else "未知")
|
|
|
results.append({"sentence": sentence, "predicted_label_numeric": predicted_label, "predicted_label_text": original_label})
|
|
|
|
|
|
print("推理完成。")
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|