File size: 3,033 Bytes
2d36faf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import torch.nn.functional as F
from Affine import Affine

#借来一用,简单改改
class Qwen2RMSNorm(nn.Module):
    def __init__(self, embedding_dim, eps=1e-6):
        """

        Qwen2RMSNorm is equivalent to T5LayerNorm

        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(embedding_dim))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        # input_dtype = hidden_states.dtype
        # hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states#.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
        
#针对每个词嵌入的前馈网络
class PositionWiseFeedForward(nn.Module):
    def __init__(self,embedding_dim,feed_forward_dim,enable_affine):
        super(PositionWiseFeedForward, self).__init__()
        self.w1 = nn.Linear(embedding_dim, feed_forward_dim, bias=False)
        self.w2 = nn.Linear(feed_forward_dim, embedding_dim, bias=False)
        self.enable_affine = enable_affine
        if enable_affine:
            self.a1 = Affine(1.0)
            self.a2 = Affine(1.0)
        
    def forward(self, x):
        if self.enable_affine:
            x = F.relu(self.w1(self.a1(x)))
            return F.relu(self.w2(self.a2(x)))
        else:
            x = F.relu(self.w1(x))
            return F.relu(self.w2(x))

#编码器层
class EncoderLayer(nn.Module):
    def __init__(self,multi_head_attention,mask_future,position_wise_feed_forward,enable_layer_norm,dropout_rate):
        super(EncoderLayer,self).__init__()
        self.multi_head_attention = multi_head_attention
        self.position_wise_feed_forward = position_wise_feed_forward
        self.mask_future = mask_future
        if enable_layer_norm == True:
            self.layer_norm = Qwen2RMSNorm(multi_head_attention.embedding_dim)
        else:
            self.layer_norm = None

        self.dropout_layer = nn.Dropout(p=dropout_rate)

    def forward(self,query,q_mask,session_id):
        #绝对不能用+=,那是原地修改,没法算梯度
        query = query + self.dropout_layer(self.multi_head_attention(query,q_mask,query,self.mask_future,session_id))
        query = query + self.dropout_layer(self.position_wise_feed_forward(query))
        if self.layer_norm is not None:
            query = self.layer_norm(query)
        return query

#编码器
class Encoder(nn.Module):
    def __init__(self, encoder_layers):
        super(Encoder, self).__init__()
        self.encoder_layers = encoder_layers
        
    def forward(self, query, q_mask,session_id):
        for encoder_layer in self.encoder_layers:
            query = encoder_layer(query,q_mask,session_id)
        return query