File size: 5,521 Bytes
fc9ae4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import torch
import numpy as np
from torch import nn
from transformers import AutoModelForSequenceClassification, BertTokenizerFast, AutoConfig, pipeline, BertPreTrainedModel, BertModel
# 定义标签名称,与任务一致
BINARY_LABELS = ['Non-Envir', 'Envir']
NUM_LABELS = 2
# ----------------------------------------------------
# A. 定义支持多标签分类的 BERT 模型(必须与训练时一致)
# ----------------------------------------------------
class BertForMultiLabelClassification(BertPreTrainedModel):
"""
基于 BERT 的多标签分类模型,使用 BCEWithLogitsLoss
"""
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# 加载 BERT 主体
self.bert = BertModel(config)
# 加载训练时的 dropout 比例
classifier_dropout = config.hidden_dropout_prob
self.dropout = nn.Dropout(classifier_dropout)
# 加载训练时的分类器层
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
self.post_init()
# 注意:推理时不需要损失函数,但保持结构完整性
self.loss_fct = nn.BCEWithLogitsLoss()
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels=None):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
# 取 [CLS] token 的隐藏状态 (即 pooler output)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
# 经过分类器层,输出 logits (未经 Sigmoid 的分数)
logits = self.classifier(pooled_output)
# 推理时 labels 为 None,直接返回 logits
return logits
# ----------------------------------------------------
# B. 模型推理函数
# ----------------------------------------------------
def predict_binary_classification(checkpoint_path: str, tokenizer_path: str, text_to_predict: str):
"""
加载 BERT 二分类模型检查点,对单个文本进行二分类预测。
Args:
checkpoint_path: BERT 模型检查点目录(包含 config.json, model.safetensors)。
tokenizer_path: 分词器路径或名称。
text_to_predict: 待预测的输入文本。
Returns:
包含预测标签和概率的字典。
"""
print(f"--- 1. 正在加载二分类模型和分词器: {checkpoint_path} ---")
try:
# 1. 加载配置和分词器
config = AutoConfig.from_pretrained(checkpoint_path, num_labels=NUM_LABELS)
tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path)
# 2. 使用标准的 AutoModelForSequenceClassification 加载模型
# 这将自动处理模型加载和分类头维度不匹配的问题
model = AutoModelForSequenceClassification.from_pretrained(
checkpoint_path,
config=config,
ignore_mismatched_sizes=True # 容忍加载时的分类头尺寸不匹配
)
except Exception as e:
print(f"加载模型或分词器失败,请检查路径中是否包含所有必需文件: {e}")
return None
model.eval() # 切换到评估模式
# 3. 文本编码
inputs = tokenizer(
text_to_predict,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
# 4. 执行推理
with torch.no_grad():
# 模型返回的是 Logits (维度通常是 [1, 2])
outputs = model(**inputs)
logits = outputs.logits # 获取 Logits
# 应用 Softmax 转换为概率分布
probabilities = torch.softmax(logits, dim=1).cpu().numpy()[0]
# 确定预测的类别索引 (0 或 1)
predicted_index = np.argmax(probabilities)
# 5. 格式化输出
# 预测的类别名称
predicted_label = BINARY_LABELS[predicted_index]
# 预测类别的概率
predicted_prob = probabilities[predicted_index]
# 打印结果
print("--- 5. 预测结果 ---")
print(f"输入文本: {text_to_predict}")
print(f"预测类别: {predicted_label}")
print(f"对应概率: {predicted_prob:.4f}")
# 返回所有类别的概率
result = {
'prediction': predicted_label,
'probability': float(f"{predicted_prob:.4f}"),
'all_probabilities': {
BINARY_LABELS[i]: float(f"{probabilities[i]:.4f}") for i in range(NUM_LABELS)
}
}
return result
# ----------------------------------------------------
# C. 示例运行
# ----------------------------------------------------
if __name__ == "__main__":
# 以下三个参数是需要替换的,TOKENIZER需要与MODEL匹配
MODEL_CHECKPOINT = "/home/hsichen/part_time/BERT_finetune/outputs/finbert2_bilabel_finetuned_model_from_dapt/final"
TOKENIZER = 'valuesimplex-ai-lab/FinBERT2-base'
# TOKENIZER = 'bert-base-chinese'
SAMPLE_TEXT = "密切关注安全环保对原料市场的影响,提前落实应对预案;"
# 确保检查点目录存在
if not os.path.exists(MODEL_CHECKPOINT):
print(f"错误:模型检查点目录不存在: {MODEL_CHECKPOINT}")
else:
predict_binary_classification(MODEL_CHECKPOINT,TOKENIZER, SAMPLE_TEXT) |