Commit
·
ebdbb00
1
Parent(s):
52551dd
Corrected x repeat before conv
Browse files- configuration_hymba.py +3 -0
- modeling_hymba.py +58 -24
configuration_hymba.py
CHANGED
|
@@ -47,6 +47,7 @@ class HymbaConfig(PretrainedConfig):
|
|
| 47 |
global_attn_idx=None,
|
| 48 |
num_mamba=1,
|
| 49 |
pure_attn=False,
|
|
|
|
| 50 |
attn_implementation_new='sdpa',
|
| 51 |
rope_type=None,
|
| 52 |
attn_factor=0.5,
|
|
@@ -113,6 +114,8 @@ class HymbaConfig(PretrainedConfig):
|
|
| 113 |
|
| 114 |
self.pure_attn = pure_attn
|
| 115 |
|
|
|
|
|
|
|
| 116 |
super().__init__(
|
| 117 |
pad_token_id=pad_token_id,
|
| 118 |
bos_token_id=bos_token_id,
|
|
|
|
| 47 |
global_attn_idx=None,
|
| 48 |
num_mamba=1,
|
| 49 |
pure_attn=False,
|
| 50 |
+
repeat_x_before_conv=True,
|
| 51 |
attn_implementation_new='sdpa',
|
| 52 |
rope_type=None,
|
| 53 |
attn_factor=0.5,
|
|
|
|
| 114 |
|
| 115 |
self.pure_attn = pure_attn
|
| 116 |
|
| 117 |
+
self.repeat_x_before_conv = repeat_x_before_conv
|
| 118 |
+
|
| 119 |
super().__init__(
|
| 120 |
pad_token_id=pad_token_id,
|
| 121 |
bos_token_id=bos_token_id,
|
modeling_hymba.py
CHANGED
|
@@ -420,8 +420,23 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
| 420 |
def __init__(self, config, batch_size, dtype=torch.float16, device=None, layer_type=None):
|
| 421 |
self.dtype = dtype
|
| 422 |
# self.layers_block_type = config.layers_block_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
self.has_previous_state = False # only used by mamba
|
| 424 |
-
intermediate_size = config.mamba_expand * config.hidden_size
|
| 425 |
ssm_state_size = config.mamba_d_state
|
| 426 |
conv_kernel_size = config.mamba_d_conv
|
| 427 |
self.conv_states = []
|
|
@@ -439,12 +454,12 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|
| 439 |
if hasattr(config, 'conv_dim'):
|
| 440 |
conv_dim = config.conv_dim[str(i)]
|
| 441 |
else:
|
| 442 |
-
conv_dim =
|
| 443 |
self.conv_states += [
|
| 444 |
torch.zeros(batch_size, conv_dim, conv_kernel_size, device=device, dtype=dtype)
|
| 445 |
]
|
| 446 |
self.ssm_states += [
|
| 447 |
-
torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
|
| 448 |
]
|
| 449 |
else:
|
| 450 |
self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
|
|
@@ -1592,19 +1607,30 @@ class HymbaBlock(nn.Module):
|
|
| 1592 |
|
| 1593 |
if not self.pure_attn:
|
| 1594 |
|
|
|
|
| 1595 |
num_ssm_param = 1
|
| 1596 |
|
| 1597 |
if not hasattr(config, 'conv_dim'):
|
| 1598 |
config.conv_dim = {str(i):0 for i in range(config.num_hidden_layers)}
|
| 1599 |
|
| 1600 |
-
self.
|
| 1601 |
-
|
| 1602 |
-
|
| 1603 |
-
|
| 1604 |
-
|
| 1605 |
-
|
| 1606 |
-
|
| 1607 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1608 |
|
| 1609 |
config.conv_dim[str(self.layer_idx)] = self.xB_size
|
| 1610 |
|
|
@@ -1724,7 +1750,7 @@ class HymbaBlock(nn.Module):
|
|
| 1724 |
|
| 1725 |
index = 0
|
| 1726 |
# ssm_parameters = self.x_proj[index](hidden_states.transpose(1, 2))
|
| 1727 |
-
B, C, x,
|
| 1728 |
hidden_states.transpose(1,2), [self.xB_size, self.intermediate_size, self.xB_size, self.time_step_rank], dim=-1
|
| 1729 |
)
|
| 1730 |
|
|
@@ -1734,14 +1760,18 @@ class HymbaBlock(nn.Module):
|
|
| 1734 |
B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous()
|
| 1735 |
C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.ssm_state_size).contiguous()
|
| 1736 |
|
| 1737 |
-
x = rearrange(x, "b l d -> b d l")
|
| 1738 |
-
|
| 1739 |
-
|
| 1740 |
-
|
|
|
|
|
|
|
|
|
|
| 1741 |
|
| 1742 |
#Run convolution
|
| 1743 |
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
|
| 1744 |
|
|
|
|
| 1745 |
if use_precomputed_states:
|
| 1746 |
x = causal_conv1d_update(
|
| 1747 |
x.squeeze(-1),
|
|
@@ -1754,18 +1784,22 @@ class HymbaBlock(nn.Module):
|
|
| 1754 |
|
| 1755 |
cache_params.mamba_past_length[self.layer_idx] += seq_len
|
| 1756 |
else:
|
| 1757 |
-
if cache_params is not None:
|
| 1758 |
-
|
| 1759 |
-
|
| 1760 |
-
|
| 1761 |
-
|
| 1762 |
-
|
| 1763 |
-
cache_params.mamba_past_length[self.layer_idx] += seq_len
|
| 1764 |
|
| 1765 |
x = causal_conv1d_fn(
|
| 1766 |
x, conv_weights, self.conv1d.bias, activation=self.activation
|
| 1767 |
)
|
| 1768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1769 |
## Handle padding for Mamba: Set padding tokens to 0
|
| 1770 |
if seq_len > 1 and attention_mask is not None and (attention_mask == 0).any():
|
| 1771 |
x = x * attention_mask.unsqueeze(1).to(x)
|
|
@@ -1792,7 +1826,7 @@ class HymbaBlock(nn.Module):
|
|
| 1792 |
if use_precomputed_states:
|
| 1793 |
scan_outputs = selective_state_update(
|
| 1794 |
cache_params.ssm_states[self.layer_idx],
|
| 1795 |
-
x,
|
| 1796 |
discrete_time_step,
|
| 1797 |
A,
|
| 1798 |
B,
|
|
|
|
| 420 |
def __init__(self, config, batch_size, dtype=torch.float16, device=None, layer_type=None):
|
| 421 |
self.dtype = dtype
|
| 422 |
# self.layers_block_type = config.layers_block_type
|
| 423 |
+
|
| 424 |
+
self.pure_attn = config.pure_attn
|
| 425 |
+
|
| 426 |
+
if self.pure_attn:
|
| 427 |
+
self.attn_hidden_size = config.hidden_size
|
| 428 |
+
self.intermediate_size = int(config.mamba_expand * config.hidden_size)
|
| 429 |
+
else:
|
| 430 |
+
self.attn_hidden_size = int(config.hidden_size * config.attn_factor)
|
| 431 |
+
config.attn_hidden_size = self.attn_hidden_size
|
| 432 |
+
self.intermediate_size = int(config.mamba_expand * config.hidden_size * (1-config.attn_factor))
|
| 433 |
+
|
| 434 |
+
self.xB_size = int(config.num_key_value_heads/config.num_attention_heads * self.intermediate_size)
|
| 435 |
+
|
| 436 |
+
# self.num_xb_head = self.xB_size // self.ssm_state_size
|
| 437 |
+
|
| 438 |
+
|
| 439 |
self.has_previous_state = False # only used by mamba
|
|
|
|
| 440 |
ssm_state_size = config.mamba_d_state
|
| 441 |
conv_kernel_size = config.mamba_d_conv
|
| 442 |
self.conv_states = []
|
|
|
|
| 454 |
if hasattr(config, 'conv_dim'):
|
| 455 |
conv_dim = config.conv_dim[str(i)]
|
| 456 |
else:
|
| 457 |
+
conv_dim = self.xB_size
|
| 458 |
self.conv_states += [
|
| 459 |
torch.zeros(batch_size, conv_dim, conv_kernel_size, device=device, dtype=dtype)
|
| 460 |
]
|
| 461 |
self.ssm_states += [
|
| 462 |
+
torch.zeros(batch_size, self.intermediate_size, ssm_state_size, device=device, dtype=dtype)
|
| 463 |
]
|
| 464 |
else:
|
| 465 |
self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
|
|
|
|
| 1607 |
|
| 1608 |
if not self.pure_attn:
|
| 1609 |
|
| 1610 |
+
self.repeat_x_before_conv = config.repeat_x_before_conv
|
| 1611 |
num_ssm_param = 1
|
| 1612 |
|
| 1613 |
if not hasattr(config, 'conv_dim'):
|
| 1614 |
config.conv_dim = {str(i):0 for i in range(config.num_hidden_layers)}
|
| 1615 |
|
| 1616 |
+
if self.repeat_x_before_conv:
|
| 1617 |
+
self.conv1d = nn.Conv1d(
|
| 1618 |
+
in_channels=self.intermediate_size,
|
| 1619 |
+
out_channels=self.intermediate_size,
|
| 1620 |
+
bias=self.use_conv_bias,
|
| 1621 |
+
kernel_size=self.conv_kernel_size,
|
| 1622 |
+
groups=self.intermediate_size,
|
| 1623 |
+
padding=self.conv_kernel_size - 1
|
| 1624 |
+
)
|
| 1625 |
+
else:
|
| 1626 |
+
self.conv1d = nn.Conv1d(
|
| 1627 |
+
in_channels=self.xB_size,
|
| 1628 |
+
out_channels=self.xB_size,
|
| 1629 |
+
bias=self.use_conv_bias,
|
| 1630 |
+
kernel_size=self.conv_kernel_size,
|
| 1631 |
+
groups=self.xB_size,
|
| 1632 |
+
padding=self.conv_kernel_size - 1
|
| 1633 |
+
)
|
| 1634 |
|
| 1635 |
config.conv_dim[str(self.layer_idx)] = self.xB_size
|
| 1636 |
|
|
|
|
| 1750 |
|
| 1751 |
index = 0
|
| 1752 |
# ssm_parameters = self.x_proj[index](hidden_states.transpose(1, 2))
|
| 1753 |
+
B, C, x, time_step = torch.split(
|
| 1754 |
hidden_states.transpose(1,2), [self.xB_size, self.intermediate_size, self.xB_size, self.time_step_rank], dim=-1
|
| 1755 |
)
|
| 1756 |
|
|
|
|
| 1760 |
B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous()
|
| 1761 |
C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.ssm_state_size).contiguous()
|
| 1762 |
|
| 1763 |
+
x = rearrange(x, "b l d -> b d l").contiguous()
|
| 1764 |
+
|
| 1765 |
+
if self.repeat_x_before_conv:
|
| 1766 |
+
# b d l
|
| 1767 |
+
x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.ssm_state_size)
|
| 1768 |
+
x = repeat_kv(x, self.repeat_group)
|
| 1769 |
+
x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
|
| 1770 |
|
| 1771 |
#Run convolution
|
| 1772 |
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
|
| 1773 |
|
| 1774 |
+
|
| 1775 |
if use_precomputed_states:
|
| 1776 |
x = causal_conv1d_update(
|
| 1777 |
x.squeeze(-1),
|
|
|
|
| 1784 |
|
| 1785 |
cache_params.mamba_past_length[self.layer_idx] += seq_len
|
| 1786 |
else:
|
| 1787 |
+
# if cache_params is not None:
|
| 1788 |
+
# conv_states = nn.functional.pad(
|
| 1789 |
+
# x, (self.conv_kernel_size - x.shape[-1], 0)
|
| 1790 |
+
# )
|
| 1791 |
+
# cache_params.conv_states[self.layer_idx].copy_(conv_states)
|
| 1792 |
+
# cache_params.mamba_past_length[self.layer_idx] += seq_len
|
|
|
|
| 1793 |
|
| 1794 |
x = causal_conv1d_fn(
|
| 1795 |
x, conv_weights, self.conv1d.bias, activation=self.activation
|
| 1796 |
)
|
| 1797 |
|
| 1798 |
+
if not self.repeat_x_before_conv:
|
| 1799 |
+
x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.ssm_state_size)
|
| 1800 |
+
x = repeat_kv(x, self.repeat_group)
|
| 1801 |
+
x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
|
| 1802 |
+
|
| 1803 |
## Handle padding for Mamba: Set padding tokens to 0
|
| 1804 |
if seq_len > 1 and attention_mask is not None and (attention_mask == 0).any():
|
| 1805 |
x = x * attention_mask.unsqueeze(1).to(x)
|
|
|
|
| 1826 |
if use_precomputed_states:
|
| 1827 |
scan_outputs = selective_state_update(
|
| 1828 |
cache_params.ssm_states[self.layer_idx],
|
| 1829 |
+
x.squeeze(),
|
| 1830 |
discrete_time_step,
|
| 1831 |
A,
|
| 1832 |
B,
|