pmolchanov commited on
Commit
c02a352
·
verified ·
1 Parent(s): 5f87817

Update modeling_hymba.py

Browse files
Files changed (1) hide show
  1. modeling_hymba.py +61 -61
modeling_hymba.py CHANGED
@@ -1714,76 +1714,76 @@ class HymbaBlock(nn.Module):
1714
 
1715
  if ssm_state is not None and cache_params is not None:
1716
  cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
1717
- if use_precomputed_states and self.layer_idx==31:
1718
- # except Exception as e:
1719
- print("\n\n\n\n")
1720
- # print(e)
1721
- print(f"use_precomputed_states {use_precomputed_states}; {index} {self.D}, {time_proj_bias} ")
1722
- print(f"{self.D[index]} ")
1723
- # cache_params.ssm_states[self.layer_idx],
1724
- # hidden_states[..., 0],
1725
- # discrete_time_step[..., 0],
1726
- # A,
1727
- # B[:, 0],
1728
- # C[:, 0],
1729
- # self.D[index],
1730
- # gate[..., 0],
1731
- # time_proj_bias,
1732
- print("=== Variable Values ===")
1733
- try:
1734
- print(f"cache_params.ssm_states[{self.layer_idx}]: {cache_params.ssm_states[self.layer_idx]}")
1735
- print(f"{cache_params.ssm_states[self.layer_idx].shape}")
1736
- except Exception as e:
1737
- print(f"Error accessing cache_params.ssm_states[{self.layer_idx}]: {e}")
1738
 
1739
- try:
1740
- print(f"hidden_states[..., 0]: {hidden_states[..., 0]}")
1741
- print(f"hidden_states[..., 0] shape: {hidden_states[..., 0].shape}")
1742
- except Exception as e:
1743
- print(f"Error accessing hidden_states[..., 0]: {e}")
1744
 
1745
- try:
1746
- print(f"discrete_time_step[..., 0]: {discrete_time_step[..., 0]}")
1747
- print(f"discrete_time_step[..., 0].shape: {discrete_time_step[..., 0].shape}")
1748
- except Exception as e:
1749
- print(f"Error accessing discrete_time_step[..., 0]: {e}")
1750
 
1751
- try:
1752
- print(f"A: {A}")
1753
- print(f"A.shape: {A.shape}")
1754
- except Exception as e:
1755
- print(f"Error accessing A: {e}")
1756
 
1757
- try:
1758
- print(f"B[:, 0]: {B[:, 0].shape}")
1759
- print(f"B[:, 0].shape: {B[:, 0].shape}")
1760
- except Exception as e:
1761
- print(f"Error accessing B[:, 0]: {e}")
1762
 
1763
- try:
1764
- print(f"C[:, 0]: {C[:, 0]}")
1765
- print(f"C[:, 0].shape: {C[:, 0].shape}")
1766
- except Exception as e:
1767
- print(f"Error accessing C[:, 0]: {e}")
1768
 
1769
- try:
1770
- print(f"D[index]: {self.D[index]}")
1771
- print(f"D[index].shape: {self.D[index].shape}")
1772
- except Exception as e:
1773
- print(f"Error accessing D[{index}]: {e}")
1774
 
1775
- try:
1776
- print(f"gate[..., 0]: {gate[..., 0]}")
1777
- print(f"gate[..., 0].shape: {gate[..., 0].shape}")
1778
- except Exception as e:
1779
- print(f"Error accessing gate[..., 0]: {e}")
1780
 
1781
- try:
1782
- print(f"time_proj_bias: {time_proj_bias}")
1783
- except Exception as e:
1784
- print(f"Error accessing time_proj_bias: {e}")
1785
 
1786
- print("\n\n\n\n")
1787
 
1788
  scan_outputs = scan_outputs.transpose(1, 2)
1789
 
 
1714
 
1715
  if ssm_state is not None and cache_params is not None:
1716
  cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
1717
+ # if use_precomputed_states and self.layer_idx==31:
1718
+ # # except Exception as e:
1719
+ # print("\n\n\n\n")
1720
+ # # print(e)
1721
+ # print(f"use_precomputed_states {use_precomputed_states}; {index} {self.D}, {time_proj_bias} ")
1722
+ # print(f"{self.D[index]} ")
1723
+ # # cache_params.ssm_states[self.layer_idx],
1724
+ # # hidden_states[..., 0],
1725
+ # # discrete_time_step[..., 0],
1726
+ # # A,
1727
+ # # B[:, 0],
1728
+ # # C[:, 0],
1729
+ # # self.D[index],
1730
+ # # gate[..., 0],
1731
+ # # time_proj_bias,
1732
+ # print("=== Variable Values ===")
1733
+ # try:
1734
+ # print(f"cache_params.ssm_states[{self.layer_idx}]: {cache_params.ssm_states[self.layer_idx]}")
1735
+ # print(f"{cache_params.ssm_states[self.layer_idx].shape}")
1736
+ # except Exception as e:
1737
+ # print(f"Error accessing cache_params.ssm_states[{self.layer_idx}]: {e}")
1738
 
1739
+ # try:
1740
+ # print(f"hidden_states[..., 0]: {hidden_states[..., 0]}")
1741
+ # print(f"hidden_states[..., 0] shape: {hidden_states[..., 0].shape}")
1742
+ # except Exception as e:
1743
+ # print(f"Error accessing hidden_states[..., 0]: {e}")
1744
 
1745
+ # try:
1746
+ # print(f"discrete_time_step[..., 0]: {discrete_time_step[..., 0]}")
1747
+ # print(f"discrete_time_step[..., 0].shape: {discrete_time_step[..., 0].shape}")
1748
+ # except Exception as e:
1749
+ # print(f"Error accessing discrete_time_step[..., 0]: {e}")
1750
 
1751
+ # try:
1752
+ # print(f"A: {A}")
1753
+ # print(f"A.shape: {A.shape}")
1754
+ # except Exception as e:
1755
+ # print(f"Error accessing A: {e}")
1756
 
1757
+ # try:
1758
+ # print(f"B[:, 0]: {B[:, 0].shape}")
1759
+ # print(f"B[:, 0].shape: {B[:, 0].shape}")
1760
+ # except Exception as e:
1761
+ # print(f"Error accessing B[:, 0]: {e}")
1762
 
1763
+ # try:
1764
+ # print(f"C[:, 0]: {C[:, 0]}")
1765
+ # print(f"C[:, 0].shape: {C[:, 0].shape}")
1766
+ # except Exception as e:
1767
+ # print(f"Error accessing C[:, 0]: {e}")
1768
 
1769
+ # try:
1770
+ # print(f"D[index]: {self.D[index]}")
1771
+ # print(f"D[index].shape: {self.D[index].shape}")
1772
+ # except Exception as e:
1773
+ # print(f"Error accessing D[{index}]: {e}")
1774
 
1775
+ # try:
1776
+ # print(f"gate[..., 0]: {gate[..., 0]}")
1777
+ # print(f"gate[..., 0].shape: {gate[..., 0].shape}")
1778
+ # except Exception as e:
1779
+ # print(f"Error accessing gate[..., 0]: {e}")
1780
 
1781
+ # try:
1782
+ # print(f"time_proj_bias: {time_proj_bias}")
1783
+ # except Exception as e:
1784
+ # print(f"Error accessing time_proj_bias: {e}")
1785
 
1786
+ # print("\n\n\n\n")
1787
 
1788
  scan_outputs = scan_outputs.transpose(1, 2)
1789