import os import torch import numpy as np from torch import nn from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast, AutoConfig # 定义标签名称,与任务一致 TAG_COLS = ['Data', 'Action', 'Gain', 'Regu', 'Vague'] PREDICTION_THRESHOLD = 0.5 # 预测阈值 # ---------------------------------------------------- # A. 定义支持多标签分类的 BERT 模型(必须与训练时一致) # ---------------------------------------------------- class BertForMultiLabelClassification(BertPreTrainedModel): """ 基于 BERT 的多标签分类模型,使用 BCEWithLogitsLoss """ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels # 加载 BERT 主体 self.bert = BertModel(config) # 加载训练时的 dropout 比例 classifier_dropout = config.hidden_dropout_prob self.dropout = nn.Dropout(classifier_dropout) # 加载训练时的分类器层 self.classifier = nn.Linear(config.hidden_size, self.num_labels) self.post_init() # 注意:推理时不需要损失函数,但保持结构完整性 self.loss_fct = nn.BCEWithLogitsLoss() def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None): outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) # 取 [CLS] token 的隐藏状态 (即 pooler output) pooled_output = outputs.pooler_output pooled_output = self.dropout(pooled_output) # 经过分类器层,输出 logits (未经 Sigmoid 的分数) logits = self.classifier(pooled_output) # 推理时 labels 为 None,直接返回 logits return logits # ---------------------------------------------------- # B. 模型推理函数 # ---------------------------------------------------- def predict_multilabel(checkpoint_path: str, tokenizer_path: str, text_to_predict: str): """ 加载模型检查点,对单个文本进行多标签预测。 Args: checkpoint_path: BERT 模型检查点目录(包含 config.json, model.safetensors)。 tokenizer_path: 分词器路径或名称。 text_to_predict: 待预测的输入文本。 Returns: 包含预测标签和概率的字典。 """ print(f"--- 1. 正在加载模型和分词器: {checkpoint_path} ---") try: config = AutoConfig.from_pretrained(checkpoint_path) # 确保配置中的 num_labels 与实际标签数量匹配 if config.num_labels != len(TAG_COLS): # 运行时动态修正 num_labels,以防 checkpoint-config.json 里的 num_labels 不匹配 config.num_labels = len(TAG_COLS) print(f"警告: 检查点配置的 num_labels 已从 {config.num_labels} 修正为 {len(TAG_COLS)}") # 从检查点加载分词器(假设分词器文件已存在或被复制) tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path) # 使用自定义模型类加载模型权重 model = BertForMultiLabelClassification.from_pretrained( checkpoint_path, config=config # 传入更新后的 config ) except Exception as e: print(f"加载模型或分词器失败,请检查路径中是否包含所有必需文件(如 model.safetensors, config.json, vocab.txt): {e}") return None model.eval() # 切换到评估模式 (关闭 Dropout等) # 2. 文本编码 inputs = tokenizer( text_to_predict, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ) # 3. 执行推理 with torch.no_grad(): # 模型返回的是 logits outputs = model(**inputs) logits = outputs.cpu().numpy() # 移动到 CPU 并转为 numpy # 4. 后处理:Sigmoid 和 阈值 # 应用 Sigmoid 转换为概率 probs = 1 / (1 + np.exp(-logits)) # 应用阈值得到二元预测 preds = (probs > PREDICTION_THRESHOLD).astype(int) # 5. 格式化输出 result = {} # 遍历每个标签,并记录其预测结果和概率 for i, tag in enumerate(TAG_COLS): # 结果只针对单个样本(批次大小为 1) is_predicted = preds[0][i] == 1 probability = probs[0][i] result[tag] = { "predicted": is_predicted, "probability": float(f"{probability:.4f}") # 保留 4 位小数 } print("--- 5. 预测结果 ---") # 提取所有预测为 True 的标签 predicted_tags = [tag for tag, info in result.items() if info["predicted"]] if predicted_tags: print(f"预测标签类别: {predicted_tags}") print(f"对应概率:") for tag in predicted_tags: print(f" - {tag}: {result[tag]['probability']}") else: print("未预测任何标签(所有标签概率均低于 0.5)。") print(f"所有标签的最高概率: {max(p['probability'] for p in result.values()):.4f}") # ---------------------------------------------------- # C. 示例运行 # ---------------------------------------------------- if __name__ == "__main__": # 以下三个参数是需要替换的,TOKENIZER需要与MODEL匹配 MODEL_CHECKPOINT = "/home/hsichen/part_time/BERT_finetune/outputs/finbert2_multilabel_model_finetuned_from_dapt/final" TOKENIZER = 'valuesimplex-ai-lab/FinBERT2-base' # TOKENIZER = 'bert-base-chinese' SAMPLE_TEXT = "密切关注安全环保对原料市场的影响,提前落实应对预案;" # 确保检查点目录存在 if not os.path.exists(MODEL_CHECKPOINT): print(f"错误:模型检查点目录不存在: {MODEL_CHECKPOINT}") else: predict_multilabel(MODEL_CHECKPOINT,TOKENIZER, SAMPLE_TEXT)