Chinese
Jerry-v0.01-0.18B / make_model.py
mdokl's picture
Upload 11 files
2d12b97 verified
import copy
import torch
import torch.nn as nn
from Embedding import Embedding
from MultiHeadAttention import MultiHeadAttention
from DiffMultiHeadAttention import DiffMultiHeadAttention
from Encoder import PositionWiseFeedForward,EncoderLayer,Encoder
from Generator import Projector,Generator
def make_model(vocab_size,embedding_dim,key_dim,head_number,position_information_type,
enable_affine,enable_talking_head,use_diff,self_attention_block_size,
feed_forward_dim,enable_layer_norm,deep,dropout_rate):
#嵌入层
embedding = Embedding(
vocab_size = vocab_size,
embedding_dim = embedding_dim,
enable_affine = enable_affine,
position_information_type = position_information_type,
dropout_rate = dropout_rate
)
#多头自注意力层
if use_diff:
Attention = DiffMultiHeadAttention
else:
Attention = MultiHeadAttention
multi_head_attention = Attention(
embedding_dim = embedding_dim,
key_dim = key_dim,
head_number = head_number,
position_information_type = position_information_type,
enable_affine = enable_affine,
enable_talking_head = enable_talking_head,
self_attention_block_size = self_attention_block_size,
cross_attention_block_size = 0,#全是自注意力,用不到
dropout_rate = dropout_rate
)
#信息融合前馈网络
position_wise_feed_forward = PositionWiseFeedForward(
embedding_dim = embedding_dim,
feed_forward_dim = feed_forward_dim,
enable_affine = enable_affine
)
#编码器层
encoder_layer = EncoderLayer(
multi_head_attention = copy.deepcopy(multi_head_attention),
mask_future = True,#自注意力,都要遮盖
position_wise_feed_forward = copy.deepcopy(position_wise_feed_forward),
enable_layer_norm = enable_layer_norm,
dropout_rate = dropout_rate
)
#堆叠的编码器层组成编码器
encoder_layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(deep)])
encoder = Encoder(encoder_layers = encoder_layers)
#投射器
projector = Projector(
embedding_dim = embedding_dim,
vocab_out_size = vocab_size,
enable_affine = enable_affine
)
#生成器模型本身
model = Generator(
embedding = embedding,
encoder = encoder,
projector = projector
)
#模型参数初始化
for p in model.parameters():
#偏置,仿射参数不会随机设置
#矩阵形式的参数
if p.dim() == 2:
nn.init.xavier_uniform_(p)
return model