LisaMegaWatts commited on
Commit
280ed9e
·
verified ·
1 Parent(s): b1d797e

Upload symbio_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. symbio_model.py +1 -3
symbio_model.py CHANGED
@@ -176,11 +176,9 @@ class MonarchMatrix(nn.Module):
176
 
177
  def forward(self, x, causal_mask=None):
178
  B, T, D_head = x.shape
179
- M = self.realize()
180
  if causal_mask is not None:
181
  M = M * causal_mask[:T, :T]
182
- else:
183
- M = M[:T, :T]
184
  x_flat = x.permute(1, 0, 2).reshape(T, B * D_head)
185
  y_flat = M @ x_flat
186
  return y_flat.reshape(T, B, D_head).permute(1, 0, 2)
 
176
 
177
  def forward(self, x, causal_mask=None):
178
  B, T, D_head = x.shape
179
+ M = self.realize()[:T, :T]
180
  if causal_mask is not None:
181
  M = M * causal_mask[:T, :T]
 
 
182
  x_flat = x.permute(1, 0, 2).reshape(T, B * D_head)
183
  y_flat = M @ x_flat
184
  return y_flat.reshape(T, B, D_head).permute(1, 0, 2)