asd52403 commited on
Commit
4e45761
·
1 Parent(s): ae5c4e3

add a todo

Browse files
Files changed (1) hide show
  1. inference/model.py +2 -0
inference/model.py CHANGED
@@ -1244,6 +1244,8 @@ class Expert_int(nn.Module):
1244
  self.w2 = Linear_rescale_int(layer_id, inter_dim, dim, 1, 1, torch.int32)
1245
  self.w3 = Linear_rescale_int(layer_id, dim, inter_dim, 1, 1, torch.int32)
1246
 
 
 
1247
  def forward(self, start_pos: int, x: torch.Tensor) -> torch.Tensor:
1248
  """
1249
  Forward pass for the Expert layer.
 
1244
  self.w2 = Linear_rescale_int(layer_id, inter_dim, dim, 1, 1, torch.int32)
1245
  self.w3 = Linear_rescale_int(layer_id, dim, inter_dim, 1, 1, torch.int32)
1246
 
1247
+
1248
+ # todo: add row id in the forward function
1249
  def forward(self, start_pos: int, x: torch.Tensor) -> torch.Tensor:
1250
  """
1251
  Forward pass for the Expert layer.