| """ | |
| NewsImportanceModel — 新闻重要性双头模型 | |
| 用于 torch.load("model.pt") 时需要此类定义 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModelForSequenceClassification | |
| class NewsImportanceModel(nn.Module): | |
| """在 FinBERT 基础上添加双头: 4-bin 分类 + 回归""" | |
| def __init__(self, base_model_name: str = "LocalOptimum/chinese-crypto-sentiment", num_bins: int = 4): | |
| super().__init__() | |
| base = AutoModelForSequenceClassification.from_pretrained(base_model_name) | |
| self.bert = base.bert if hasattr(base, "bert") else base.roberta | |
| hidden_size = self.bert.config.hidden_size | |
| self.dropout = nn.Dropout(0.1) | |
| self.classifier = nn.Linear(hidden_size, num_bins) | |
| self.regressor = nn.Sequential( | |
| nn.Linear(hidden_size, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(128, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, input_ids, attention_mask, token_type_ids=None): | |
| outputs = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| ) | |
| pooled = outputs.last_hidden_state[:, 0] | |
| pooled = self.dropout(pooled) | |
| logits = self.classifier(pooled) | |
| score = self.regressor(pooled).squeeze(-1) | |
| return logits, score | |