robinfaro commited on
Commit
45c846c
·
verified ·
1 Parent(s): 8311a4f

Upload moe.py

Browse files
Files changed (1) hide show
  1. moe.py +2 -1
moe.py CHANGED
@@ -86,13 +86,14 @@ class MaskedMoE(MoE):
86
 
87
 
88
  def forward(self, inputs: torch.Tensor, mask: torch.Tensor):
 
89
  inputs_squashed = inputs.view(-1, inputs.shape[-1])
90
  router_logits = self.router(inputs_squashed)
91
  mask = torch.cat(
92
  (mask, torch.ones((mask.shape[0], 1), device=mask.device)),
93
  dim=1
94
  )
95
- mask = mask.repeat_interleave(self._sequence_length, dim=0)
96
  router_logits = router_logits*mask
97
 
98
  # note that selected experts will be the same for all orders:
 
86
 
87
 
88
  def forward(self, inputs: torch.Tensor, mask: torch.Tensor):
89
+ seq_len = inputs.shape[1]
90
  inputs_squashed = inputs.view(-1, inputs.shape[-1])
91
  router_logits = self.router(inputs_squashed)
92
  mask = torch.cat(
93
  (mask, torch.ones((mask.shape[0], 1), device=mask.device)),
94
  dim=1
95
  )
96
+ mask = mask.repeat_interleave(seq_len, dim=0)
97
  router_logits = router_logits*mask
98
 
99
  # note that selected experts will be the same for all orders: