import os import torch import numpy as np from torch import nn from transformers import AutoModelForSequenceClassification, BertTokenizerFast, AutoConfig, pipeline, BertPreTrainedModel, BertModel # 定义标签名称,与任务一致 BINARY_LABELS = ['Non-Envir', 'Envir'] NUM_LABELS = 2 # ---------------------------------------------------- # A. 定义支持多标签分类的 BERT 模型(必须与训练时一致) # ---------------------------------------------------- class BertForMultiLabelClassification(BertPreTrainedModel): """ 基于 BERT 的多标签分类模型,使用 BCEWithLogitsLoss """ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels # 加载 BERT 主体 self.bert = BertModel(config) # 加载训练时的 dropout 比例 classifier_dropout = config.hidden_dropout_prob self.dropout = nn.Dropout(classifier_dropout) # 加载训练时的分类器层 self.classifier = nn.Linear(config.hidden_size, self.num_labels) self.post_init() # 注意:推理时不需要损失函数,但保持结构完整性 self.loss_fct = nn.BCEWithLogitsLoss() def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None): outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) # 取 [CLS] token 的隐藏状态 (即 pooler output) pooled_output = outputs.pooler_output pooled_output = self.dropout(pooled_output) # 经过分类器层,输出 logits (未经 Sigmoid 的分数) logits = self.classifier(pooled_output) # 推理时 labels 为 None,直接返回 logits return logits # ---------------------------------------------------- # B. 模型推理函数 # ---------------------------------------------------- def predict_binary_classification(checkpoint_path: str, tokenizer_path: str, text_to_predict: str): """ 加载 BERT 二分类模型检查点,对单个文本进行二分类预测。 Args: checkpoint_path: BERT 模型检查点目录(包含 config.json, model.safetensors)。 tokenizer_path: 分词器路径或名称。 text_to_predict: 待预测的输入文本。 Returns: 包含预测标签和概率的字典。 """ print(f"--- 1. 正在加载二分类模型和分词器: {checkpoint_path} ---") try: # 1. 加载配置和分词器 config = AutoConfig.from_pretrained(checkpoint_path, num_labels=NUM_LABELS) tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path) # 2. 使用标准的 AutoModelForSequenceClassification 加载模型 # 这将自动处理模型加载和分类头维度不匹配的问题 model = AutoModelForSequenceClassification.from_pretrained( checkpoint_path, config=config, ignore_mismatched_sizes=True # 容忍加载时的分类头尺寸不匹配 ) except Exception as e: print(f"加载模型或分词器失败,请检查路径中是否包含所有必需文件: {e}") return None model.eval() # 切换到评估模式 # 3. 文本编码 inputs = tokenizer( text_to_predict, padding=True, truncation=True, max_length=512, return_tensors="pt" ) # 4. 执行推理 with torch.no_grad(): # 模型返回的是 Logits (维度通常是 [1, 2]) outputs = model(**inputs) logits = outputs.logits # 获取 Logits # 应用 Softmax 转换为概率分布 probabilities = torch.softmax(logits, dim=1).cpu().numpy()[0] # 确定预测的类别索引 (0 或 1) predicted_index = np.argmax(probabilities) # 5. 格式化输出 # 预测的类别名称 predicted_label = BINARY_LABELS[predicted_index] # 预测类别的概率 predicted_prob = probabilities[predicted_index] # 打印结果 print("--- 5. 预测结果 ---") print(f"输入文本: {text_to_predict}") print(f"预测类别: {predicted_label}") print(f"对应概率: {predicted_prob:.4f}") # 返回所有类别的概率 result = { 'prediction': predicted_label, 'probability': float(f"{predicted_prob:.4f}"), 'all_probabilities': { BINARY_LABELS[i]: float(f"{probabilities[i]:.4f}") for i in range(NUM_LABELS) } } return result # ---------------------------------------------------- # C. 示例运行 # ---------------------------------------------------- if __name__ == "__main__": # 以下三个参数是需要替换的,TOKENIZER需要与MODEL匹配 MODEL_CHECKPOINT = "/home/hsichen/part_time/BERT_finetune/outputs/finbert2_bilabel_finetuned_model_from_dapt/final" TOKENIZER = 'valuesimplex-ai-lab/FinBERT2-base' # TOKENIZER = 'bert-base-chinese' SAMPLE_TEXT = "密切关注安全环保对原料市场的影响,提前落实应对预案;" # 确保检查点目录存在 if not os.path.exists(MODEL_CHECKPOINT): print(f"错误:模型检查点目录不存在: {MODEL_CHECKPOINT}") else: predict_binary_classification(MODEL_CHECKPOINT,TOKENIZER, SAMPLE_TEXT)