|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
from torch import nn |
|
|
from transformers import AutoModelForSequenceClassification, BertTokenizerFast, AutoConfig, pipeline, BertPreTrainedModel, BertModel |
|
|
|
|
|
|
|
|
BINARY_LABELS = ['Non-Envir', 'Envir'] |
|
|
NUM_LABELS = 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BertForMultiLabelClassification(BertPreTrainedModel): |
|
|
""" |
|
|
基于 BERT 的多标签分类模型,使用 BCEWithLogitsLoss |
|
|
""" |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
|
|
|
|
|
|
self.bert = BertModel(config) |
|
|
|
|
|
|
|
|
classifier_dropout = config.hidden_dropout_prob |
|
|
self.dropout = nn.Dropout(classifier_dropout) |
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(config.hidden_size, self.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
self.loss_fct = nn.BCEWithLogitsLoss() |
|
|
|
|
|
def forward(self, |
|
|
input_ids=None, |
|
|
attention_mask=None, |
|
|
token_type_ids=None, |
|
|
labels=None): |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
) |
|
|
|
|
|
|
|
|
pooled_output = outputs.pooler_output |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
|
|
|
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_binary_classification(checkpoint_path: str, tokenizer_path: str, text_to_predict: str): |
|
|
""" |
|
|
加载 BERT 二分类模型检查点,对单个文本进行二分类预测。 |
|
|
|
|
|
Args: |
|
|
checkpoint_path: BERT 模型检查点目录(包含 config.json, model.safetensors)。 |
|
|
tokenizer_path: 分词器路径或名称。 |
|
|
text_to_predict: 待预测的输入文本。 |
|
|
|
|
|
Returns: |
|
|
包含预测标签和概率的字典。 |
|
|
""" |
|
|
print(f"--- 1. 正在加载二分类模型和分词器: {checkpoint_path} ---") |
|
|
|
|
|
try: |
|
|
|
|
|
config = AutoConfig.from_pretrained(checkpoint_path, num_labels=NUM_LABELS) |
|
|
tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path) |
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
checkpoint_path, |
|
|
config=config, |
|
|
ignore_mismatched_sizes=True |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"加载模型或分词器失败,请检查路径中是否包含所有必需文件: {e}") |
|
|
return None |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
inputs = tokenizer( |
|
|
text_to_predict, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
|
|
|
|
|
|
probabilities = torch.softmax(logits, dim=1).cpu().numpy()[0] |
|
|
|
|
|
|
|
|
predicted_index = np.argmax(probabilities) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predicted_label = BINARY_LABELS[predicted_index] |
|
|
|
|
|
predicted_prob = probabilities[predicted_index] |
|
|
|
|
|
|
|
|
print("--- 5. 预测结果 ---") |
|
|
print(f"输入文本: {text_to_predict}") |
|
|
print(f"预测类别: {predicted_label}") |
|
|
print(f"对应概率: {predicted_prob:.4f}") |
|
|
|
|
|
|
|
|
result = { |
|
|
'prediction': predicted_label, |
|
|
'probability': float(f"{predicted_prob:.4f}"), |
|
|
'all_probabilities': { |
|
|
BINARY_LABELS[i]: float(f"{probabilities[i]:.4f}") for i in range(NUM_LABELS) |
|
|
} |
|
|
} |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
MODEL_CHECKPOINT = "/home/hsichen/part_time/BERT_finetune/outputs/finbert2_bilabel_finetuned_model_from_dapt/final" |
|
|
TOKENIZER = 'valuesimplex-ai-lab/FinBERT2-base' |
|
|
|
|
|
SAMPLE_TEXT = "密切关注安全环保对原料市场的影响,提前落实应对预案;" |
|
|
|
|
|
|
|
|
if not os.path.exists(MODEL_CHECKPOINT): |
|
|
print(f"错误:模型检查点目录不存在: {MODEL_CHECKPOINT}") |
|
|
else: |
|
|
predict_binary_classification(MODEL_CHECKPOINT,TOKENIZER, SAMPLE_TEXT) |