Maxtimer97 commited on
Commit
ebdbb00
·
1 Parent(s): 52551dd

Corrected x repeat before conv

Browse files
Files changed (2) hide show
  1. configuration_hymba.py +3 -0
  2. modeling_hymba.py +58 -24
configuration_hymba.py CHANGED
@@ -47,6 +47,7 @@ class HymbaConfig(PretrainedConfig):
47
  global_attn_idx=None,
48
  num_mamba=1,
49
  pure_attn=False,
 
50
  attn_implementation_new='sdpa',
51
  rope_type=None,
52
  attn_factor=0.5,
@@ -113,6 +114,8 @@ class HymbaConfig(PretrainedConfig):
113
 
114
  self.pure_attn = pure_attn
115
 
 
 
116
  super().__init__(
117
  pad_token_id=pad_token_id,
118
  bos_token_id=bos_token_id,
 
47
  global_attn_idx=None,
48
  num_mamba=1,
49
  pure_attn=False,
50
+ repeat_x_before_conv=True,
51
  attn_implementation_new='sdpa',
52
  rope_type=None,
53
  attn_factor=0.5,
 
114
 
115
  self.pure_attn = pure_attn
116
 
117
+ self.repeat_x_before_conv = repeat_x_before_conv
118
+
119
  super().__init__(
120
  pad_token_id=pad_token_id,
121
  bos_token_id=bos_token_id,
modeling_hymba.py CHANGED
@@ -420,8 +420,23 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
420
  def __init__(self, config, batch_size, dtype=torch.float16, device=None, layer_type=None):
421
  self.dtype = dtype
422
  # self.layers_block_type = config.layers_block_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  self.has_previous_state = False # only used by mamba
424
- intermediate_size = config.mamba_expand * config.hidden_size
425
  ssm_state_size = config.mamba_d_state
426
  conv_kernel_size = config.mamba_d_conv
427
  self.conv_states = []
@@ -439,12 +454,12 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
439
  if hasattr(config, 'conv_dim'):
440
  conv_dim = config.conv_dim[str(i)]
441
  else:
442
- conv_dim = intermediate_size
443
  self.conv_states += [
444
  torch.zeros(batch_size, conv_dim, conv_kernel_size, device=device, dtype=dtype)
445
  ]
446
  self.ssm_states += [
447
- torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
448
  ]
449
  else:
450
  self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
@@ -1592,19 +1607,30 @@ class HymbaBlock(nn.Module):
1592
 
1593
  if not self.pure_attn:
1594
 
 
1595
  num_ssm_param = 1
1596
 
1597
  if not hasattr(config, 'conv_dim'):
1598
  config.conv_dim = {str(i):0 for i in range(config.num_hidden_layers)}
1599
 
1600
- self.conv1d = nn.Conv1d(
1601
- in_channels=self.xB_size,
1602
- out_channels=self.xB_size,
1603
- bias=self.use_conv_bias,
1604
- kernel_size=self.conv_kernel_size,
1605
- groups=self.xB_size,
1606
- padding=self.conv_kernel_size - 1
1607
- )
 
 
 
 
 
 
 
 
 
 
1608
 
1609
  config.conv_dim[str(self.layer_idx)] = self.xB_size
1610
 
@@ -1724,7 +1750,7 @@ class HymbaBlock(nn.Module):
1724
 
1725
  index = 0
1726
  # ssm_parameters = self.x_proj[index](hidden_states.transpose(1, 2))
1727
- B, C, x, timestep = torch.split(
1728
  hidden_states.transpose(1,2), [self.xB_size, self.intermediate_size, self.xB_size, self.time_step_rank], dim=-1
1729
  )
1730
 
@@ -1734,14 +1760,18 @@ class HymbaBlock(nn.Module):
1734
  B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous()
1735
  C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.ssm_state_size).contiguous()
1736
 
1737
- x = rearrange(x, "b l d -> b d l")
1738
- x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.ssm_state_size)
1739
- x = repeat_kv(x, self.repeat_group)
1740
- x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
 
 
 
1741
 
1742
  #Run convolution
1743
  conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
1744
 
 
1745
  if use_precomputed_states:
1746
  x = causal_conv1d_update(
1747
  x.squeeze(-1),
@@ -1754,18 +1784,22 @@ class HymbaBlock(nn.Module):
1754
 
1755
  cache_params.mamba_past_length[self.layer_idx] += seq_len
1756
  else:
1757
- if cache_params is not None:
1758
- conv_states = nn.functional.pad(
1759
- x, (self.conv_kernel_size - x.shape[-1], 0)
1760
- )
1761
-
1762
- cache_params.conv_states[self.layer_idx].copy_(conv_states)
1763
- cache_params.mamba_past_length[self.layer_idx] += seq_len
1764
 
1765
  x = causal_conv1d_fn(
1766
  x, conv_weights, self.conv1d.bias, activation=self.activation
1767
  )
1768
 
 
 
 
 
 
1769
  ## Handle padding for Mamba: Set padding tokens to 0
1770
  if seq_len > 1 and attention_mask is not None and (attention_mask == 0).any():
1771
  x = x * attention_mask.unsqueeze(1).to(x)
@@ -1792,7 +1826,7 @@ class HymbaBlock(nn.Module):
1792
  if use_precomputed_states:
1793
  scan_outputs = selective_state_update(
1794
  cache_params.ssm_states[self.layer_idx],
1795
- x,
1796
  discrete_time_step,
1797
  A,
1798
  B,
 
420
  def __init__(self, config, batch_size, dtype=torch.float16, device=None, layer_type=None):
421
  self.dtype = dtype
422
  # self.layers_block_type = config.layers_block_type
423
+
424
+ self.pure_attn = config.pure_attn
425
+
426
+ if self.pure_attn:
427
+ self.attn_hidden_size = config.hidden_size
428
+ self.intermediate_size = int(config.mamba_expand * config.hidden_size)
429
+ else:
430
+ self.attn_hidden_size = int(config.hidden_size * config.attn_factor)
431
+ config.attn_hidden_size = self.attn_hidden_size
432
+ self.intermediate_size = int(config.mamba_expand * config.hidden_size * (1-config.attn_factor))
433
+
434
+ self.xB_size = int(config.num_key_value_heads/config.num_attention_heads * self.intermediate_size)
435
+
436
+ # self.num_xb_head = self.xB_size // self.ssm_state_size
437
+
438
+
439
  self.has_previous_state = False # only used by mamba
 
440
  ssm_state_size = config.mamba_d_state
441
  conv_kernel_size = config.mamba_d_conv
442
  self.conv_states = []
 
454
  if hasattr(config, 'conv_dim'):
455
  conv_dim = config.conv_dim[str(i)]
456
  else:
457
+ conv_dim = self.xB_size
458
  self.conv_states += [
459
  torch.zeros(batch_size, conv_dim, conv_kernel_size, device=device, dtype=dtype)
460
  ]
461
  self.ssm_states += [
462
+ torch.zeros(batch_size, self.intermediate_size, ssm_state_size, device=device, dtype=dtype)
463
  ]
464
  else:
465
  self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
 
1607
 
1608
  if not self.pure_attn:
1609
 
1610
+ self.repeat_x_before_conv = config.repeat_x_before_conv
1611
  num_ssm_param = 1
1612
 
1613
  if not hasattr(config, 'conv_dim'):
1614
  config.conv_dim = {str(i):0 for i in range(config.num_hidden_layers)}
1615
 
1616
+ if self.repeat_x_before_conv:
1617
+ self.conv1d = nn.Conv1d(
1618
+ in_channels=self.intermediate_size,
1619
+ out_channels=self.intermediate_size,
1620
+ bias=self.use_conv_bias,
1621
+ kernel_size=self.conv_kernel_size,
1622
+ groups=self.intermediate_size,
1623
+ padding=self.conv_kernel_size - 1
1624
+ )
1625
+ else:
1626
+ self.conv1d = nn.Conv1d(
1627
+ in_channels=self.xB_size,
1628
+ out_channels=self.xB_size,
1629
+ bias=self.use_conv_bias,
1630
+ kernel_size=self.conv_kernel_size,
1631
+ groups=self.xB_size,
1632
+ padding=self.conv_kernel_size - 1
1633
+ )
1634
 
1635
  config.conv_dim[str(self.layer_idx)] = self.xB_size
1636
 
 
1750
 
1751
  index = 0
1752
  # ssm_parameters = self.x_proj[index](hidden_states.transpose(1, 2))
1753
+ B, C, x, time_step = torch.split(
1754
  hidden_states.transpose(1,2), [self.xB_size, self.intermediate_size, self.xB_size, self.time_step_rank], dim=-1
1755
  )
1756
 
 
1760
  B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous()
1761
  C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.ssm_state_size).contiguous()
1762
 
1763
+ x = rearrange(x, "b l d -> b d l").contiguous()
1764
+
1765
+ if self.repeat_x_before_conv:
1766
+ # b d l
1767
+ x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.ssm_state_size)
1768
+ x = repeat_kv(x, self.repeat_group)
1769
+ x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
1770
 
1771
  #Run convolution
1772
  conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
1773
 
1774
+
1775
  if use_precomputed_states:
1776
  x = causal_conv1d_update(
1777
  x.squeeze(-1),
 
1784
 
1785
  cache_params.mamba_past_length[self.layer_idx] += seq_len
1786
  else:
1787
+ # if cache_params is not None:
1788
+ # conv_states = nn.functional.pad(
1789
+ # x, (self.conv_kernel_size - x.shape[-1], 0)
1790
+ # )
1791
+ # cache_params.conv_states[self.layer_idx].copy_(conv_states)
1792
+ # cache_params.mamba_past_length[self.layer_idx] += seq_len
 
1793
 
1794
  x = causal_conv1d_fn(
1795
  x, conv_weights, self.conv1d.bias, activation=self.activation
1796
  )
1797
 
1798
+ if not self.repeat_x_before_conv:
1799
+ x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.ssm_state_size)
1800
+ x = repeat_kv(x, self.repeat_group)
1801
+ x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
1802
+
1803
  ## Handle padding for Mamba: Set padding tokens to 0
1804
  if seq_len > 1 and attention_mask is not None and (attention_mask == 0).any():
1805
  x = x * attention_mask.unsqueeze(1).to(x)
 
1826
  if use_precomputed_states:
1827
  scan_outputs = selective_state_update(
1828
  cache_params.ssm_states[self.layer_idx],
1829
+ x.squeeze(),
1830
  discrete_time_step,
1831
  A,
1832
  B,