Update mamba.py
Browse files
mamba.py
CHANGED
|
@@ -160,7 +160,7 @@ class MambaBlock(nn.Module):
|
|
| 160 |
|
| 161 |
deltaBC = self.x_proj(x)
|
| 162 |
|
| 163 |
-
delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1)
|
| 164 |
delta = F.softplus(self.dt_proj(delta))
|
| 165 |
|
| 166 |
if self.config.pscan:
|
|
@@ -196,7 +196,7 @@ class MambaBlock(nn.Module):
|
|
| 196 |
|
| 197 |
BX = deltaB * (x.unsqueeze(-1))
|
| 198 |
|
| 199 |
-
h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device)
|
| 200 |
hs = []
|
| 201 |
|
| 202 |
for t in range(0, L):
|
|
@@ -233,10 +233,10 @@ class MambaBlock(nn.Module):
|
|
| 233 |
z = F.silu(z)
|
| 234 |
|
| 235 |
output = y * z
|
| 236 |
-
output = self.out_proj(output)
|
| 237 |
|
| 238 |
|
| 239 |
-
inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2)
|
| 240 |
cache = (h, inputs)
|
| 241 |
|
| 242 |
return output, cache
|
|
|
|
| 160 |
|
| 161 |
deltaBC = self.x_proj(x)
|
| 162 |
|
| 163 |
+
delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1)
|
| 164 |
delta = F.softplus(self.dt_proj(delta))
|
| 165 |
|
| 166 |
if self.config.pscan:
|
|
|
|
| 196 |
|
| 197 |
BX = deltaB * (x.unsqueeze(-1))
|
| 198 |
|
| 199 |
+
h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device)
|
| 200 |
hs = []
|
| 201 |
|
| 202 |
for t in range(0, L):
|
|
|
|
| 233 |
z = F.silu(z)
|
| 234 |
|
| 235 |
output = y * z
|
| 236 |
+
output = self.out_proj(output)
|
| 237 |
|
| 238 |
|
| 239 |
+
inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2)
|
| 240 |
cache = (h, inputs)
|
| 241 |
|
| 242 |
return output, cache
|