FINBERT2_finetune / model_inference_task2.py
Riverise's picture
Upload folder using huggingface_hub
fc9ae4e verified
import os
import torch
import numpy as np
from torch import nn
from transformers import AutoModelForSequenceClassification, BertTokenizerFast, AutoConfig, pipeline, BertPreTrainedModel, BertModel
# 定义标签名称,与任务一致
BINARY_LABELS = ['Non-Envir', 'Envir']
NUM_LABELS = 2
# ----------------------------------------------------
# A. 定义支持多标签分类的 BERT 模型(必须与训练时一致)
# ----------------------------------------------------
class BertForMultiLabelClassification(BertPreTrainedModel):
"""
基于 BERT 的多标签分类模型,使用 BCEWithLogitsLoss
"""
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# 加载 BERT 主体
self.bert = BertModel(config)
# 加载训练时的 dropout 比例
classifier_dropout = config.hidden_dropout_prob
self.dropout = nn.Dropout(classifier_dropout)
# 加载训练时的分类器层
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
self.post_init()
# 注意:推理时不需要损失函数,但保持结构完整性
self.loss_fct = nn.BCEWithLogitsLoss()
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels=None):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
# 取 [CLS] token 的隐藏状态 (即 pooler output)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
# 经过分类器层,输出 logits (未经 Sigmoid 的分数)
logits = self.classifier(pooled_output)
# 推理时 labels 为 None,直接返回 logits
return logits
# ----------------------------------------------------
# B. 模型推理函数
# ----------------------------------------------------
def predict_binary_classification(checkpoint_path: str, tokenizer_path: str, text_to_predict: str):
"""
加载 BERT 二分类模型检查点,对单个文本进行二分类预测。
Args:
checkpoint_path: BERT 模型检查点目录(包含 config.json, model.safetensors)。
tokenizer_path: 分词器路径或名称。
text_to_predict: 待预测的输入文本。
Returns:
包含预测标签和概率的字典。
"""
print(f"--- 1. 正在加载二分类模型和分词器: {checkpoint_path} ---")
try:
# 1. 加载配置和分词器
config = AutoConfig.from_pretrained(checkpoint_path, num_labels=NUM_LABELS)
tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path)
# 2. 使用标准的 AutoModelForSequenceClassification 加载模型
# 这将自动处理模型加载和分类头维度不匹配的问题
model = AutoModelForSequenceClassification.from_pretrained(
checkpoint_path,
config=config,
ignore_mismatched_sizes=True # 容忍加载时的分类头尺寸不匹配
)
except Exception as e:
print(f"加载模型或分词器失败,请检查路径中是否包含所有必需文件: {e}")
return None
model.eval() # 切换到评估模式
# 3. 文本编码
inputs = tokenizer(
text_to_predict,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
# 4. 执行推理
with torch.no_grad():
# 模型返回的是 Logits (维度通常是 [1, 2])
outputs = model(**inputs)
logits = outputs.logits # 获取 Logits
# 应用 Softmax 转换为概率分布
probabilities = torch.softmax(logits, dim=1).cpu().numpy()[0]
# 确定预测的类别索引 (0 或 1)
predicted_index = np.argmax(probabilities)
# 5. 格式化输出
# 预测的类别名称
predicted_label = BINARY_LABELS[predicted_index]
# 预测类别的概率
predicted_prob = probabilities[predicted_index]
# 打印结果
print("--- 5. 预测结果 ---")
print(f"输入文本: {text_to_predict}")
print(f"预测类别: {predicted_label}")
print(f"对应概率: {predicted_prob:.4f}")
# 返回所有类别的概率
result = {
'prediction': predicted_label,
'probability': float(f"{predicted_prob:.4f}"),
'all_probabilities': {
BINARY_LABELS[i]: float(f"{probabilities[i]:.4f}") for i in range(NUM_LABELS)
}
}
return result
# ----------------------------------------------------
# C. 示例运行
# ----------------------------------------------------
if __name__ == "__main__":
# 以下三个参数是需要替换的,TOKENIZER需要与MODEL匹配
MODEL_CHECKPOINT = "/home/hsichen/part_time/BERT_finetune/outputs/finbert2_bilabel_finetuned_model_from_dapt/final"
TOKENIZER = 'valuesimplex-ai-lab/FinBERT2-base'
# TOKENIZER = 'bert-base-chinese'
SAMPLE_TEXT = "密切关注安全环保对原料市场的影响,提前落实应对预案;"
# 确保检查点目录存在
if not os.path.exists(MODEL_CHECKPOINT):
print(f"错误:模型检查点目录不存在: {MODEL_CHECKPOINT}")
else:
predict_binary_classification(MODEL_CHECKPOINT,TOKENIZER, SAMPLE_TEXT)