FINBERT2_finetune / model_inference_task1.py
Riverise's picture
Upload folder using huggingface_hub
fc9ae4e verified
import os
import torch
import numpy as np
from torch import nn
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast, AutoConfig
# 定义标签名称,与任务一致
TAG_COLS = ['Data', 'Action', 'Gain', 'Regu', 'Vague']
PREDICTION_THRESHOLD = 0.5 # 预测阈值
# ----------------------------------------------------
# A. 定义支持多标签分类的 BERT 模型(必须与训练时一致)
# ----------------------------------------------------
class BertForMultiLabelClassification(BertPreTrainedModel):
"""
基于 BERT 的多标签分类模型,使用 BCEWithLogitsLoss
"""
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# 加载 BERT 主体
self.bert = BertModel(config)
# 加载训练时的 dropout 比例
classifier_dropout = config.hidden_dropout_prob
self.dropout = nn.Dropout(classifier_dropout)
# 加载训练时的分类器层
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
self.post_init()
# 注意:推理时不需要损失函数,但保持结构完整性
self.loss_fct = nn.BCEWithLogitsLoss()
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels=None):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
# 取 [CLS] token 的隐藏状态 (即 pooler output)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
# 经过分类器层,输出 logits (未经 Sigmoid 的分数)
logits = self.classifier(pooled_output)
# 推理时 labels 为 None,直接返回 logits
return logits
# ----------------------------------------------------
# B. 模型推理函数
# ----------------------------------------------------
def predict_multilabel(checkpoint_path: str, tokenizer_path: str, text_to_predict: str):
"""
加载模型检查点,对单个文本进行多标签预测。
Args:
checkpoint_path: BERT 模型检查点目录(包含 config.json, model.safetensors)。
tokenizer_path: 分词器路径或名称。
text_to_predict: 待预测的输入文本。
Returns:
包含预测标签和概率的字典。
"""
print(f"--- 1. 正在加载模型和分词器: {checkpoint_path} ---")
try:
config = AutoConfig.from_pretrained(checkpoint_path)
# 确保配置中的 num_labels 与实际标签数量匹配
if config.num_labels != len(TAG_COLS):
# 运行时动态修正 num_labels,以防 checkpoint-config.json 里的 num_labels 不匹配
config.num_labels = len(TAG_COLS)
print(f"警告: 检查点配置的 num_labels 已从 {config.num_labels} 修正为 {len(TAG_COLS)}")
# 从检查点加载分词器(假设分词器文件已存在或被复制)
tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path)
# 使用自定义模型类加载模型权重
model = BertForMultiLabelClassification.from_pretrained(
checkpoint_path,
config=config # 传入更新后的 config
)
except Exception as e:
print(f"加载模型或分词器失败,请检查路径中是否包含所有必需文件(如 model.safetensors, config.json, vocab.txt): {e}")
return None
model.eval() # 切换到评估模式 (关闭 Dropout等)
# 2. 文本编码
inputs = tokenizer(
text_to_predict,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
)
# 3. 执行推理
with torch.no_grad():
# 模型返回的是 logits
outputs = model(**inputs)
logits = outputs.cpu().numpy() # 移动到 CPU 并转为 numpy
# 4. 后处理:Sigmoid 和 阈值
# 应用 Sigmoid 转换为概率
probs = 1 / (1 + np.exp(-logits))
# 应用阈值得到二元预测
preds = (probs > PREDICTION_THRESHOLD).astype(int)
# 5. 格式化输出
result = {}
# 遍历每个标签,并记录其预测结果和概率
for i, tag in enumerate(TAG_COLS):
# 结果只针对单个样本(批次大小为 1)
is_predicted = preds[0][i] == 1
probability = probs[0][i]
result[tag] = {
"predicted": is_predicted,
"probability": float(f"{probability:.4f}") # 保留 4 位小数
}
print("--- 5. 预测结果 ---")
# 提取所有预测为 True 的标签
predicted_tags = [tag for tag, info in result.items() if info["predicted"]]
if predicted_tags:
print(f"预测标签类别: {predicted_tags}")
print(f"对应概率:")
for tag in predicted_tags:
print(f" - {tag}: {result[tag]['probability']}")
else:
print("未预测任何标签(所有标签概率均低于 0.5)。")
print(f"所有标签的最高概率: {max(p['probability'] for p in result.values()):.4f}")
# ----------------------------------------------------
# C. 示例运行
# ----------------------------------------------------
if __name__ == "__main__":
# 以下三个参数是需要替换的,TOKENIZER需要与MODEL匹配
MODEL_CHECKPOINT = "/home/hsichen/part_time/BERT_finetune/outputs/finbert2_multilabel_model_finetuned_from_dapt/final"
TOKENIZER = 'valuesimplex-ai-lab/FinBERT2-base'
# TOKENIZER = 'bert-base-chinese'
SAMPLE_TEXT = "密切关注安全环保对原料市场的影响,提前落实应对预案;"
# 确保检查点目录存在
if not os.path.exists(MODEL_CHECKPOINT):
print(f"错误:模型检查点目录不存在: {MODEL_CHECKPOINT}")
else:
predict_multilabel(MODEL_CHECKPOINT,TOKENIZER, SAMPLE_TEXT)