File size: 1,432 Bytes
f7e81b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""

NewsImportanceModel — 新闻重要性双头模型

用于 torch.load("model.pt") 时需要此类定义

"""

import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification


class NewsImportanceModel(nn.Module):
    """在 FinBERT 基础上添加双头: 4-bin 分类 + 回归"""

    def __init__(self, base_model_name: str = "LocalOptimum/chinese-crypto-sentiment", num_bins: int = 4):
        super().__init__()
        base = AutoModelForSequenceClassification.from_pretrained(base_model_name)
        self.bert = base.bert if hasattr(base, "bert") else base.roberta
        hidden_size = self.bert.config.hidden_size
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_size, num_bins)
        self.regressor = nn.Sequential(
            nn.Linear(hidden_size, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        pooled = outputs.last_hidden_state[:, 0]
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)
        score = self.regressor(pooled).squeeze(-1)
        return logits, score