Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| # ========== 配置参数 ========== | |
| BATCH_SIZE = 128 | |
| EPOCHS = 50 | |
| LEARNING_RATE = 1e-4 | |
| MAX_SEQ_LEN = 60 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| D_MODEL = 512 | |
| N_HEAD = 8 | |
| NUM_LAYERS = 6 | |
| DIM_FEEDFORWARD = 2048 | |
| # ========== 数据加载 ========== | |
| class TransformerModel(nn.Module): | |
| def __init__(self, src_vocab_size, tgt_vocab_size, d_model=D_MODEL, nhead=N_HEAD, num_layers=NUM_LAYERS, dim_feedforward=DIM_FEEDFORWARD): | |
| super(TransformerModel, self).__init__() | |
| self.src_embedding = nn.Embedding(src_vocab_size, d_model) | |
| self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model) | |
| self.positional_encoding = nn.Parameter(torch.zeros(1, MAX_SEQ_LEN, d_model)) | |
| self.transformer = nn.Transformer( | |
| d_model=d_model, | |
| nhead=nhead, | |
| num_encoder_layers=num_layers, | |
| num_decoder_layers=num_layers, | |
| dim_feedforward=dim_feedforward, | |
| dropout=0.1, | |
| ) | |
| self.fc_out = nn.Linear(d_model, tgt_vocab_size) | |
| self.d_model = d_model | |
| def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None): | |
| src_emb = self.src_embedding(src) * (self.d_model ** 0.5) + self.positional_encoding[:, :src.size(1), :] | |
| tgt_emb = self.tgt_embedding(tgt) * (self.d_model ** 0.5) + self.positional_encoding[:, :tgt.size(1), :] | |
| output = self.transformer( | |
| src_emb.permute(1, 0, 2), | |
| tgt_emb.permute(1, 0, 2), | |
| src_mask=src_mask, | |
| tgt_mask=tgt_mask, | |
| src_key_padding_mask=src_padding_mask, | |
| tgt_key_padding_mask=tgt_padding_mask, | |
| ) | |
| return self.fc_out(output.permute(1, 0, 2)) |