pmolchanov commited on
Commit
32d9f3b
·
verified ·
1 Parent(s): b67b129

Update modeling_hymba.py

Browse files
Files changed (1) hide show
  1. modeling_hymba.py +94 -86
modeling_hymba.py CHANGED
@@ -1679,102 +1679,110 @@ class HymbaBlock(nn.Module):
1679
  A = -torch.exp(self.A_log[index].float())
1680
 
1681
  time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
1682
- # try:
1683
- if use_precomputed_states:
1684
- scan_outputs = selective_state_update(
1685
- cache_params.ssm_states[self.layer_idx],
1686
- hidden_states[..., 0],
1687
- discrete_time_step[..., 0],
1688
- A,
1689
- B[:, 0],
1690
- C[:, 0],
1691
- self.D[index],
1692
- gate[..., 0],
1693
- time_proj_bias,
1694
- dt_softplus=True,
1695
- ).unsqueeze(-1)
1696
- else:
1697
- outputs = selective_scan_fn(
1698
- hidden_states,
1699
- discrete_time_step,
1700
- A,
1701
- B.transpose(1, 2),
1702
- C.transpose(1, 2),
1703
- self.D[index].float(),
1704
- z=gate,
1705
- delta_bias=time_proj_bias,
1706
- delta_softplus=True,
1707
- return_last_state=True,
1708
- )
1709
-
1710
- if len(outputs) == 3:
1711
- scan_outputs, ssm_state, _ = outputs
1712
  else:
1713
- scan_outputs, ssm_state = outputs
1714
-
1715
- if ssm_state is not None and cache_params is not None:
1716
- cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
1717
- # except Exception as e:
1718
- # print("\n\n\n\n")
1719
- # print(e)
1720
- # print(f"use_precomputed_states {use_precomputed_states}; {index} {self.D}, {time_proj_bias} ")
1721
- # print(f"{self.D[index]} ")
1722
- # # cache_params.ssm_states[self.layer_idx],
1723
- # # hidden_states[..., 0],
1724
- # # discrete_time_step[..., 0],
1725
- # # A,
1726
- # # B[:, 0],
1727
- # # C[:, 0],
1728
- # # self.D[index],
1729
- # # gate[..., 0],
1730
- # # time_proj_bias,
1731
- # print("=== Variable Values ===")
1732
- # try:
1733
- # print(f"cache_params.ssm_states[{self.layer_idx}]: {cache_params.ssm_states[self.layer_idx]}")
1734
- # except Exception as e:
1735
- # print(f"Error accessing cache_params.ssm_states[{self.layer_idx}]: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1736
 
1737
- # try:
1738
- # print(f"hidden_states[..., 0]: {hidden_states[..., 0]}")
1739
- # except Exception as e:
1740
- # print(f"Error accessing hidden_states[..., 0]: {e}")
 
1741
 
1742
- # try:
1743
- # print(f"discrete_time_step[..., 0]: {discrete_time_step[..., 0]}")
1744
- # except Exception as e:
1745
- # print(f"Error accessing discrete_time_step[..., 0]: {e}")
 
1746
 
1747
- # try:
1748
- # print(f"A: {A}")
1749
- # except Exception as e:
1750
- # print(f"Error accessing A: {e}")
 
1751
 
1752
- # try:
1753
- # print(f"B[:, 0]: {B[:, 0]}")
1754
- # except Exception as e:
1755
- # print(f"Error accessing B[:, 0]: {e}")
 
1756
 
1757
- # try:
1758
- # print(f"C[:, 0]: {C[:, 0]}")
1759
- # except Exception as e:
1760
- # print(f"Error accessing C[:, 0]: {e}")
 
1761
 
1762
- # try:
1763
- # print(f"D[index]: {self.D[index]}")
1764
- # except Exception as e:
1765
- # print(f"Error accessing D[{index}]: {e}")
 
1766
 
1767
- # try:
1768
- # print(f"gate[..., 0]: {gate[..., 0]}")
1769
- # except Exception as e:
1770
- # print(f"Error accessing gate[..., 0]: {e}")
 
1771
 
1772
- # try:
1773
- # print(f"time_proj_bias: {time_proj_bias}")
1774
- # except Exception as e:
1775
- # print(f"Error accessing time_proj_bias: {e}")
1776
 
1777
- # print("\n\n\n\n")
1778
 
1779
  scan_outputs = scan_outputs.transpose(1, 2)
1780
 
 
1679
  A = -torch.exp(self.A_log[index].float())
1680
 
1681
  time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
1682
+ try:
1683
+ if use_precomputed_states:
1684
+ scan_outputs = selective_state_update(
1685
+ cache_params.ssm_states[self.layer_idx],
1686
+ hidden_states[..., 0],
1687
+ discrete_time_step[..., 0],
1688
+ A,
1689
+ B[:, 0],
1690
+ C[:, 0],
1691
+ self.D[index],
1692
+ gate[..., 0],
1693
+ time_proj_bias,
1694
+ dt_softplus=True,
1695
+ ).unsqueeze(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1696
  else:
1697
+ outputs = selective_scan_fn(
1698
+ hidden_states,
1699
+ discrete_time_step,
1700
+ A,
1701
+ B.transpose(1, 2),
1702
+ C.transpose(1, 2),
1703
+ self.D[index].float(),
1704
+ z=gate,
1705
+ delta_bias=time_proj_bias,
1706
+ delta_softplus=True,
1707
+ return_last_state=True,
1708
+ )
1709
+
1710
+ if len(outputs) == 3:
1711
+ scan_outputs, ssm_state, _ = outputs
1712
+ else:
1713
+ scan_outputs, ssm_state = outputs
1714
+
1715
+ if ssm_state is not None and cache_params is not None:
1716
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
1717
+ except Exception as e:
1718
+ print("\n\n\n\n")
1719
+ print(e)
1720
+ print(f"use_precomputed_states {use_precomputed_states}; {index} {self.D}, {time_proj_bias} ")
1721
+ print(f"{self.D[index]} ")
1722
+ # cache_params.ssm_states[self.layer_idx],
1723
+ # hidden_states[..., 0],
1724
+ # discrete_time_step[..., 0],
1725
+ # A,
1726
+ # B[:, 0],
1727
+ # C[:, 0],
1728
+ # self.D[index],
1729
+ # gate[..., 0],
1730
+ # time_proj_bias,
1731
+ print("=== Variable Values ===")
1732
+ try:
1733
+ print(f"cache_params.ssm_states[{self.layer_idx}]: {cache_params.ssm_states[self.layer_idx]}")
1734
+ print(f"{cache_params.ssm_states[self.layer_idx].shape}")
1735
+ except Exception as e:
1736
+ print(f"Error accessing cache_params.ssm_states[{self.layer_idx}]: {e}")
1737
 
1738
+ try:
1739
+ print(f"hidden_states[..., 0]: {hidden_states[..., 0]}")
1740
+ print(f"hidden_states[..., 0] shape: {hidden_states[..., 0].shape}")
1741
+ except Exception as e:
1742
+ print(f"Error accessing hidden_states[..., 0]: {e}")
1743
 
1744
+ try:
1745
+ print(f"discrete_time_step[..., 0]: {discrete_time_step[..., 0]}")
1746
+ print(f"discrete_time_step[..., 0].shape: {discrete_time_step[..., 0].shape}")
1747
+ except Exception as e:
1748
+ print(f"Error accessing discrete_time_step[..., 0]: {e}")
1749
 
1750
+ try:
1751
+ print(f"A: {A}")
1752
+ print(f"A.shape: {A.shape}")
1753
+ except Exception as e:
1754
+ print(f"Error accessing A: {e}")
1755
 
1756
+ try:
1757
+ print(f"B[:, 0]: {B[:, 0].shape}")
1758
+ print(f"B[:, 0].shape: {B[:, 0].shape}")
1759
+ except Exception as e:
1760
+ print(f"Error accessing B[:, 0]: {e}")
1761
 
1762
+ try:
1763
+ print(f"C[:, 0]: {C[:, 0]}")
1764
+ print(f"C[:, 0].shape: {C[:, 0].shape}")
1765
+ except Exception as e:
1766
+ print(f"Error accessing C[:, 0]: {e}")
1767
 
1768
+ try:
1769
+ print(f"D[index]: {self.D[index]}")
1770
+ print(f"D[index].shape: {self.D[index].shape}")
1771
+ except Exception as e:
1772
+ print(f"Error accessing D[{index}]: {e}")
1773
 
1774
+ try:
1775
+ print(f"gate[..., 0]: {gate[..., 0]}")
1776
+ print(f"gate[..., 0].shape: {gate[..., 0].shape}")
1777
+ except Exception as e:
1778
+ print(f"Error accessing gate[..., 0]: {e}")
1779
 
1780
+ try:
1781
+ print(f"time_proj_bias: {time_proj_bias}")
1782
+ except Exception as e:
1783
+ print(f"Error accessing time_proj_bias: {e}")
1784
 
1785
+ print("\n\n\n\n")
1786
 
1787
  scan_outputs = scan_outputs.transpose(1, 2)
1788