Update modeling_hymba.py
Browse files- modeling_hymba.py +3 -2
modeling_hymba.py
CHANGED
|
@@ -1679,7 +1679,7 @@ class HymbaBlock(nn.Module):
|
|
| 1679 |
A = -torch.exp(self.A_log[index].float())
|
| 1680 |
|
| 1681 |
time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
|
| 1682 |
-
|
| 1683 |
if use_precomputed_states:
|
| 1684 |
scan_outputs = selective_state_update(
|
| 1685 |
cache_params.ssm_states[self.layer_idx],
|
|
@@ -1714,7 +1714,8 @@ class HymbaBlock(nn.Module):
|
|
| 1714 |
|
| 1715 |
if ssm_state is not None and cache_params is not None:
|
| 1716 |
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
|
| 1717 |
-
|
|
|
|
| 1718 |
print("\n\n\n\n")
|
| 1719 |
print(e)
|
| 1720 |
print(f"use_precomputed_states {use_precomputed_states}; {index} {self.D}, {time_proj_bias} ")
|
|
|
|
| 1679 |
A = -torch.exp(self.A_log[index].float())
|
| 1680 |
|
| 1681 |
time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
|
| 1682 |
+
if 1:
|
| 1683 |
if use_precomputed_states:
|
| 1684 |
scan_outputs = selective_state_update(
|
| 1685 |
cache_params.ssm_states[self.layer_idx],
|
|
|
|
| 1714 |
|
| 1715 |
if ssm_state is not None and cache_params is not None:
|
| 1716 |
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
|
| 1717 |
+
if use_precomputed_states:
|
| 1718 |
+
# except Exception as e:
|
| 1719 |
print("\n\n\n\n")
|
| 1720 |
print(e)
|
| 1721 |
print(f"use_precomputed_states {use_precomputed_states}; {index} {self.D}, {time_proj_bias} ")
|