Update modeling_hymba.py
Browse files- modeling_hymba.py +94 -86
modeling_hymba.py
CHANGED
|
@@ -1679,102 +1679,110 @@ class HymbaBlock(nn.Module):
|
|
| 1679 |
A = -torch.exp(self.A_log[index].float())
|
| 1680 |
|
| 1681 |
time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
|
| 1682 |
-
|
| 1683 |
-
|
| 1684 |
-
|
| 1685 |
-
|
| 1686 |
-
|
| 1687 |
-
|
| 1688 |
-
|
| 1689 |
-
|
| 1690 |
-
|
| 1691 |
-
|
| 1692 |
-
|
| 1693 |
-
|
| 1694 |
-
|
| 1695 |
-
|
| 1696 |
-
else:
|
| 1697 |
-
outputs = selective_scan_fn(
|
| 1698 |
-
hidden_states,
|
| 1699 |
-
discrete_time_step,
|
| 1700 |
-
A,
|
| 1701 |
-
B.transpose(1, 2),
|
| 1702 |
-
C.transpose(1, 2),
|
| 1703 |
-
self.D[index].float(),
|
| 1704 |
-
z=gate,
|
| 1705 |
-
delta_bias=time_proj_bias,
|
| 1706 |
-
delta_softplus=True,
|
| 1707 |
-
return_last_state=True,
|
| 1708 |
-
)
|
| 1709 |
-
|
| 1710 |
-
if len(outputs) == 3:
|
| 1711 |
-
scan_outputs, ssm_state, _ = outputs
|
| 1712 |
else:
|
| 1713 |
-
|
| 1714 |
-
|
| 1715 |
-
|
| 1716 |
-
|
| 1717 |
-
|
| 1718 |
-
|
| 1719 |
-
|
| 1720 |
-
|
| 1721 |
-
|
| 1722 |
-
|
| 1723 |
-
|
| 1724 |
-
|
| 1725 |
-
|
| 1726 |
-
|
| 1727 |
-
|
| 1728 |
-
|
| 1729 |
-
|
| 1730 |
-
|
| 1731 |
-
|
| 1732 |
-
|
| 1733 |
-
|
| 1734 |
-
|
| 1735 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1736 |
|
| 1737 |
-
|
| 1738 |
-
|
| 1739 |
-
|
| 1740 |
-
|
|
|
|
| 1741 |
|
| 1742 |
-
|
| 1743 |
-
|
| 1744 |
-
|
| 1745 |
-
|
|
|
|
| 1746 |
|
| 1747 |
-
|
| 1748 |
-
|
| 1749 |
-
|
| 1750 |
-
|
|
|
|
| 1751 |
|
| 1752 |
-
|
| 1753 |
-
|
| 1754 |
-
|
| 1755 |
-
|
|
|
|
| 1756 |
|
| 1757 |
-
|
| 1758 |
-
|
| 1759 |
-
|
| 1760 |
-
|
|
|
|
| 1761 |
|
| 1762 |
-
|
| 1763 |
-
|
| 1764 |
-
|
| 1765 |
-
|
|
|
|
| 1766 |
|
| 1767 |
-
|
| 1768 |
-
|
| 1769 |
-
|
| 1770 |
-
|
|
|
|
| 1771 |
|
| 1772 |
-
|
| 1773 |
-
|
| 1774 |
-
|
| 1775 |
-
|
| 1776 |
|
| 1777 |
-
|
| 1778 |
|
| 1779 |
scan_outputs = scan_outputs.transpose(1, 2)
|
| 1780 |
|
|
|
|
| 1679 |
A = -torch.exp(self.A_log[index].float())
|
| 1680 |
|
| 1681 |
time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
|
| 1682 |
+
try:
|
| 1683 |
+
if use_precomputed_states:
|
| 1684 |
+
scan_outputs = selective_state_update(
|
| 1685 |
+
cache_params.ssm_states[self.layer_idx],
|
| 1686 |
+
hidden_states[..., 0],
|
| 1687 |
+
discrete_time_step[..., 0],
|
| 1688 |
+
A,
|
| 1689 |
+
B[:, 0],
|
| 1690 |
+
C[:, 0],
|
| 1691 |
+
self.D[index],
|
| 1692 |
+
gate[..., 0],
|
| 1693 |
+
time_proj_bias,
|
| 1694 |
+
dt_softplus=True,
|
| 1695 |
+
).unsqueeze(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1696 |
else:
|
| 1697 |
+
outputs = selective_scan_fn(
|
| 1698 |
+
hidden_states,
|
| 1699 |
+
discrete_time_step,
|
| 1700 |
+
A,
|
| 1701 |
+
B.transpose(1, 2),
|
| 1702 |
+
C.transpose(1, 2),
|
| 1703 |
+
self.D[index].float(),
|
| 1704 |
+
z=gate,
|
| 1705 |
+
delta_bias=time_proj_bias,
|
| 1706 |
+
delta_softplus=True,
|
| 1707 |
+
return_last_state=True,
|
| 1708 |
+
)
|
| 1709 |
+
|
| 1710 |
+
if len(outputs) == 3:
|
| 1711 |
+
scan_outputs, ssm_state, _ = outputs
|
| 1712 |
+
else:
|
| 1713 |
+
scan_outputs, ssm_state = outputs
|
| 1714 |
+
|
| 1715 |
+
if ssm_state is not None and cache_params is not None:
|
| 1716 |
+
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
|
| 1717 |
+
except Exception as e:
|
| 1718 |
+
print("\n\n\n\n")
|
| 1719 |
+
print(e)
|
| 1720 |
+
print(f"use_precomputed_states {use_precomputed_states}; {index} {self.D}, {time_proj_bias} ")
|
| 1721 |
+
print(f"{self.D[index]} ")
|
| 1722 |
+
# cache_params.ssm_states[self.layer_idx],
|
| 1723 |
+
# hidden_states[..., 0],
|
| 1724 |
+
# discrete_time_step[..., 0],
|
| 1725 |
+
# A,
|
| 1726 |
+
# B[:, 0],
|
| 1727 |
+
# C[:, 0],
|
| 1728 |
+
# self.D[index],
|
| 1729 |
+
# gate[..., 0],
|
| 1730 |
+
# time_proj_bias,
|
| 1731 |
+
print("=== Variable Values ===")
|
| 1732 |
+
try:
|
| 1733 |
+
print(f"cache_params.ssm_states[{self.layer_idx}]: {cache_params.ssm_states[self.layer_idx]}")
|
| 1734 |
+
print(f"{cache_params.ssm_states[self.layer_idx].shape}")
|
| 1735 |
+
except Exception as e:
|
| 1736 |
+
print(f"Error accessing cache_params.ssm_states[{self.layer_idx}]: {e}")
|
| 1737 |
|
| 1738 |
+
try:
|
| 1739 |
+
print(f"hidden_states[..., 0]: {hidden_states[..., 0]}")
|
| 1740 |
+
print(f"hidden_states[..., 0] shape: {hidden_states[..., 0].shape}")
|
| 1741 |
+
except Exception as e:
|
| 1742 |
+
print(f"Error accessing hidden_states[..., 0]: {e}")
|
| 1743 |
|
| 1744 |
+
try:
|
| 1745 |
+
print(f"discrete_time_step[..., 0]: {discrete_time_step[..., 0]}")
|
| 1746 |
+
print(f"discrete_time_step[..., 0].shape: {discrete_time_step[..., 0].shape}")
|
| 1747 |
+
except Exception as e:
|
| 1748 |
+
print(f"Error accessing discrete_time_step[..., 0]: {e}")
|
| 1749 |
|
| 1750 |
+
try:
|
| 1751 |
+
print(f"A: {A}")
|
| 1752 |
+
print(f"A.shape: {A.shape}")
|
| 1753 |
+
except Exception as e:
|
| 1754 |
+
print(f"Error accessing A: {e}")
|
| 1755 |
|
| 1756 |
+
try:
|
| 1757 |
+
print(f"B[:, 0]: {B[:, 0].shape}")
|
| 1758 |
+
print(f"B[:, 0].shape: {B[:, 0].shape}")
|
| 1759 |
+
except Exception as e:
|
| 1760 |
+
print(f"Error accessing B[:, 0]: {e}")
|
| 1761 |
|
| 1762 |
+
try:
|
| 1763 |
+
print(f"C[:, 0]: {C[:, 0]}")
|
| 1764 |
+
print(f"C[:, 0].shape: {C[:, 0].shape}")
|
| 1765 |
+
except Exception as e:
|
| 1766 |
+
print(f"Error accessing C[:, 0]: {e}")
|
| 1767 |
|
| 1768 |
+
try:
|
| 1769 |
+
print(f"D[index]: {self.D[index]}")
|
| 1770 |
+
print(f"D[index].shape: {self.D[index].shape}")
|
| 1771 |
+
except Exception as e:
|
| 1772 |
+
print(f"Error accessing D[{index}]: {e}")
|
| 1773 |
|
| 1774 |
+
try:
|
| 1775 |
+
print(f"gate[..., 0]: {gate[..., 0]}")
|
| 1776 |
+
print(f"gate[..., 0].shape: {gate[..., 0].shape}")
|
| 1777 |
+
except Exception as e:
|
| 1778 |
+
print(f"Error accessing gate[..., 0]: {e}")
|
| 1779 |
|
| 1780 |
+
try:
|
| 1781 |
+
print(f"time_proj_bias: {time_proj_bias}")
|
| 1782 |
+
except Exception as e:
|
| 1783 |
+
print(f"Error accessing time_proj_bias: {e}")
|
| 1784 |
|
| 1785 |
+
print("\n\n\n\n")
|
| 1786 |
|
| 1787 |
scan_outputs = scan_outputs.transpose(1, 2)
|
| 1788 |
|