Upload modeling_hybridna.py with huggingface_hub
Browse files- modeling_hybridna.py +27 -163
modeling_hybridna.py
CHANGED
|
@@ -35,8 +35,8 @@ from transformers.modeling_attn_mask_utils import (
|
|
| 35 |
AttentionMaskConverter,
|
| 36 |
)
|
| 37 |
from transformers.modeling_outputs import (
|
| 38 |
-
|
| 39 |
-
|
| 40 |
SequenceClassifierOutputWithPast,
|
| 41 |
)
|
| 42 |
from transformers.modeling_utils import PreTrainedModel
|
|
@@ -90,85 +90,6 @@ logger = logging.get_logger(__name__)
|
|
| 90 |
_CONFIG_FOR_DOC = "HybriDNAConfig"
|
| 91 |
|
| 92 |
|
| 93 |
-
# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func with gate->router
|
| 94 |
-
def load_balancing_loss_func(
|
| 95 |
-
router_logits: torch.Tensor,
|
| 96 |
-
num_experts: torch.Tensor = None,
|
| 97 |
-
top_k=2,
|
| 98 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 99 |
-
) -> float:
|
| 100 |
-
r"""
|
| 101 |
-
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
| 102 |
-
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
|
| 103 |
-
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
| 104 |
-
experts is too unbalanced.
|
| 105 |
-
Args:
|
| 106 |
-
router_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
|
| 107 |
-
Logits from the `router`, should be a tuple of model.config.num_hidden_layers tensors of
|
| 108 |
-
shape [batch_size X sequence_length, num_experts].
|
| 109 |
-
attention_mask (`torch.Tensor`, None):
|
| 110 |
-
The attention_mask used in forward function
|
| 111 |
-
shape [batch_size X sequence_length] if not None.
|
| 112 |
-
num_experts (`int`, *optional*):
|
| 113 |
-
Number of experts
|
| 114 |
-
Returns:
|
| 115 |
-
The auxiliary loss.
|
| 116 |
-
"""
|
| 117 |
-
if router_logits is None or not isinstance(router_logits, tuple):
|
| 118 |
-
return 0
|
| 119 |
-
|
| 120 |
-
if isinstance(router_logits, tuple):
|
| 121 |
-
compute_device = router_logits[0].device
|
| 122 |
-
concatenated_router_logits = torch.cat(
|
| 123 |
-
[layer_router.to(compute_device) for layer_router in router_logits], dim=0
|
| 124 |
-
)
|
| 125 |
-
|
| 126 |
-
routing_weights = torch.nn.functional.softmax(concatenated_router_logits, dim=-1)
|
| 127 |
-
|
| 128 |
-
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
| 129 |
-
|
| 130 |
-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
| 131 |
-
|
| 132 |
-
if attention_mask is None:
|
| 133 |
-
# Compute the percentage of tokens routed to each experts
|
| 134 |
-
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 135 |
-
|
| 136 |
-
# Compute the average probability of routing to these experts
|
| 137 |
-
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
| 138 |
-
else:
|
| 139 |
-
batch_size, sequence_length = attention_mask.shape
|
| 140 |
-
num_hidden_layers = concatenated_router_logits.shape[0] // (batch_size * sequence_length)
|
| 141 |
-
|
| 142 |
-
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
| 143 |
-
expert_attention_mask = (
|
| 144 |
-
attention_mask[None, :, :, None, None]
|
| 145 |
-
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
| 146 |
-
.reshape(-1, top_k, num_experts)
|
| 147 |
-
.to(compute_device)
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
# Compute the percentage of tokens routed to each experts
|
| 151 |
-
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
| 152 |
-
expert_attention_mask, dim=0
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
| 156 |
-
router_per_expert_attention_mask = (
|
| 157 |
-
attention_mask[None, :, :, None]
|
| 158 |
-
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
| 159 |
-
.reshape(-1, num_experts)
|
| 160 |
-
.to(compute_device)
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
# Compute the average probability of routing to these experts
|
| 164 |
-
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
| 165 |
-
router_per_expert_attention_mask, dim=0
|
| 166 |
-
)
|
| 167 |
-
|
| 168 |
-
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
| 169 |
-
return overall_loss * num_experts
|
| 170 |
-
|
| 171 |
-
|
| 172 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
| 173 |
def _get_unpad_data(attention_mask):
|
| 174 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
@@ -1221,9 +1142,7 @@ class HybriDNAMLP(nn.Module):
|
|
| 1221 |
class HybriDNAAttentionDecoderLayer(nn.Module):
|
| 1222 |
def __init__(self, config: HybriDNAConfig, layer_idx: int):
|
| 1223 |
super().__init__()
|
| 1224 |
-
# Remove MoE support: always use vanilla MLP
|
| 1225 |
self.self_attn = HYBRIDNA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
| 1226 |
-
# Previously: ffn_layer_class = HybriDNASparseMoeBlock if num_experts > 1 else HybriDNAMLP
|
| 1227 |
self.feed_forward = HybriDNAMLP(config)
|
| 1228 |
self.input_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1229 |
self.pre_ff_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
@@ -1231,10 +1150,9 @@ class HybriDNAAttentionDecoderLayer(nn.Module):
|
|
| 1231 |
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
|
| 1232 |
position_ids: Optional[torch.LongTensor] = None,
|
| 1233 |
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
|
| 1234 |
-
output_attentions: Optional[bool] = False,
|
| 1235 |
use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None
|
| 1236 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 1237 |
-
# ...existing code before feed-forward...
|
| 1238 |
residual = hidden_states
|
| 1239 |
hidden_states = self.input_layernorm(hidden_states)
|
| 1240 |
|
|
@@ -1251,28 +1169,24 @@ class HybriDNAAttentionDecoderLayer(nn.Module):
|
|
| 1251 |
# residual connection after attention
|
| 1252 |
hidden_states = residual + hidden_states
|
| 1253 |
|
| 1254 |
-
# feed-forward
|
| 1255 |
residual = hidden_states
|
| 1256 |
hidden_states = self.pre_ff_layernorm(hidden_states)
|
| 1257 |
-
# Remove MoE tuple-check: directly compute feed-forward output
|
| 1258 |
hidden_states = self.feed_forward(hidden_states)
|
| 1259 |
hidden_states = residual + hidden_states
|
| 1260 |
|
| 1261 |
outputs = (hidden_states,)
|
| 1262 |
if output_attentions:
|
| 1263 |
-
outputs += (self_attn_weights,)
|
| 1264 |
if use_cache:
|
| 1265 |
outputs += (present_key_value,)
|
| 1266 |
-
# Remove router_logits branch completely
|
| 1267 |
return outputs
|
| 1268 |
|
| 1269 |
|
| 1270 |
class HybriDNAMambaDecoderLayer(nn.Module):
|
| 1271 |
def __init__(self, config: HybriDNAConfig, layer_idx: int):
|
| 1272 |
super().__init__()
|
| 1273 |
-
# Remove MoE support: always use vanilla MLP
|
| 1274 |
self.mamba = HybriDNAMamba2Mixer(config=config, layer_idx=layer_idx)
|
| 1275 |
-
# Previously: ffn_layer_class = HybriDNASparseMoeBlock if num_experts > 1 else HybriDNAMLP
|
| 1276 |
self.feed_forward = HybriDNAMLP(config)
|
| 1277 |
self.input_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1278 |
self.pre_ff_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
@@ -1280,10 +1194,9 @@ class HybriDNAMambaDecoderLayer(nn.Module):
|
|
| 1280 |
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
|
| 1281 |
position_ids: Optional[torch.LongTensor] = None,
|
| 1282 |
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
|
| 1283 |
-
output_attentions: Optional[bool] = False,
|
| 1284 |
use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None
|
| 1285 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 1286 |
-
# ...existing code before feed-forward...
|
| 1287 |
residual = hidden_states
|
| 1288 |
|
| 1289 |
hidden_states = self.input_layernorm(hidden_states)
|
|
@@ -1297,10 +1210,9 @@ class HybriDNAMambaDecoderLayer(nn.Module):
|
|
| 1297 |
# residual connection after mamba
|
| 1298 |
hidden_states = residual + hidden_states
|
| 1299 |
|
| 1300 |
-
# feed-forward
|
| 1301 |
residual = hidden_states
|
| 1302 |
hidden_states = self.pre_ff_layernorm(hidden_states)
|
| 1303 |
-
# Remove MoE tuple-check: directly compute feed-forward output
|
| 1304 |
hidden_states = self.feed_forward(hidden_states)
|
| 1305 |
hidden_states = residual + hidden_states
|
| 1306 |
|
|
@@ -1309,11 +1221,10 @@ class HybriDNAMambaDecoderLayer(nn.Module):
|
|
| 1309 |
outputs += (self_attn_weights,)
|
| 1310 |
if use_cache:
|
| 1311 |
outputs += (past_key_value,)
|
| 1312 |
-
# Remove router_logits branch completely
|
| 1313 |
return outputs
|
| 1314 |
|
| 1315 |
|
| 1316 |
-
|
| 1317 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1318 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 1319 |
etc.)
|
|
@@ -1330,7 +1241,7 @@ JAMBA_START_DOCSTRING = r"""
|
|
| 1330 |
|
| 1331 |
@add_start_docstrings(
|
| 1332 |
"The bare HybriDNA Model outputting raw hidden-states without any specific head on top.",
|
| 1333 |
-
|
| 1334 |
)
|
| 1335 |
class HybriDNAPreTrainedModel(PreTrainedModel):
|
| 1336 |
config_class = HybriDNAConfig
|
|
@@ -1354,7 +1265,7 @@ class HybriDNAPreTrainedModel(PreTrainedModel):
|
|
| 1354 |
module.weight.data[module.padding_idx].zero_()
|
| 1355 |
|
| 1356 |
|
| 1357 |
-
|
| 1358 |
Args:
|
| 1359 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1360 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
@@ -1404,9 +1315,6 @@ JAMBA_INPUTS_DOCSTRING = r"""
|
|
| 1404 |
output_hidden_states (`bool`, *optional*):
|
| 1405 |
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1406 |
more detail.
|
| 1407 |
-
output_router_logits (`bool`, *optional*):
|
| 1408 |
-
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
| 1409 |
-
should not be returned during inference.
|
| 1410 |
return_dict (`bool`, *optional*):
|
| 1411 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1412 |
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
|
@@ -1420,9 +1328,8 @@ ALL_DECODER_LAYER_TYPES = {"attention": HybriDNAAttentionDecoderLayer, "mamba":
|
|
| 1420 |
|
| 1421 |
@add_start_docstrings(
|
| 1422 |
"The bare HybriDNA Model outputting raw hidden-states without any specific head on top.",
|
| 1423 |
-
|
| 1424 |
)
|
| 1425 |
-
# Adapted from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->JAMBA, Mistral->HybriDNA
|
| 1426 |
class HybriDNAModel(HybriDNAPreTrainedModel):
|
| 1427 |
"""
|
| 1428 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HybriDNADecoderLayer`]
|
|
@@ -1455,7 +1362,7 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
|
|
| 1455 |
def set_input_embeddings(self, value):
|
| 1456 |
self.embed_tokens = value
|
| 1457 |
|
| 1458 |
-
@add_start_docstrings_to_model_forward(
|
| 1459 |
def forward(
|
| 1460 |
self,
|
| 1461 |
input_ids: torch.LongTensor = None,
|
|
@@ -1466,14 +1373,10 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
|
|
| 1466 |
use_cache: Optional[bool] = None,
|
| 1467 |
output_attentions: Optional[bool] = None,
|
| 1468 |
output_hidden_states: Optional[bool] = None,
|
| 1469 |
-
output_router_logits: Optional[bool] = None,
|
| 1470 |
return_dict: Optional[bool] = None,
|
| 1471 |
cache_position: Optional[torch.LongTensor] = None,
|
| 1472 |
-
) -> Union[Tuple,
|
| 1473 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1474 |
-
output_router_logits = (
|
| 1475 |
-
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 1476 |
-
)
|
| 1477 |
output_hidden_states = (
|
| 1478 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1479 |
)
|
|
@@ -1512,7 +1415,6 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
|
|
| 1512 |
|
| 1513 |
all_hidden_states = () if output_hidden_states else None
|
| 1514 |
all_self_attns = () if output_attentions else None
|
| 1515 |
-
all_router_logits = () if output_router_logits else None
|
| 1516 |
|
| 1517 |
for decoder_layer in self.layers:
|
| 1518 |
if output_hidden_states:
|
|
@@ -1526,7 +1428,6 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
|
|
| 1526 |
position_ids,
|
| 1527 |
past_key_values,
|
| 1528 |
output_attentions,
|
| 1529 |
-
output_router_logits,
|
| 1530 |
use_cache,
|
| 1531 |
cache_position,
|
| 1532 |
)
|
|
@@ -1537,7 +1438,6 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
|
|
| 1537 |
position_ids=position_ids,
|
| 1538 |
past_key_value=past_key_values,
|
| 1539 |
output_attentions=output_attentions,
|
| 1540 |
-
output_router_logits=output_router_logits,
|
| 1541 |
use_cache=use_cache,
|
| 1542 |
cache_position=cache_position,
|
| 1543 |
)
|
|
@@ -1549,11 +1449,6 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
|
|
| 1549 |
# append attentions only of attention layers. Mamba layers return `None` as the attention weights
|
| 1550 |
all_self_attns += (layer_outputs[1],)
|
| 1551 |
|
| 1552 |
-
if output_router_logits:
|
| 1553 |
-
if layer_outputs[-1] is not None:
|
| 1554 |
-
# append router logits only of expert layers. Regular MLP layers return `None` as the router logits
|
| 1555 |
-
all_router_logits += (layer_outputs[-1],)
|
| 1556 |
-
|
| 1557 |
hidden_states = self.final_layernorm(hidden_states)
|
| 1558 |
|
| 1559 |
# add hidden states from the last decoder layer
|
|
@@ -1570,15 +1465,14 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
|
|
| 1570 |
if not return_dict:
|
| 1571 |
return tuple(
|
| 1572 |
v
|
| 1573 |
-
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns
|
| 1574 |
if v is not None
|
| 1575 |
)
|
| 1576 |
-
return
|
| 1577 |
last_hidden_state=hidden_states,
|
| 1578 |
past_key_values=next_cache,
|
| 1579 |
hidden_states=all_hidden_states,
|
| 1580 |
attentions=all_self_attns,
|
| 1581 |
-
router_logits=all_router_logits,
|
| 1582 |
)
|
| 1583 |
|
| 1584 |
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
|
@@ -1617,7 +1511,6 @@ class HybriDNAModel(HybriDNAPreTrainedModel):
|
|
| 1617 |
return causal_mask
|
| 1618 |
|
| 1619 |
|
| 1620 |
-
# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->HybriDNA
|
| 1621 |
class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
|
| 1622 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1623 |
|
|
@@ -1626,9 +1519,6 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
|
|
| 1626 |
self.model = HybriDNAModel(config)
|
| 1627 |
self.vocab_size = config.vocab_size
|
| 1628 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1629 |
-
self.router_aux_loss_coef = config.router_aux_loss_coef
|
| 1630 |
-
self.num_experts = config.num_experts
|
| 1631 |
-
self.num_experts_per_tok = config.num_experts_per_tok
|
| 1632 |
# Initialize weights and apply final processing
|
| 1633 |
self.post_init()
|
| 1634 |
|
|
@@ -1650,9 +1540,8 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
|
|
| 1650 |
def get_decoder(self):
|
| 1651 |
return self.model
|
| 1652 |
|
| 1653 |
-
@add_start_docstrings_to_model_forward(
|
| 1654 |
-
@replace_return_docstrings(output_type=
|
| 1655 |
-
# Ignore copy
|
| 1656 |
def forward(
|
| 1657 |
self,
|
| 1658 |
input_ids: torch.LongTensor = None,
|
|
@@ -1664,11 +1553,10 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
|
|
| 1664 |
use_cache: Optional[bool] = None,
|
| 1665 |
output_attentions: Optional[bool] = None,
|
| 1666 |
output_hidden_states: Optional[bool] = None,
|
| 1667 |
-
output_router_logits: Optional[bool] = None,
|
| 1668 |
return_dict: Optional[bool] = None,
|
| 1669 |
cache_position: Optional[torch.LongTensor] = None,
|
| 1670 |
num_logits_to_keep: Optional[Union[int, None]] = None,
|
| 1671 |
-
) -> Union[Tuple,
|
| 1672 |
r"""
|
| 1673 |
Args:
|
| 1674 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1682,22 +1570,17 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
|
|
| 1682 |
Returns:
|
| 1683 |
Example:
|
| 1684 |
```python
|
| 1685 |
-
>>> from transformers import AutoTokenizer,
|
| 1686 |
-
>>> model =
|
| 1687 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("
|
| 1688 |
-
>>> prompt = "
|
| 1689 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1690 |
>>> # Generate
|
| 1691 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1692 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1693 |
-
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 1694 |
```"""
|
| 1695 |
|
| 1696 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1697 |
-
output_router_logits = (
|
| 1698 |
-
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 1699 |
-
)
|
| 1700 |
-
|
| 1701 |
output_hidden_states = (
|
| 1702 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1703 |
)
|
|
@@ -1713,7 +1596,6 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
|
|
| 1713 |
use_cache=use_cache,
|
| 1714 |
output_attentions=output_attentions,
|
| 1715 |
output_hidden_states=output_hidden_states,
|
| 1716 |
-
output_router_logits=output_router_logits,
|
| 1717 |
cache_position=cache_position,
|
| 1718 |
return_dict=return_dict,
|
| 1719 |
)
|
|
@@ -1738,31 +1620,16 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
|
|
| 1738 |
shift_labels = shift_labels.to(shift_logits.device)
|
| 1739 |
loss = loss_fct(shift_logits, shift_labels)
|
| 1740 |
|
| 1741 |
-
aux_loss = None
|
| 1742 |
-
if output_router_logits:
|
| 1743 |
-
aux_loss = load_balancing_loss_func(
|
| 1744 |
-
outputs.router_logits if return_dict else outputs[-1],
|
| 1745 |
-
self.num_experts,
|
| 1746 |
-
self.num_experts_per_tok,
|
| 1747 |
-
attention_mask,
|
| 1748 |
-
)
|
| 1749 |
-
if labels is not None:
|
| 1750 |
-
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
| 1751 |
-
|
| 1752 |
if not return_dict:
|
| 1753 |
output = (logits,) + outputs[1:]
|
| 1754 |
-
if output_router_logits:
|
| 1755 |
-
output = (aux_loss,) + output
|
| 1756 |
return (loss,) + output if loss is not None else output
|
| 1757 |
|
| 1758 |
-
return
|
| 1759 |
loss=loss,
|
| 1760 |
-
aux_loss=aux_loss,
|
| 1761 |
logits=logits,
|
| 1762 |
past_key_values=outputs.past_key_values,
|
| 1763 |
hidden_states=outputs.hidden_states,
|
| 1764 |
attentions=outputs.attentions,
|
| 1765 |
-
router_logits=outputs.router_logits,
|
| 1766 |
)
|
| 1767 |
|
| 1768 |
def prepare_inputs_for_generation(
|
|
@@ -1771,7 +1638,6 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
|
|
| 1771 |
past_key_values=None,
|
| 1772 |
attention_mask=None,
|
| 1773 |
inputs_embeds=None,
|
| 1774 |
-
output_router_logits=False,
|
| 1775 |
cache_position=None,
|
| 1776 |
**kwargs,
|
| 1777 |
):
|
|
@@ -1827,7 +1693,6 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
|
|
| 1827 |
"past_key_values": past_key_values,
|
| 1828 |
"use_cache": kwargs.get("use_cache"),
|
| 1829 |
"attention_mask": attention_mask,
|
| 1830 |
-
"output_router_logits": output_router_logits,
|
| 1831 |
"num_logits_to_keep": self.config.num_logits_to_keep,
|
| 1832 |
"cache_position": cache_position,
|
| 1833 |
}
|
|
@@ -1846,9 +1711,8 @@ class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
|
|
| 1846 |
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
| 1847 |
each row of the batch).
|
| 1848 |
""",
|
| 1849 |
-
|
| 1850 |
)
|
| 1851 |
-
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification with Mixtral->HybriDNA, MIXTRAL->JAMBA
|
| 1852 |
class HybriDNAForSequenceClassification(HybriDNAPreTrainedModel):
|
| 1853 |
def __init__(self, config):
|
| 1854 |
super().__init__(config)
|
|
@@ -1865,7 +1729,7 @@ class HybriDNAForSequenceClassification(HybriDNAPreTrainedModel):
|
|
| 1865 |
def set_input_embeddings(self, value):
|
| 1866 |
self.model.embed_tokens = value
|
| 1867 |
|
| 1868 |
-
@add_start_docstrings_to_model_forward(
|
| 1869 |
def forward(
|
| 1870 |
self,
|
| 1871 |
input_ids: torch.LongTensor = None,
|
|
@@ -1963,7 +1827,7 @@ class HybriDNAForSequenceClassification(HybriDNAPreTrainedModel):
|
|
| 1963 |
The input sequence is concatenated with its reverse complement before being processed by the model.
|
| 1964 |
[`HybriDNAForSequenceClassificationRCEcho`]
|
| 1965 |
""",
|
| 1966 |
-
|
| 1967 |
)
|
| 1968 |
class HybriDNAForSequenceClassificationRCEcho(HybriDNAPreTrainedModel):
|
| 1969 |
def __init__(self, config):
|
|
@@ -2005,7 +1869,7 @@ class HybriDNAForSequenceClassificationRCEcho(HybriDNAPreTrainedModel):
|
|
| 2005 |
rc = torch.flip(rc, dims=[1])
|
| 2006 |
return rc
|
| 2007 |
|
| 2008 |
-
@add_start_docstrings_to_model_forward(
|
| 2009 |
def forward(
|
| 2010 |
self,
|
| 2011 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 35 |
AttentionMaskConverter,
|
| 36 |
)
|
| 37 |
from transformers.modeling_outputs import (
|
| 38 |
+
BaseModelOutputWithPast,
|
| 39 |
+
CausalLMOutputWithPast,
|
| 40 |
SequenceClassifierOutputWithPast,
|
| 41 |
)
|
| 42 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
| 90 |
_CONFIG_FOR_DOC = "HybriDNAConfig"
|
| 91 |
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
| 94 |
def _get_unpad_data(attention_mask):
|
| 95 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
|
| 1142 |
class HybriDNAAttentionDecoderLayer(nn.Module):
|
| 1143 |
def __init__(self, config: HybriDNAConfig, layer_idx: int):
|
| 1144 |
super().__init__()
|
|
|
|
| 1145 |
self.self_attn = HYBRIDNA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
|
|
|
| 1146 |
self.feed_forward = HybriDNAMLP(config)
|
| 1147 |
self.input_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1148 |
self.pre_ff_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 1150 |
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
|
| 1151 |
position_ids: Optional[torch.LongTensor] = None,
|
| 1152 |
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
|
| 1153 |
+
output_attentions: Optional[bool] = False,
|
| 1154 |
use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None
|
| 1155 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
|
| 1156 |
residual = hidden_states
|
| 1157 |
hidden_states = self.input_layernorm(hidden_states)
|
| 1158 |
|
|
|
|
| 1169 |
# residual connection after attention
|
| 1170 |
hidden_states = residual + hidden_states
|
| 1171 |
|
| 1172 |
+
# feed-forward
|
| 1173 |
residual = hidden_states
|
| 1174 |
hidden_states = self.pre_ff_layernorm(hidden_states)
|
|
|
|
| 1175 |
hidden_states = self.feed_forward(hidden_states)
|
| 1176 |
hidden_states = residual + hidden_states
|
| 1177 |
|
| 1178 |
outputs = (hidden_states,)
|
| 1179 |
if output_attentions:
|
| 1180 |
+
outputs += (self_attn_weights,)
|
| 1181 |
if use_cache:
|
| 1182 |
outputs += (present_key_value,)
|
|
|
|
| 1183 |
return outputs
|
| 1184 |
|
| 1185 |
|
| 1186 |
class HybriDNAMambaDecoderLayer(nn.Module):
|
| 1187 |
def __init__(self, config: HybriDNAConfig, layer_idx: int):
|
| 1188 |
super().__init__()
|
|
|
|
| 1189 |
self.mamba = HybriDNAMamba2Mixer(config=config, layer_idx=layer_idx)
|
|
|
|
| 1190 |
self.feed_forward = HybriDNAMLP(config)
|
| 1191 |
self.input_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1192 |
self.pre_ff_layernorm = HybriDNARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 1194 |
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
|
| 1195 |
position_ids: Optional[torch.LongTensor] = None,
|
| 1196 |
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
|
| 1197 |
+
output_attentions: Optional[bool] = False,
|
| 1198 |
use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None
|
| 1199 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
|
| 1200 |
residual = hidden_states
|
| 1201 |
|
| 1202 |
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
| 1210 |
# residual connection after mamba
|
| 1211 |
hidden_states = residual + hidden_states
|
| 1212 |
|
| 1213 |
+
# feed-forward
|
| 1214 |
residual = hidden_states
|
| 1215 |
hidden_states = self.pre_ff_layernorm(hidden_states)
|
|
|
|
| 1216 |
hidden_states = self.feed_forward(hidden_states)
|
| 1217 |
hidden_states = residual + hidden_states
|
| 1218 |
|
|
|
|
| 1221 |
outputs += (self_attn_weights,)
|
| 1222 |
if use_cache:
|
| 1223 |
outputs += (past_key_value,)
|
|
|
|
| 1224 |
return outputs
|
| 1225 |
|
| 1226 |
|
| 1227 |
+
HYBRIDNA_START_DOCSTRING = r"""
|
| 1228 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1229 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 1230 |
etc.)
|
|
|
|
| 1241 |
|
| 1242 |
@add_start_docstrings(
|
| 1243 |
"The bare HybriDNA Model outputting raw hidden-states without any specific head on top.",
|
| 1244 |
+
HYBRIDNA_START_DOCSTRING,
|
| 1245 |
)
|
| 1246 |
class HybriDNAPreTrainedModel(PreTrainedModel):
|
| 1247 |
config_class = HybriDNAConfig
|
|
|
|
| 1265 |
module.weight.data[module.padding_idx].zero_()
|
| 1266 |
|
| 1267 |
|
| 1268 |
+
HYBRIDNA_INPUTS_DOCSTRING = r"""
|
| 1269 |
Args:
|
| 1270 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1271 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
|
|
| 1315 |
output_hidden_states (`bool`, *optional*):
|
| 1316 |
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1317 |
more detail.
|
|
|
|
|
|
|
|
|
|
| 1318 |
return_dict (`bool`, *optional*):
|
| 1319 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1320 |
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
|
|
|
| 1328 |
|
| 1329 |
@add_start_docstrings(
|
| 1330 |
"The bare HybriDNA Model outputting raw hidden-states without any specific head on top.",
|
| 1331 |
+
HYBRIDNA_START_DOCSTRING,
|
| 1332 |
)
|
|
|
|
| 1333 |
class HybriDNAModel(HybriDNAPreTrainedModel):
|
| 1334 |
"""
|
| 1335 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HybriDNADecoderLayer`]
|
|
|
|
| 1362 |
def set_input_embeddings(self, value):
|
| 1363 |
self.embed_tokens = value
|
| 1364 |
|
| 1365 |
+
@add_start_docstrings_to_model_forward(HYBRIDNA_INPUTS_DOCSTRING)
|
| 1366 |
def forward(
|
| 1367 |
self,
|
| 1368 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 1373 |
use_cache: Optional[bool] = None,
|
| 1374 |
output_attentions: Optional[bool] = None,
|
| 1375 |
output_hidden_states: Optional[bool] = None,
|
|
|
|
| 1376 |
return_dict: Optional[bool] = None,
|
| 1377 |
cache_position: Optional[torch.LongTensor] = None,
|
| 1378 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 1379 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
|
|
|
|
|
|
| 1380 |
output_hidden_states = (
|
| 1381 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1382 |
)
|
|
|
|
| 1415 |
|
| 1416 |
all_hidden_states = () if output_hidden_states else None
|
| 1417 |
all_self_attns = () if output_attentions else None
|
|
|
|
| 1418 |
|
| 1419 |
for decoder_layer in self.layers:
|
| 1420 |
if output_hidden_states:
|
|
|
|
| 1428 |
position_ids,
|
| 1429 |
past_key_values,
|
| 1430 |
output_attentions,
|
|
|
|
| 1431 |
use_cache,
|
| 1432 |
cache_position,
|
| 1433 |
)
|
|
|
|
| 1438 |
position_ids=position_ids,
|
| 1439 |
past_key_value=past_key_values,
|
| 1440 |
output_attentions=output_attentions,
|
|
|
|
| 1441 |
use_cache=use_cache,
|
| 1442 |
cache_position=cache_position,
|
| 1443 |
)
|
|
|
|
| 1449 |
# append attentions only of attention layers. Mamba layers return `None` as the attention weights
|
| 1450 |
all_self_attns += (layer_outputs[1],)
|
| 1451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1452 |
hidden_states = self.final_layernorm(hidden_states)
|
| 1453 |
|
| 1454 |
# add hidden states from the last decoder layer
|
|
|
|
| 1465 |
if not return_dict:
|
| 1466 |
return tuple(
|
| 1467 |
v
|
| 1468 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
| 1469 |
if v is not None
|
| 1470 |
)
|
| 1471 |
+
return BaseModelOutputWithPast(
|
| 1472 |
last_hidden_state=hidden_states,
|
| 1473 |
past_key_values=next_cache,
|
| 1474 |
hidden_states=all_hidden_states,
|
| 1475 |
attentions=all_self_attns,
|
|
|
|
| 1476 |
)
|
| 1477 |
|
| 1478 |
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
|
|
|
| 1511 |
return causal_mask
|
| 1512 |
|
| 1513 |
|
|
|
|
| 1514 |
class HybriDNAForCausalLM(HybriDNAPreTrainedModel, GenerationMixin):
|
| 1515 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1516 |
|
|
|
|
| 1519 |
self.model = HybriDNAModel(config)
|
| 1520 |
self.vocab_size = config.vocab_size
|
| 1521 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
| 1522 |
# Initialize weights and apply final processing
|
| 1523 |
self.post_init()
|
| 1524 |
|
|
|
|
| 1540 |
def get_decoder(self):
|
| 1541 |
return self.model
|
| 1542 |
|
| 1543 |
+
@add_start_docstrings_to_model_forward(HYBRIDNA_INPUTS_DOCSTRING)
|
| 1544 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
|
| 1545 |
def forward(
|
| 1546 |
self,
|
| 1547 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 1553 |
use_cache: Optional[bool] = None,
|
| 1554 |
output_attentions: Optional[bool] = None,
|
| 1555 |
output_hidden_states: Optional[bool] = None,
|
|
|
|
| 1556 |
return_dict: Optional[bool] = None,
|
| 1557 |
cache_position: Optional[torch.LongTensor] = None,
|
| 1558 |
num_logits_to_keep: Optional[Union[int, None]] = None,
|
| 1559 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1560 |
r"""
|
| 1561 |
Args:
|
| 1562 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
|
| 1570 |
Returns:
|
| 1571 |
Example:
|
| 1572 |
```python
|
| 1573 |
+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 1574 |
+
>>> model = AutoModelForCausalLM.from_pretrained("Mishamq/HybriDNA-300M", trust_remote_code=True)
|
| 1575 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Mishamq/HybriDNA-300M", trust_remote_code=True)
|
| 1576 |
+
>>> prompt = "ACGTACGTACGTACGT"
|
| 1577 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1578 |
>>> # Generate
|
| 1579 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1580 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
|
|
| 1581 |
```"""
|
| 1582 |
|
| 1583 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1584 |
output_hidden_states = (
|
| 1585 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1586 |
)
|
|
|
|
| 1596 |
use_cache=use_cache,
|
| 1597 |
output_attentions=output_attentions,
|
| 1598 |
output_hidden_states=output_hidden_states,
|
|
|
|
| 1599 |
cache_position=cache_position,
|
| 1600 |
return_dict=return_dict,
|
| 1601 |
)
|
|
|
|
| 1620 |
shift_labels = shift_labels.to(shift_logits.device)
|
| 1621 |
loss = loss_fct(shift_logits, shift_labels)
|
| 1622 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1623 |
if not return_dict:
|
| 1624 |
output = (logits,) + outputs[1:]
|
|
|
|
|
|
|
| 1625 |
return (loss,) + output if loss is not None else output
|
| 1626 |
|
| 1627 |
+
return CausalLMOutputWithPast(
|
| 1628 |
loss=loss,
|
|
|
|
| 1629 |
logits=logits,
|
| 1630 |
past_key_values=outputs.past_key_values,
|
| 1631 |
hidden_states=outputs.hidden_states,
|
| 1632 |
attentions=outputs.attentions,
|
|
|
|
| 1633 |
)
|
| 1634 |
|
| 1635 |
def prepare_inputs_for_generation(
|
|
|
|
| 1638 |
past_key_values=None,
|
| 1639 |
attention_mask=None,
|
| 1640 |
inputs_embeds=None,
|
|
|
|
| 1641 |
cache_position=None,
|
| 1642 |
**kwargs,
|
| 1643 |
):
|
|
|
|
| 1693 |
"past_key_values": past_key_values,
|
| 1694 |
"use_cache": kwargs.get("use_cache"),
|
| 1695 |
"attention_mask": attention_mask,
|
|
|
|
| 1696 |
"num_logits_to_keep": self.config.num_logits_to_keep,
|
| 1697 |
"cache_position": cache_position,
|
| 1698 |
}
|
|
|
|
| 1711 |
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
| 1712 |
each row of the batch).
|
| 1713 |
""",
|
| 1714 |
+
HYBRIDNA_START_DOCSTRING,
|
| 1715 |
)
|
|
|
|
| 1716 |
class HybriDNAForSequenceClassification(HybriDNAPreTrainedModel):
|
| 1717 |
def __init__(self, config):
|
| 1718 |
super().__init__(config)
|
|
|
|
| 1729 |
def set_input_embeddings(self, value):
|
| 1730 |
self.model.embed_tokens = value
|
| 1731 |
|
| 1732 |
+
@add_start_docstrings_to_model_forward(HYBRIDNA_INPUTS_DOCSTRING)
|
| 1733 |
def forward(
|
| 1734 |
self,
|
| 1735 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 1827 |
The input sequence is concatenated with its reverse complement before being processed by the model.
|
| 1828 |
[`HybriDNAForSequenceClassificationRCEcho`]
|
| 1829 |
""",
|
| 1830 |
+
HYBRIDNA_START_DOCSTRING,
|
| 1831 |
)
|
| 1832 |
class HybriDNAForSequenceClassificationRCEcho(HybriDNAPreTrainedModel):
|
| 1833 |
def __init__(self, config):
|
|
|
|
| 1869 |
rc = torch.flip(rc, dims=[1])
|
| 1870 |
return rc
|
| 1871 |
|
| 1872 |
+
@add_start_docstrings_to_model_forward(HYBRIDNA_INPUTS_DOCSTRING)
|
| 1873 |
def forward(
|
| 1874 |
self,
|
| 1875 |
input_ids: torch.LongTensor = None,
|