pmolchanov commited on
Commit
28ef867
·
verified ·
1 Parent(s): ffc758e

Update modeling_hymba.py

Browse files

testing error during generation

Files changed (1) hide show
  1. modeling_hymba.py +36 -33
modeling_hymba.py CHANGED
@@ -1679,40 +1679,43 @@ class HymbaBlock(nn.Module):
1679
  A = -torch.exp(self.A_log[index].float())
1680
 
1681
  time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
1682
- if use_precomputed_states:
1683
- scan_outputs = selective_state_update(
1684
- cache_params.ssm_states[self.layer_idx],
1685
- hidden_states[..., 0],
1686
- discrete_time_step[..., 0],
1687
- A,
1688
- B[:, 0],
1689
- C[:, 0],
1690
- self.D[index],
1691
- gate[..., 0],
1692
- time_proj_bias,
1693
- dt_softplus=True,
1694
- ).unsqueeze(-1)
1695
- else:
1696
- outputs = selective_scan_fn(
1697
- hidden_states,
1698
- discrete_time_step,
1699
- A,
1700
- B.transpose(1, 2),
1701
- C.transpose(1, 2),
1702
- self.D[index].float(),
1703
- z=gate,
1704
- delta_bias=time_proj_bias,
1705
- delta_softplus=True,
1706
- return_last_state=True,
1707
- )
1708
-
1709
- if len(outputs) == 3:
1710
- scan_outputs, ssm_state, _ = outputs
1711
  else:
1712
- scan_outputs, ssm_state = outputs
1713
-
1714
- if ssm_state is not None and cache_params is not None:
1715
- cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1716
 
1717
  scan_outputs = scan_outputs.transpose(1, 2)
1718
 
 
1679
  A = -torch.exp(self.A_log[index].float())
1680
 
1681
  time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
1682
+ try:
1683
+ if use_precomputed_states:
1684
+ scan_outputs = selective_state_update(
1685
+ cache_params.ssm_states[self.layer_idx],
1686
+ hidden_states[..., 0],
1687
+ discrete_time_step[..., 0],
1688
+ A,
1689
+ B[:, 0],
1690
+ C[:, 0],
1691
+ self.D[index],
1692
+ gate[..., 0],
1693
+ time_proj_bias,
1694
+ dt_softplus=True,
1695
+ ).unsqueeze(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1696
  else:
1697
+ outputs = selective_scan_fn(
1698
+ hidden_states,
1699
+ discrete_time_step,
1700
+ A,
1701
+ B.transpose(1, 2),
1702
+ C.transpose(1, 2),
1703
+ self.D[index].float(),
1704
+ z=gate,
1705
+ delta_bias=time_proj_bias,
1706
+ delta_softplus=True,
1707
+ return_last_state=True,
1708
+ )
1709
+
1710
+ if len(outputs) == 3:
1711
+ scan_outputs, ssm_state, _ = outputs
1712
+ else:
1713
+ scan_outputs, ssm_state = outputs
1714
+
1715
+ if ssm_state is not None and cache_params is not None:
1716
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
1717
+ except:
1718
+ print(f"use_precomputed_states {use_precomputed_states}; {index} {self.D}, {delta_bias} ")
1719
 
1720
  scan_outputs = scan_outputs.transpose(1, 2)
1721