Mishamq commited on
Commit
9e08f93
·
verified ·
1 Parent(s): 3d0420f

Upload modeling_hybridna.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_hybridna.py +27 -163
modeling_hybridna.py CHANGED
@@ -35,8 +35,8 @@ from transformers.modeling_attn_mask_utils import (
35
  AttentionMaskConverter,
36
  )
37
  from transformers.modeling_outputs import (
38
- MoeCausalLMOutputWithPast,
39
- MoeModelOutputWithPast,
40
  SequenceClassifierOutputWithPast,
41
  )
42
  from transformers.modeling_utils import PreTrainedModel
@@ -90,85 +90,6 @@ logger = logging.get_logger(__name__)
90
  _CONFIG_FOR_DOC = "HybriDNAConfig"
91
 
92
 
93
- # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func with gate->router
94
- def load_balancing_loss_func(
95
- router_logits: torch.Tensor,
96
- num_experts: torch.Tensor = None,
97
- top_k=2,
98
- attention_mask: Optional[torch.Tensor] = None,
99
- ) -> float:
100
- r"""
101
- Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
102
- See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
103
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
104
- experts is too unbalanced.
105
- Args:
106
- router_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
107
- Logits from the `router`, should be a tuple of model.config.num_hidden_layers tensors of
108
- shape [batch_size X sequence_length, num_experts].
109
- attention_mask (`torch.Tensor`, None):
110
- The attention_mask used in forward function
111
- shape [batch_size X sequence_length] if not None.
112
- num_experts (`int`, *optional*):
113
- Number of experts
114
- Returns:
115
- The auxiliary loss.
116
- """
117
- if router_logits is None or not isinstance(router_logits, tuple):
118
- return 0
119
-
120
- if isinstance(router_logits, tuple):
121
- compute_device = router_logits[0].device
122
- concatenated_router_logits = torch.cat(
123
- [layer_router.to(compute_device) for layer_router in router_logits], dim=0
124
- )
125
-
126
- routing_weights = torch.nn.functional.softmax(concatenated_router_logits, dim=-1)
127
-
128
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
129
-
130
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
131
-
132
- if attention_mask is None:
133
- # Compute the percentage of tokens routed to each experts
134
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
135
-
136
- # Compute the average probability of routing to these experts
137
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
138
- else:
139
- batch_size, sequence_length = attention_mask.shape
140
- num_hidden_layers = concatenated_router_logits.shape[0] // (batch_size * sequence_length)
141
-
142
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
143
- expert_attention_mask = (
144
- attention_mask[None, :, :, None, None]
145
- .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
146
- .reshape(-1, top_k, num_experts)
147
- .to(compute_device)
148
- )
149
-
150
- # Compute the percentage of tokens routed to each experts
151
- tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
152
- expert_attention_mask, dim=0
153
- )
154
-
155
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
156
- router_per_expert_attention_mask = (
157
- attention_mask[None, :, :, None]
158
- .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
159
- .reshape(-1, num_experts)
160
- .to(compute_device)
161
- )
162
-
163
- # Compute the average probability of routing to these experts
164
- router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
165
- router_per_expert_attention_mask, dim=0
166
- )
167
-
168
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
169
- return overall_loss * num_experts
170
-
171
-
172
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
173
  def _get_unpad_data(attention_mask):
174
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -1221,9 +1142,7 @@ class HybriDNAMLP(nn.Module):
1221
  class HybriDNAAttentionDecoderLayer(nn.Module):
1222
  def __init__(self, config: HybriDNAConfig, layer_idx: int):
1223
  super().__init__()
1224
- # Remove MoE support: always use vanilla MLP
1225
  self.self_attn = HYBRIDNA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
1226
- # Previously: ffn_layer_class = HybriDNASparseMoeBlock if num_experts > 1 else HybriDNAMLP
1227
  self.feed_forward = HybriDNAMLP(config)
1228
  self.input_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1229
  self.pre_ff_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -1231,10 +1150,9 @@ class HybriDNAAttentionDecoderLayer(nn.Module):
1231
  def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
1232
  position_ids: Optional[torch.LongTensor] = None,
1233
  past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1234
- output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False,
1235
  use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None
1236
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
1237
- # ...existing code before feed-forward...
1238
  residual = hidden_states
1239
  hidden_states = self.input_layernorm(hidden_states)
1240
 
@@ -1251,28 +1169,24 @@ class HybriDNAAttentionDecoderLayer(nn.Module):
1251
  # residual connection after attention
1252
  hidden_states = residual + hidden_states
1253
 
1254
- # feed-forward (experts/MLP)
1255
  residual = hidden_states
1256
  hidden_states = self.pre_ff_layernorm(hidden_states)
1257
- # Remove MoE tuple-check: directly compute feed-forward output
1258
  hidden_states = self.feed_forward(hidden_states)
1259
  hidden_states = residual + hidden_states
1260
 
1261
  outputs = (hidden_states,)
1262
  if output_attentions:
1263
- outputs += (self_attn_weights,) # remains unchanged from attention branch
1264
  if use_cache:
1265
  outputs += (present_key_value,)
1266
- # Remove router_logits branch completely
1267
  return outputs
1268
 
1269
 
1270
  class HybriDNAMambaDecoderLayer(nn.Module):
1271
  def __init__(self, config: HybriDNAConfig, layer_idx: int):
1272
  super().__init__()
1273
- # Remove MoE support: always use vanilla MLP
1274
  self.mamba = HybriDNAMamba2Mixer(config=config, layer_idx=layer_idx)
1275
- # Previously: ffn_layer_class = HybriDNASparseMoeBlock if num_experts > 1 else HybriDNAMLP
1276
  self.feed_forward = HybriDNAMLP(config)
1277
  self.input_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1278
  self.pre_ff_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -1280,10 +1194,9 @@ class HybriDNAMambaDecoderLayer(nn.Module):
1280
  def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
1281
  position_ids: Optional[torch.LongTensor] = None,
1282
  past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1283
- output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False,
1284
  use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None
1285
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
1286
- # ...existing code before feed-forward...
1287
  residual = hidden_states
1288
 
1289
  hidden_states = self.input_layernorm(hidden_states)
@@ -1297,10 +1210,9 @@ class HybriDNAMambaDecoderLayer(nn.Module):
1297
  # residual connection after mamba
1298
  hidden_states = residual + hidden_states
1299
 
1300
- # feed-forward (experts/MLP)
1301
  residual = hidden_states
1302
  hidden_states = self.pre_ff_layernorm(hidden_states)
1303
- # Remove MoE tuple-check: directly compute feed-forward output
1304
  hidden_states = self.feed_forward(hidden_states)
1305
  hidden_states = residual + hidden_states
1306
 
@@ -1309,11 +1221,10 @@ class HybriDNAMambaDecoderLayer(nn.Module):
1309
  outputs += (self_attn_weights,)
1310
  if use_cache:
1311
  outputs += (past_key_value,)
1312
- # Remove router_logits branch completely
1313
  return outputs
1314
 
1315
 
1316
- JAMBA_START_DOCSTRING = r"""
1317
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1318
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1319
  etc.)
@@ -1330,7 +1241,7 @@ JAMBA_START_DOCSTRING = r"""
1330
 
1331
  @add_start_docstrings(
1332
  "The bare HybriDNA Model outputting raw hidden-states without any specific head on top.",
1333
- JAMBA_START_DOCSTRING,
1334
  )
1335
  class HybriDNAPreTrainedModel(PreTrainedModel):
1336
  config_class = HybriDNAConfig
@@ -1354,7 +1265,7 @@ class HybriDNAPreTrainedModel(PreTrainedModel):
1354
  module.weight.data[module.padding_idx].zero_()
1355
 
1356
 
1357
- JAMBA_INPUTS_DOCSTRING = r"""
1358
  Args:
1359
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1360
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@@ -1404,9 +1315,6 @@ JAMBA_INPUTS_DOCSTRING = r"""
1404
  output_hidden_states (`bool`, *optional*):
1405
  Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1406
  more detail.
1407
- output_router_logits (`bool`, *optional*):
1408
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
1409
- should not be returned during inference.
1410
  return_dict (`bool`, *optional*):
1411
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1412
  cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
@@ -1420,9 +1328,8 @@ ALL_DECODER_LAYER_TYPES = {"attention": HybriDNAAttentionDecoderLayer, "mamba":
1420
 
1421
  @add_start_docstrings(
1422
  "The bare HybriDNA Model outputting raw hidden-states without any specific head on top.",
1423
- JAMBA_START_DOCSTRING,
1424
  )
1425
- # Adapted from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->JAMBA, Mistral->HybriDNA
1426
  class HybriDNAModel(HybriDNAPreTrainedModel):
1427
  """
1428
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HybriDNADecoderLayer`]
@@ -1455,7 +1362,7 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
1455
  def set_input_embeddings(self, value):
1456
  self.embed_tokens = value
1457
 
1458
- @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
1459
  def forward(
1460
  self,
1461
  input_ids: torch.LongTensor = None,
@@ -1466,14 +1373,10 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
1466
  use_cache: Optional[bool] = None,
1467
  output_attentions: Optional[bool] = None,
1468
  output_hidden_states: Optional[bool] = None,
1469
- output_router_logits: Optional[bool] = None,
1470
  return_dict: Optional[bool] = None,
1471
  cache_position: Optional[torch.LongTensor] = None,
1472
- ) -> Union[Tuple, MoeModelOutputWithPast]:
1473
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1474
- output_router_logits = (
1475
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
1476
- )
1477
  output_hidden_states = (
1478
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1479
  )
@@ -1512,7 +1415,6 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
1512
 
1513
  all_hidden_states = () if output_hidden_states else None
1514
  all_self_attns = () if output_attentions else None
1515
- all_router_logits = () if output_router_logits else None
1516
 
1517
  for decoder_layer in self.layers:
1518
  if output_hidden_states:
@@ -1526,7 +1428,6 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
1526
  position_ids,
1527
  past_key_values,
1528
  output_attentions,
1529
- output_router_logits,
1530
  use_cache,
1531
  cache_position,
1532
  )
@@ -1537,7 +1438,6 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
1537
  position_ids=position_ids,
1538
  past_key_value=past_key_values,
1539
  output_attentions=output_attentions,
1540
- output_router_logits=output_router_logits,
1541
  use_cache=use_cache,
1542
  cache_position=cache_position,
1543
  )
@@ -1549,11 +1449,6 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
1549
  # append attentions only of attention layers. Mamba layers return `None` as the attention weights
1550
  all_self_attns += (layer_outputs[1],)
1551
 
1552
- if output_router_logits:
1553
- if layer_outputs[-1] is not None:
1554
- # append router logits only of expert layers. Regular MLP layers return `None` as the router logits
1555
- all_router_logits += (layer_outputs[-1],)
1556
-
1557
  hidden_states = self.final_layernorm(hidden_states)
1558
 
1559
  # add hidden states from the last decoder layer
@@ -1570,15 +1465,14 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
1570
  if not return_dict:
1571
  return tuple(
1572
  v
1573
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1574
  if v is not None
1575
  )
1576
- return MoeModelOutputWithPast(
1577
  last_hidden_state=hidden_states,
1578
  past_key_values=next_cache,
1579
  hidden_states=all_hidden_states,
1580
  attentions=all_self_attns,
1581
- router_logits=all_router_logits,
1582
  )
1583
 
1584
  def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
@@ -1617,7 +1511,6 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
1617
  return causal_mask
1618
 
1619
 
1620
- # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->HybriDNA
1621
  class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
1622
  _tied_weights_keys = ["lm_head.weight"]
1623
 
@@ -1626,9 +1519,6 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
1626
  self.model = HybriDNAModel(config)
1627
  self.vocab_size = config.vocab_size
1628
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1629
- self.router_aux_loss_coef = config.router_aux_loss_coef
1630
- self.num_experts = config.num_experts
1631
- self.num_experts_per_tok = config.num_experts_per_tok
1632
  # Initialize weights and apply final processing
1633
  self.post_init()
1634
 
@@ -1650,9 +1540,8 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
1650
  def get_decoder(self):
1651
  return self.model
1652
 
1653
- @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
1654
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1655
- # Ignore copy
1656
  def forward(
1657
  self,
1658
  input_ids: torch.LongTensor = None,
@@ -1664,11 +1553,10 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
1664
  use_cache: Optional[bool] = None,
1665
  output_attentions: Optional[bool] = None,
1666
  output_hidden_states: Optional[bool] = None,
1667
- output_router_logits: Optional[bool] = None,
1668
  return_dict: Optional[bool] = None,
1669
  cache_position: Optional[torch.LongTensor] = None,
1670
  num_logits_to_keep: Optional[Union[int, None]] = None,
1671
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1672
  r"""
1673
  Args:
1674
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1682,22 +1570,17 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
1682
  Returns:
1683
  Example:
1684
  ```python
1685
- >>> from transformers import AutoTokenizer, HybriDNAForCausalLM
1686
- >>> model = HybriDNAForCausalLM.from_pretrained("ai21labs/HybriDNA-v0.1")
1687
- >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/HybriDNA-v0.1")
1688
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1689
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1690
  >>> # Generate
1691
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1692
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1693
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1694
  ```"""
1695
 
1696
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1697
- output_router_logits = (
1698
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
1699
- )
1700
-
1701
  output_hidden_states = (
1702
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1703
  )
@@ -1713,7 +1596,6 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
1713
  use_cache=use_cache,
1714
  output_attentions=output_attentions,
1715
  output_hidden_states=output_hidden_states,
1716
- output_router_logits=output_router_logits,
1717
  cache_position=cache_position,
1718
  return_dict=return_dict,
1719
  )
@@ -1738,31 +1620,16 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
1738
  shift_labels = shift_labels.to(shift_logits.device)
1739
  loss = loss_fct(shift_logits, shift_labels)
1740
 
1741
- aux_loss = None
1742
- if output_router_logits:
1743
- aux_loss = load_balancing_loss_func(
1744
- outputs.router_logits if return_dict else outputs[-1],
1745
- self.num_experts,
1746
- self.num_experts_per_tok,
1747
- attention_mask,
1748
- )
1749
- if labels is not None:
1750
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
1751
-
1752
  if not return_dict:
1753
  output = (logits,) + outputs[1:]
1754
- if output_router_logits:
1755
- output = (aux_loss,) + output
1756
  return (loss,) + output if loss is not None else output
1757
 
1758
- return MoeCausalLMOutputWithPast(
1759
  loss=loss,
1760
- aux_loss=aux_loss,
1761
  logits=logits,
1762
  past_key_values=outputs.past_key_values,
1763
  hidden_states=outputs.hidden_states,
1764
  attentions=outputs.attentions,
1765
- router_logits=outputs.router_logits,
1766
  )
1767
 
1768
  def prepare_inputs_for_generation(
@@ -1771,7 +1638,6 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
1771
  past_key_values=None,
1772
  attention_mask=None,
1773
  inputs_embeds=None,
1774
- output_router_logits=False,
1775
  cache_position=None,
1776
  **kwargs,
1777
  ):
@@ -1827,7 +1693,6 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
1827
  "past_key_values": past_key_values,
1828
  "use_cache": kwargs.get("use_cache"),
1829
  "attention_mask": attention_mask,
1830
- "output_router_logits": output_router_logits,
1831
  "num_logits_to_keep": self.config.num_logits_to_keep,
1832
  "cache_position": cache_position,
1833
  }
@@ -1846,9 +1711,8 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
1846
  padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1847
  each row of the batch).
1848
  """,
1849
- JAMBA_START_DOCSTRING,
1850
  )
1851
- # Copied from transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification with Mixtral->HybriDNA, MIXTRAL->JAMBA
1852
  class HybriDNAForSequenceClassification(HybriDNAPreTrainedModel):
1853
  def __init__(self, config):
1854
  super().__init__(config)
@@ -1865,7 +1729,7 @@ class HybriDNAForSequenceClassification(HybriDNAPreTrainedModel):
1865
  def set_input_embeddings(self, value):
1866
  self.model.embed_tokens = value
1867
 
1868
- @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
1869
  def forward(
1870
  self,
1871
  input_ids: torch.LongTensor = None,
@@ -1963,7 +1827,7 @@ class HybriDNAForSequenceClassification(HybriDNAPreTrainedModel):
1963
  The input sequence is concatenated with its reverse complement before being processed by the model.
1964
  [`HybriDNAForSequenceClassificationRCEcho`]
1965
  """,
1966
- JAMBA_START_DOCSTRING,
1967
  )
1968
  class HybriDNAForSequenceClassificationRCEcho(HybriDNAPreTrainedModel):
1969
  def __init__(self, config):
@@ -2005,7 +1869,7 @@ class HybriDNAForSequenceClassificationRCEcho(HybriDNAPreTrainedModel):
2005
  rc = torch.flip(rc, dims=[1])
2006
  return rc
2007
 
2008
- @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
2009
  def forward(
2010
  self,
2011
  input_ids: torch.LongTensor = None,
 
35
  AttentionMaskConverter,
36
  )
37
  from transformers.modeling_outputs import (
38
+ BaseModelOutputWithPast,
39
+ CausalLMOutputWithPast,
40
  SequenceClassifierOutputWithPast,
41
  )
42
  from transformers.modeling_utils import PreTrainedModel
 
90
  _CONFIG_FOR_DOC = "HybriDNAConfig"
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
94
  def _get_unpad_data(attention_mask):
95
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
1142
  class HybriDNAAttentionDecoderLayer(nn.Module):
1143
  def __init__(self, config: HybriDNAConfig, layer_idx: int):
1144
  super().__init__()
 
1145
  self.self_attn = HYBRIDNA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
 
1146
  self.feed_forward = HybriDNAMLP(config)
1147
  self.input_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1148
  self.pre_ff_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
1150
  def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
1151
  position_ids: Optional[torch.LongTensor] = None,
1152
  past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1153
+ output_attentions: Optional[bool] = False,
1154
  use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None
1155
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
 
1156
  residual = hidden_states
1157
  hidden_states = self.input_layernorm(hidden_states)
1158
 
 
1169
  # residual connection after attention
1170
  hidden_states = residual + hidden_states
1171
 
1172
+ # feed-forward
1173
  residual = hidden_states
1174
  hidden_states = self.pre_ff_layernorm(hidden_states)
 
1175
  hidden_states = self.feed_forward(hidden_states)
1176
  hidden_states = residual + hidden_states
1177
 
1178
  outputs = (hidden_states,)
1179
  if output_attentions:
1180
+ outputs += (self_attn_weights,)
1181
  if use_cache:
1182
  outputs += (present_key_value,)
 
1183
  return outputs
1184
 
1185
 
1186
  class HybriDNAMambaDecoderLayer(nn.Module):
1187
  def __init__(self, config: HybriDNAConfig, layer_idx: int):
1188
  super().__init__()
 
1189
  self.mamba = HybriDNAMamba2Mixer(config=config, layer_idx=layer_idx)
 
1190
  self.feed_forward = HybriDNAMLP(config)
1191
  self.input_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1192
  self.pre_ff_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
1194
  def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
1195
  position_ids: Optional[torch.LongTensor] = None,
1196
  past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1197
+ output_attentions: Optional[bool] = False,
1198
  use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None
1199
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
 
1200
  residual = hidden_states
1201
 
1202
  hidden_states = self.input_layernorm(hidden_states)
 
1210
  # residual connection after mamba
1211
  hidden_states = residual + hidden_states
1212
 
1213
+ # feed-forward
1214
  residual = hidden_states
1215
  hidden_states = self.pre_ff_layernorm(hidden_states)
 
1216
  hidden_states = self.feed_forward(hidden_states)
1217
  hidden_states = residual + hidden_states
1218
 
 
1221
  outputs += (self_attn_weights,)
1222
  if use_cache:
1223
  outputs += (past_key_value,)
 
1224
  return outputs
1225
 
1226
 
1227
+ HYBRIDNA_START_DOCSTRING = r"""
1228
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1229
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1230
  etc.)
 
1241
 
1242
  @add_start_docstrings(
1243
  "The bare HybriDNA Model outputting raw hidden-states without any specific head on top.",
1244
+ HYBRIDNA_START_DOCSTRING,
1245
  )
1246
  class HybriDNAPreTrainedModel(PreTrainedModel):
1247
  config_class = HybriDNAConfig
 
1265
  module.weight.data[module.padding_idx].zero_()
1266
 
1267
 
1268
+ HYBRIDNA_INPUTS_DOCSTRING = r"""
1269
  Args:
1270
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1271
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
 
1315
  output_hidden_states (`bool`, *optional*):
1316
  Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1317
  more detail.
 
 
 
1318
  return_dict (`bool`, *optional*):
1319
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1320
  cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
 
1328
 
1329
  @add_start_docstrings(
1330
  "The bare HybriDNA Model outputting raw hidden-states without any specific head on top.",
1331
+ HYBRIDNA_START_DOCSTRING,
1332
  )
 
1333
  class HybriDNAModel(HybriDNAPreTrainedModel):
1334
  """
1335
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HybriDNADecoderLayer`]
 
1362
  def set_input_embeddings(self, value):
1363
  self.embed_tokens = value
1364
 
1365
+ @add_start_docstrings_to_model_forward(HYBRIDNA_INPUTS_DOCSTRING)
1366
  def forward(
1367
  self,
1368
  input_ids: torch.LongTensor = None,
 
1373
  use_cache: Optional[bool] = None,
1374
  output_attentions: Optional[bool] = None,
1375
  output_hidden_states: Optional[bool] = None,
 
1376
  return_dict: Optional[bool] = None,
1377
  cache_position: Optional[torch.LongTensor] = None,
1378
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1379
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
1380
  output_hidden_states = (
1381
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1382
  )
 
1415
 
1416
  all_hidden_states = () if output_hidden_states else None
1417
  all_self_attns = () if output_attentions else None
 
1418
 
1419
  for decoder_layer in self.layers:
1420
  if output_hidden_states:
 
1428
  position_ids,
1429
  past_key_values,
1430
  output_attentions,
 
1431
  use_cache,
1432
  cache_position,
1433
  )
 
1438
  position_ids=position_ids,
1439
  past_key_value=past_key_values,
1440
  output_attentions=output_attentions,
 
1441
  use_cache=use_cache,
1442
  cache_position=cache_position,
1443
  )
 
1449
  # append attentions only of attention layers. Mamba layers return `None` as the attention weights
1450
  all_self_attns += (layer_outputs[1],)
1451
 
 
 
 
 
 
1452
  hidden_states = self.final_layernorm(hidden_states)
1453
 
1454
  # add hidden states from the last decoder layer
 
1465
  if not return_dict:
1466
  return tuple(
1467
  v
1468
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1469
  if v is not None
1470
  )
1471
+ return BaseModelOutputWithPast(
1472
  last_hidden_state=hidden_states,
1473
  past_key_values=next_cache,
1474
  hidden_states=all_hidden_states,
1475
  attentions=all_self_attns,
 
1476
  )
1477
 
1478
  def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
 
1511
  return causal_mask
1512
 
1513
 
 
1514
  class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
1515
  _tied_weights_keys = ["lm_head.weight"]
1516
 
 
1519
  self.model = HybriDNAModel(config)
1520
  self.vocab_size = config.vocab_size
1521
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
1522
  # Initialize weights and apply final processing
1523
  self.post_init()
1524
 
 
1540
  def get_decoder(self):
1541
  return self.model
1542
 
1543
+ @add_start_docstrings_to_model_forward(HYBRIDNA_INPUTS_DOCSTRING)
1544
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1545
  def forward(
1546
  self,
1547
  input_ids: torch.LongTensor = None,
 
1553
  use_cache: Optional[bool] = None,
1554
  output_attentions: Optional[bool] = None,
1555
  output_hidden_states: Optional[bool] = None,
 
1556
  return_dict: Optional[bool] = None,
1557
  cache_position: Optional[torch.LongTensor] = None,
1558
  num_logits_to_keep: Optional[Union[int, None]] = None,
1559
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1560
  r"""
1561
  Args:
1562
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
1570
  Returns:
1571
  Example:
1572
  ```python
1573
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
1574
+ >>> model = AutoModelForCausalLM.from_pretrained("Mishamq/HybriDNA-300M", trust_remote_code=True)
1575
+ >>> tokenizer = AutoTokenizer.from_pretrained("Mishamq/HybriDNA-300M", trust_remote_code=True)
1576
+ >>> prompt = "ACGTACGTACGTACGT"
1577
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1578
  >>> # Generate
1579
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1580
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
1581
  ```"""
1582
 
1583
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
1584
  output_hidden_states = (
1585
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1586
  )
 
1596
  use_cache=use_cache,
1597
  output_attentions=output_attentions,
1598
  output_hidden_states=output_hidden_states,
 
1599
  cache_position=cache_position,
1600
  return_dict=return_dict,
1601
  )
 
1620
  shift_labels = shift_labels.to(shift_logits.device)
1621
  loss = loss_fct(shift_logits, shift_labels)
1622
 
 
 
 
 
 
 
 
 
 
 
 
1623
  if not return_dict:
1624
  output = (logits,) + outputs[1:]
 
 
1625
  return (loss,) + output if loss is not None else output
1626
 
1627
+ return CausalLMOutputWithPast(
1628
  loss=loss,
 
1629
  logits=logits,
1630
  past_key_values=outputs.past_key_values,
1631
  hidden_states=outputs.hidden_states,
1632
  attentions=outputs.attentions,
 
1633
  )
1634
 
1635
  def prepare_inputs_for_generation(
 
1638
  past_key_values=None,
1639
  attention_mask=None,
1640
  inputs_embeds=None,
 
1641
  cache_position=None,
1642
  **kwargs,
1643
  ):
 
1693
  "past_key_values": past_key_values,
1694
  "use_cache": kwargs.get("use_cache"),
1695
  "attention_mask": attention_mask,
 
1696
  "num_logits_to_keep": self.config.num_logits_to_keep,
1697
  "cache_position": cache_position,
1698
  }
 
1711
  padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1712
  each row of the batch).
1713
  """,
1714
+ HYBRIDNA_START_DOCSTRING,
1715
  )
 
1716
  class HybriDNAForSequenceClassification(HybriDNAPreTrainedModel):
1717
  def __init__(self, config):
1718
  super().__init__(config)
 
1729
  def set_input_embeddings(self, value):
1730
  self.model.embed_tokens = value
1731
 
1732
+ @add_start_docstrings_to_model_forward(HYBRIDNA_INPUTS_DOCSTRING)
1733
  def forward(
1734
  self,
1735
  input_ids: torch.LongTensor = None,
 
1827
  The input sequence is concatenated with its reverse complement before being processed by the model.
1828
  [`HybriDNAForSequenceClassificationRCEcho`]
1829
  """,
1830
+ HYBRIDNA_START_DOCSTRING,
1831
  )
1832
  class HybriDNAForSequenceClassificationRCEcho(HybriDNAPreTrainedModel):
1833
  def __init__(self, config):
 
1869
  rc = torch.flip(rc, dims=[1])
1870
  return rc
1871
 
1872
+ @add_start_docstrings_to_model_forward(HYBRIDNA_INPUTS_DOCSTRING)
1873
  def forward(
1874
  self,
1875
  input_ids: torch.LongTensor = None,