Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from Affine import Affine | |
| #投射器 | |
| class Projector(nn.Module): | |
| def __init__(self,embedding_dim,vocab_out_size,enable_affine): | |
| super(Projector,self).__init__() | |
| self.enable_affine = enable_affine | |
| self.project = nn.Linear(embedding_dim,vocab_out_size,bias=False) | |
| if enable_affine: | |
| self.affine = Affine(1.0) | |
| def forward(self, x): | |
| if self.enable_affine: | |
| ori_dist = self.affine(self.project(x)) | |
| else: | |
| ori_dist = self.project(x) | |
| return ori_dist | |
| #生成器 | |
| class Generator(nn.Module): | |
| def __init__(self,embedding,encoder,projector): | |
| super(Generator, self).__init__() | |
| self.model_type = "generator" | |
| self.embedding = embedding | |
| self.encoder = encoder | |
| self.projector = projector | |
| def forward(self,inputs,q_mask): | |
| query = self.embedding(inputs) | |
| out = self.encoder(query,q_mask) | |
| return self.projector(out) | |