Upload symbio_model.py with huggingface_hub
Browse files- symbio_model.py +1 -3
symbio_model.py
CHANGED
|
@@ -176,11 +176,9 @@ class MonarchMatrix(nn.Module):
|
|
| 176 |
|
| 177 |
def forward(self, x, causal_mask=None):
|
| 178 |
B, T, D_head = x.shape
|
| 179 |
-
M = self.realize()
|
| 180 |
if causal_mask is not None:
|
| 181 |
M = M * causal_mask[:T, :T]
|
| 182 |
-
else:
|
| 183 |
-
M = M[:T, :T]
|
| 184 |
x_flat = x.permute(1, 0, 2).reshape(T, B * D_head)
|
| 185 |
y_flat = M @ x_flat
|
| 186 |
return y_flat.reshape(T, B, D_head).permute(1, 0, 2)
|
|
|
|
| 176 |
|
| 177 |
def forward(self, x, causal_mask=None):
|
| 178 |
B, T, D_head = x.shape
|
| 179 |
+
M = self.realize()[:T, :T]
|
| 180 |
if causal_mask is not None:
|
| 181 |
M = M * causal_mask[:T, :T]
|
|
|
|
|
|
|
| 182 |
x_flat = x.permute(1, 0, 2).reshape(T, B * D_head)
|
| 183 |
y_flat = M @ x_flat
|
| 184 |
return y_flat.reshape(T, B, D_head).permute(1, 0, 2)
|