|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
from torch import nn |
|
|
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast, AutoConfig |
|
|
|
|
|
|
|
|
TAG_COLS = ['Data', 'Action', 'Gain', 'Regu', 'Vague'] |
|
|
PREDICTION_THRESHOLD = 0.5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BertForMultiLabelClassification(BertPreTrainedModel): |
|
|
""" |
|
|
基于 BERT 的多标签分类模型,使用 BCEWithLogitsLoss |
|
|
""" |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
|
|
|
|
|
|
self.bert = BertModel(config) |
|
|
|
|
|
|
|
|
classifier_dropout = config.hidden_dropout_prob |
|
|
self.dropout = nn.Dropout(classifier_dropout) |
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(config.hidden_size, self.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
self.loss_fct = nn.BCEWithLogitsLoss() |
|
|
|
|
|
def forward(self, |
|
|
input_ids=None, |
|
|
attention_mask=None, |
|
|
token_type_ids=None, |
|
|
labels=None): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
) |
|
|
|
|
|
|
|
|
pooled_output = outputs.pooler_output |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
|
|
|
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_multilabel(checkpoint_path: str, tokenizer_path: str, text_to_predict: str): |
|
|
""" |
|
|
加载模型检查点,对单个文本进行多标签预测。 |
|
|
|
|
|
Args: |
|
|
checkpoint_path: BERT 模型检查点目录(包含 config.json, model.safetensors)。 |
|
|
tokenizer_path: 分词器路径或名称。 |
|
|
text_to_predict: 待预测的输入文本。 |
|
|
|
|
|
Returns: |
|
|
包含预测标签和概率的字典。 |
|
|
""" |
|
|
print(f"--- 1. 正在加载模型和分词器: {checkpoint_path} ---") |
|
|
|
|
|
try: |
|
|
config = AutoConfig.from_pretrained(checkpoint_path) |
|
|
|
|
|
if config.num_labels != len(TAG_COLS): |
|
|
|
|
|
config.num_labels = len(TAG_COLS) |
|
|
print(f"警告: 检查点配置的 num_labels 已从 {config.num_labels} 修正为 {len(TAG_COLS)}") |
|
|
|
|
|
|
|
|
tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path) |
|
|
|
|
|
|
|
|
model = BertForMultiLabelClassification.from_pretrained( |
|
|
checkpoint_path, |
|
|
config=config |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"加载模型或分词器失败,请检查路径中是否包含所有必需文件(如 model.safetensors, config.json, vocab.txt): {e}") |
|
|
return None |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
inputs = tokenizer( |
|
|
text_to_predict, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
outputs = model(**inputs) |
|
|
logits = outputs.cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
probs = 1 / (1 + np.exp(-logits)) |
|
|
|
|
|
preds = (probs > PREDICTION_THRESHOLD).astype(int) |
|
|
|
|
|
|
|
|
result = {} |
|
|
|
|
|
|
|
|
for i, tag in enumerate(TAG_COLS): |
|
|
|
|
|
is_predicted = preds[0][i] == 1 |
|
|
probability = probs[0][i] |
|
|
|
|
|
result[tag] = { |
|
|
"predicted": is_predicted, |
|
|
"probability": float(f"{probability:.4f}") |
|
|
} |
|
|
|
|
|
print("--- 5. 预测结果 ---") |
|
|
|
|
|
|
|
|
predicted_tags = [tag for tag, info in result.items() if info["predicted"]] |
|
|
|
|
|
if predicted_tags: |
|
|
print(f"预测标签类别: {predicted_tags}") |
|
|
print(f"对应概率:") |
|
|
for tag in predicted_tags: |
|
|
print(f" - {tag}: {result[tag]['probability']}") |
|
|
else: |
|
|
print("未预测任何标签(所有标签概率均低于 0.5)。") |
|
|
print(f"所有标签的最高概率: {max(p['probability'] for p in result.values()):.4f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
MODEL_CHECKPOINT = "/home/hsichen/part_time/BERT_finetune/outputs/finbert2_multilabel_model_finetuned_from_dapt/final" |
|
|
TOKENIZER = 'valuesimplex-ai-lab/FinBERT2-base' |
|
|
|
|
|
SAMPLE_TEXT = "密切关注安全环保对原料市场的影响,提前落实应对预案;" |
|
|
|
|
|
|
|
|
if not os.path.exists(MODEL_CHECKPOINT): |
|
|
print(f"错误:模型检查点目录不存在: {MODEL_CHECKPOINT}") |
|
|
else: |
|
|
predict_multilabel(MODEL_CHECKPOINT,TOKENIZER, SAMPLE_TEXT) |