File size: 6,149 Bytes
d82e190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union

from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging

from configuration_pdeeppp import PDeepPPConfig

logger = logging.get_logger(__name__)

class SelfAttentionGlobalFeatures(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attention = nn.MultiheadAttention(
            embed_dim=config.input_size, 
            num_heads=config.num_heads, 
            batch_first=True
        )
        self.fc1 = nn.Linear(config.input_size, config.hidden_size)
        self.fc2 = nn.Linear(config.hidden_size, config.output_size)
        self.layer_norm = nn.LayerNorm(config.input_size)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        attn_output, _ = self.self_attention(x, x, x)
        x = self.layer_norm(x + attn_output)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class TransConv1d(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attention_global_features = SelfAttentionGlobalFeatures(config)
        self.transformer_encoder = nn.TransformerEncoderLayer(
            d_model=config.output_size, 
            nhead=config.num_heads, 
            dim_feedforward=config.hidden_size*2, 
            dropout=config.dropout, 
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            self.transformer_encoder, 
            num_layers=config.num_transformer_layers
        )
        self.fc1 = nn.Linear(config.output_size, config.output_size)
        self.fc2 = nn.Linear(config.output_size, config.output_size)
        self.layer_norm = nn.LayerNorm(config.output_size)

    def forward(self, x):
        x = self.self_attention_global_features(x)
        residual = x
        x = self.transformer(x)
        x = self.fc1(x)
        residual = x
        x = self.fc2(x)
        x = self.layer_norm(x + residual)
        return x

class PosCNN(nn.Module):
    def __init__(self, config, use_position_encoding=True):
        super().__init__()
        self.use_position_encoding = use_position_encoding
        self.conv1d = nn.Conv1d(
            in_channels=config.input_size, 
            out_channels=64, 
            kernel_size=3, 
            padding=1
        )
        self.relu = nn.ReLU()
        self.global_pooling = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(64, config.output_size)
        
        if self.use_position_encoding:
            self.position_encoding = nn.Parameter(torch.zeros(64, config.input_size))

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.conv1d(x)
        x = self.relu(x)
        
        if self.use_position_encoding:
            seq_len = x.size(2)
            pos_encoding = self.position_encoding[:, :seq_len].unsqueeze(0)
            x = x + pos_encoding
            
        x = self.global_pooling(x)
        x = x.squeeze(-1)
        x = self.fc(x)
        return x

class PDeepPPPreTrainedModel(PreTrainedModel):
    """
    抽象基类,包含所有PDeepPP模型所需的方法
    """
    config_class = PDeepPPConfig
    base_model_prefix = "PDeepPP"
    supports_gradient_checkpointing = True
    
    def _init_weights(self, module):
        """初始化权重"""
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

class PDeepPPModel(PDeepPPPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        self.transformer = TransConv1d(config)
        self.cnn = PosCNN(config)
        self.cnn_layers = nn.Sequential(
            nn.Conv1d(config.output_size*2, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveMaxPool1d(1),
            nn.Dropout(config.dropout/2),
            nn.Conv1d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveMaxPool1d(1),
            nn.Dropout(config.dropout/2),
            nn.Flatten(),
            nn.Linear(64, 1)
        )
        
        # 初始化权重
        self.post_init()

    def forward(
        self,
        input_embeds=None,
        labels=None,
        return_dict=None,
    ):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the classification loss.
            
        Returns:
            dict or tuple: 根据return_dict参数返回不同格式的结果
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        transformer_output = self.transformer(input_embeds)
        cnn_output = self.cnn(input_embeds)
        cnn_output = cnn_output.unsqueeze(1).expand(-1, transformer_output.size(1), -1)
        combined = torch.cat([transformer_output, cnn_output], dim=2)
        combined = combined.permute(0, 2, 1)
        logits = self.cnn_layers(combined).squeeze(1)
        
        loss = None
        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(logits, labels.float())
            
            # 添加您自定义的损失函数
            probs = torch.sigmoid(logits)
            ent = -(probs*torch.log(probs+1e-12) + 
                  (1-probs)*torch.log(1-probs+1e-12)).mean()
            cond_ent = -(probs*torch.log(probs+1e-12)).mean()
            reg_loss = self.config.lambda_ * ent - self.config.lambda_ * cond_ent
            
            loss = self.config.lambda_ * loss + (1 - self.config.lambda_) * reg_loss
        
        if return_dict:
            return {
                "loss": loss,
                "logits": logits,
            }
        else:
            return (loss, logits) if loss is not None else logits
        
PDeepPPModel.register_for_auto_class("AutoModel")