pmolchanov commited on
Commit
b67b129
·
verified ·
1 Parent(s): ad8a965

Update modeling_hymba.py

Browse files
Files changed (1) hide show
  1. modeling_hymba.py +86 -86
modeling_hymba.py CHANGED
@@ -1679,102 +1679,102 @@ class HymbaBlock(nn.Module):
1679
  A = -torch.exp(self.A_log[index].float())
1680
 
1681
  time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
1682
- try:
1683
- if use_precomputed_states:
1684
- scan_outputs = selective_state_update(
1685
- cache_params.ssm_states[self.layer_idx],
1686
- hidden_states[..., 0],
1687
- discrete_time_step[..., 0],
1688
- A,
1689
- B[:, 0],
1690
- C[:, 0],
1691
- self.D[index],
1692
- gate[..., 0],
1693
- time_proj_bias,
1694
- dt_softplus=True,
1695
- ).unsqueeze(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1696
  else:
1697
- outputs = selective_scan_fn(
1698
- hidden_states,
1699
- discrete_time_step,
1700
- A,
1701
- B.transpose(1, 2),
1702
- C.transpose(1, 2),
1703
- self.D[index].float(),
1704
- z=gate,
1705
- delta_bias=time_proj_bias,
1706
- delta_softplus=True,
1707
- return_last_state=True,
1708
- )
1709
-
1710
- if len(outputs) == 3:
1711
- scan_outputs, ssm_state, _ = outputs
1712
- else:
1713
- scan_outputs, ssm_state = outputs
1714
-
1715
- if ssm_state is not None and cache_params is not None:
1716
- cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
1717
- except Exception as e:
1718
- print("\n\n\n\n")
1719
- print(e)
1720
- print(f"use_precomputed_states {use_precomputed_states}; {index} {self.D}, {time_proj_bias} ")
1721
- print(f"{self.D[index]} ")
1722
- # cache_params.ssm_states[self.layer_idx],
1723
- # hidden_states[..., 0],
1724
- # discrete_time_step[..., 0],
1725
- # A,
1726
- # B[:, 0],
1727
- # C[:, 0],
1728
- # self.D[index],
1729
- # gate[..., 0],
1730
- # time_proj_bias,
1731
- print("=== Variable Values ===")
1732
- try:
1733
- print(f"cache_params.ssm_states[{self.layer_idx}]: {cache_params.ssm_states[self.layer_idx]}")
1734
- except Exception as e:
1735
- print(f"Error accessing cache_params.ssm_states[{self.layer_idx}]: {e}")
1736
 
1737
- try:
1738
- print(f"hidden_states[..., 0]: {hidden_states[..., 0]}")
1739
- except Exception as e:
1740
- print(f"Error accessing hidden_states[..., 0]: {e}")
1741
 
1742
- try:
1743
- print(f"discrete_time_step[..., 0]: {discrete_time_step[..., 0]}")
1744
- except Exception as e:
1745
- print(f"Error accessing discrete_time_step[..., 0]: {e}")
1746
 
1747
- try:
1748
- print(f"A: {A}")
1749
- except Exception as e:
1750
- print(f"Error accessing A: {e}")
1751
 
1752
- try:
1753
- print(f"B[:, 0]: {B[:, 0]}")
1754
- except Exception as e:
1755
- print(f"Error accessing B[:, 0]: {e}")
1756
 
1757
- try:
1758
- print(f"C[:, 0]: {C[:, 0]}")
1759
- except Exception as e:
1760
- print(f"Error accessing C[:, 0]: {e}")
1761
 
1762
- try:
1763
- print(f"D[index]: {self.D[index]}")
1764
- except Exception as e:
1765
- print(f"Error accessing D[{index}]: {e}")
1766
 
1767
- try:
1768
- print(f"gate[..., 0]: {gate[..., 0]}")
1769
- except Exception as e:
1770
- print(f"Error accessing gate[..., 0]: {e}")
1771
 
1772
- try:
1773
- print(f"time_proj_bias: {time_proj_bias}")
1774
- except Exception as e:
1775
- print(f"Error accessing time_proj_bias: {e}")
1776
 
1777
- print("\n\n\n\n")
1778
 
1779
  scan_outputs = scan_outputs.transpose(1, 2)
1780
 
 
1679
  A = -torch.exp(self.A_log[index].float())
1680
 
1681
  time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
1682
+ # try:
1683
+ if use_precomputed_states:
1684
+ scan_outputs = selective_state_update(
1685
+ cache_params.ssm_states[self.layer_idx],
1686
+ hidden_states[..., 0],
1687
+ discrete_time_step[..., 0],
1688
+ A,
1689
+ B[:, 0],
1690
+ C[:, 0],
1691
+ self.D[index],
1692
+ gate[..., 0],
1693
+ time_proj_bias,
1694
+ dt_softplus=True,
1695
+ ).unsqueeze(-1)
1696
+ else:
1697
+ outputs = selective_scan_fn(
1698
+ hidden_states,
1699
+ discrete_time_step,
1700
+ A,
1701
+ B.transpose(1, 2),
1702
+ C.transpose(1, 2),
1703
+ self.D[index].float(),
1704
+ z=gate,
1705
+ delta_bias=time_proj_bias,
1706
+ delta_softplus=True,
1707
+ return_last_state=True,
1708
+ )
1709
+
1710
+ if len(outputs) == 3:
1711
+ scan_outputs, ssm_state, _ = outputs
1712
  else:
1713
+ scan_outputs, ssm_state = outputs
1714
+
1715
+ if ssm_state is not None and cache_params is not None:
1716
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
1717
+ # except Exception as e:
1718
+ # print("\n\n\n\n")
1719
+ # print(e)
1720
+ # print(f"use_precomputed_states {use_precomputed_states}; {index} {self.D}, {time_proj_bias} ")
1721
+ # print(f"{self.D[index]} ")
1722
+ # # cache_params.ssm_states[self.layer_idx],
1723
+ # # hidden_states[..., 0],
1724
+ # # discrete_time_step[..., 0],
1725
+ # # A,
1726
+ # # B[:, 0],
1727
+ # # C[:, 0],
1728
+ # # self.D[index],
1729
+ # # gate[..., 0],
1730
+ # # time_proj_bias,
1731
+ # print("=== Variable Values ===")
1732
+ # try:
1733
+ # print(f"cache_params.ssm_states[{self.layer_idx}]: {cache_params.ssm_states[self.layer_idx]}")
1734
+ # except Exception as e:
1735
+ # print(f"Error accessing cache_params.ssm_states[{self.layer_idx}]: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1736
 
1737
+ # try:
1738
+ # print(f"hidden_states[..., 0]: {hidden_states[..., 0]}")
1739
+ # except Exception as e:
1740
+ # print(f"Error accessing hidden_states[..., 0]: {e}")
1741
 
1742
+ # try:
1743
+ # print(f"discrete_time_step[..., 0]: {discrete_time_step[..., 0]}")
1744
+ # except Exception as e:
1745
+ # print(f"Error accessing discrete_time_step[..., 0]: {e}")
1746
 
1747
+ # try:
1748
+ # print(f"A: {A}")
1749
+ # except Exception as e:
1750
+ # print(f"Error accessing A: {e}")
1751
 
1752
+ # try:
1753
+ # print(f"B[:, 0]: {B[:, 0]}")
1754
+ # except Exception as e:
1755
+ # print(f"Error accessing B[:, 0]: {e}")
1756
 
1757
+ # try:
1758
+ # print(f"C[:, 0]: {C[:, 0]}")
1759
+ # except Exception as e:
1760
+ # print(f"Error accessing C[:, 0]: {e}")
1761
 
1762
+ # try:
1763
+ # print(f"D[index]: {self.D[index]}")
1764
+ # except Exception as e:
1765
+ # print(f"Error accessing D[{index}]: {e}")
1766
 
1767
+ # try:
1768
+ # print(f"gate[..., 0]: {gate[..., 0]}")
1769
+ # except Exception as e:
1770
+ # print(f"Error accessing gate[..., 0]: {e}")
1771
 
1772
+ # try:
1773
+ # print(f"time_proj_bias: {time_proj_bias}")
1774
+ # except Exception as e:
1775
+ # print(f"Error accessing time_proj_bias: {e}")
1776
 
1777
+ # print("\n\n\n\n")
1778
 
1779
  scan_outputs = scan_outputs.transpose(1, 2)
1780