File size: 5,521 Bytes
fc9ae4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import torch
import numpy as np
from torch import nn
from transformers import AutoModelForSequenceClassification, BertTokenizerFast, AutoConfig, pipeline, BertPreTrainedModel, BertModel

# 定义标签名称,与任务一致
BINARY_LABELS = ['Non-Envir', 'Envir'] 
NUM_LABELS = 2

# ----------------------------------------------------
# A. 定义支持多标签分类的 BERT 模型(必须与训练时一致)
# ----------------------------------------------------
class BertForMultiLabelClassification(BertPreTrainedModel):
    """
    基于 BERT 的多标签分类模型,使用 BCEWithLogitsLoss
    """
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        # 加载 BERT 主体
        self.bert = BertModel(config)
        
        # 加载训练时的 dropout 比例
        classifier_dropout = config.hidden_dropout_prob
        self.dropout = nn.Dropout(classifier_dropout)
        
        # 加载训练时的分类器层
        self.classifier = nn.Linear(config.hidden_size, self.num_labels) 

        self.post_init()
        # 注意:推理时不需要损失函数,但保持结构完整性
        self.loss_fct = nn.BCEWithLogitsLoss() 

    def forward(self, 
                input_ids=None, 
                attention_mask=None, 
                token_type_ids=None, 
                labels=None):
        
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        
        # 取 [CLS] token 的隐藏状态 (即 pooler output)
        pooled_output = outputs.pooler_output 
        pooled_output = self.dropout(pooled_output)
        
        # 经过分类器层,输出 logits (未经 Sigmoid 的分数)
        logits = self.classifier(pooled_output)

        # 推理时 labels 为 None,直接返回 logits
        return logits


# ----------------------------------------------------
# B. 模型推理函数
# ----------------------------------------------------
def predict_binary_classification(checkpoint_path: str, tokenizer_path: str, text_to_predict: str):
    """
    加载 BERT 二分类模型检查点,对单个文本进行二分类预测。

    Args:
        checkpoint_path: BERT 模型检查点目录(包含 config.json, model.safetensors)。
        tokenizer_path: 分词器路径或名称。
        text_to_predict: 待预测的输入文本。

    Returns:
        包含预测标签和概率的字典。
    """
    print(f"--- 1. 正在加载二分类模型和分词器: {checkpoint_path} ---")
    
    try:
        # 1. 加载配置和分词器
        config = AutoConfig.from_pretrained(checkpoint_path, num_labels=NUM_LABELS)
        tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path)
        
        # 2. 使用标准的 AutoModelForSequenceClassification 加载模型
        # 这将自动处理模型加载和分类头维度不匹配的问题
        model = AutoModelForSequenceClassification.from_pretrained(
            checkpoint_path, 
            config=config,
            ignore_mismatched_sizes=True # 容忍加载时的分类头尺寸不匹配
        )
    except Exception as e:
        print(f"加载模型或分词器失败,请检查路径中是否包含所有必需文件: {e}")
        return None

    model.eval() # 切换到评估模式
    
    # 3. 文本编码
    inputs = tokenizer(
        text_to_predict, 
        padding=True, 
        truncation=True, 
        max_length=512, 
        return_tensors="pt"
    )
    
    # 4. 执行推理
    with torch.no_grad():
        # 模型返回的是 Logits (维度通常是 [1, 2])
        outputs = model(**inputs)
        logits = outputs.logits # 获取 Logits
        
        # 应用 Softmax 转换为概率分布
        probabilities = torch.softmax(logits, dim=1).cpu().numpy()[0] 
        
        # 确定预测的类别索引 (0 或 1)
        predicted_index = np.argmax(probabilities)

    # 5. 格式化输出
    
    # 预测的类别名称
    predicted_label = BINARY_LABELS[predicted_index]
    # 预测类别的概率
    predicted_prob = probabilities[predicted_index]
    
    # 打印结果
    print("--- 5. 预测结果 ---")
    print(f"输入文本: {text_to_predict}")
    print(f"预测类别: {predicted_label}")
    print(f"对应概率: {predicted_prob:.4f}")
    
    # 返回所有类别的概率
    result = {
        'prediction': predicted_label,
        'probability': float(f"{predicted_prob:.4f}"),
        'all_probabilities': {
            BINARY_LABELS[i]: float(f"{probabilities[i]:.4f}") for i in range(NUM_LABELS)
        }
    }
    return result


# ----------------------------------------------------
# C. 示例运行
# ----------------------------------------------------
if __name__ == "__main__":
    # 以下三个参数是需要替换的,TOKENIZER需要与MODEL匹配
    MODEL_CHECKPOINT = "/home/hsichen/part_time/BERT_finetune/outputs/finbert2_bilabel_finetuned_model_from_dapt/final" 
    TOKENIZER = 'valuesimplex-ai-lab/FinBERT2-base'
    # TOKENIZER = 'bert-base-chinese'
    SAMPLE_TEXT = "密切关注安全环保对原料市场的影响,提前落实应对预案;"
    
    # 确保检查点目录存在
    if not os.path.exists(MODEL_CHECKPOINT):
        print(f"错误:模型检查点目录不存在: {MODEL_CHECKPOINT}")
    else:
        predict_binary_classification(MODEL_CHECKPOINT,TOKENIZER, SAMPLE_TEXT)