Abdullah-Nazhat commited on
Commit
e787de7
·
verified ·
1 Parent(s): d11d0a9

Update mamba.py

Browse files
Files changed (1) hide show
  1. mamba.py +4 -4
mamba.py CHANGED
@@ -160,7 +160,7 @@ class MambaBlock(nn.Module):
160
 
161
  deltaBC = self.x_proj(x)
162
 
163
- delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
164
  delta = F.softplus(self.dt_proj(delta))
165
 
166
  if self.config.pscan:
@@ -196,7 +196,7 @@ class MambaBlock(nn.Module):
196
 
197
  BX = deltaB * (x.unsqueeze(-1))
198
 
199
- h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
200
  hs = []
201
 
202
  for t in range(0, L):
@@ -233,10 +233,10 @@ class MambaBlock(nn.Module):
233
  z = F.silu(z)
234
 
235
  output = y * z
236
- output = self.out_proj(output) # (B, D)
237
 
238
 
239
- inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) # (B, ED, d_conv-1)
240
  cache = (h, inputs)
241
 
242
  return output, cache
 
160
 
161
  deltaBC = self.x_proj(x)
162
 
163
+ delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1)
164
  delta = F.softplus(self.dt_proj(delta))
165
 
166
  if self.config.pscan:
 
196
 
197
  BX = deltaB * (x.unsqueeze(-1))
198
 
199
+ h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device)
200
  hs = []
201
 
202
  for t in range(0, L):
 
233
  z = F.silu(z)
234
 
235
  output = y * z
236
+ output = self.out_proj(output)
237
 
238
 
239
+ inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2)
240
  cache = (h, inputs)
241
 
242
  return output, cache