LocalOptimum's picture
Upload chinese-crypto-importance v1.0
f7e81b6 verified
"""
NewsImportanceModel — 新闻重要性双头模型
用于 torch.load("model.pt") 时需要此类定义
"""
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification
class NewsImportanceModel(nn.Module):
"""在 FinBERT 基础上添加双头: 4-bin 分类 + 回归"""
def __init__(self, base_model_name: str = "LocalOptimum/chinese-crypto-sentiment", num_bins: int = 4):
super().__init__()
base = AutoModelForSequenceClassification.from_pretrained(base_model_name)
self.bert = base.bert if hasattr(base, "bert") else base.roberta
hidden_size = self.bert.config.hidden_size
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(hidden_size, num_bins)
self.regressor = nn.Sequential(
nn.Linear(hidden_size, 128),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(128, 1),
nn.Sigmoid(),
)
def forward(self, input_ids, attention_mask, token_type_ids=None):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
pooled = outputs.last_hidden_state[:, 0]
pooled = self.dropout(pooled)
logits = self.classifier(pooled)
score = self.regressor(pooled).squeeze(-1)
return logits, score