Multipack simplify for Mixtral (#1142)
Browse files- src/axolotl/core/trainer_builder.py +16 -7
- src/axolotl/monkeypatch/mixtral/__init__.py +4 -14
- src/axolotl/monkeypatch/mixtral/modeling_mixtral.py +0 -383
- src/axolotl/monkeypatch/utils.py +34 -0
- src/axolotl/utils/collators.py +27 -0
- src/axolotl/utils/config.py +38 -2
- src/axolotl/utils/models.py +7 -3
- src/axolotl/utils/trainer.py +1 -0
- tests/e2e/patched/test_mixtral_samplepack.py +3 -15
- tests/e2e/patched/test_model_patches.py +1 -5
- tests/monkeypatch/test_llama_attn_hijack_flash.py +70 -1
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -12,7 +12,7 @@ from abc import abstractmethod
|
|
| 12 |
from dataclasses import dataclass, field
|
| 13 |
from functools import wraps
|
| 14 |
from pathlib import Path
|
| 15 |
-
from typing import Optional
|
| 16 |
|
| 17 |
import torch
|
| 18 |
import transformers
|
|
@@ -37,6 +37,7 @@ from axolotl.utils.collators import (
|
|
| 37 |
BatchSamplerDataCollatorForSeq2Seq,
|
| 38 |
DataCollatorForSeq2Seq,
|
| 39 |
MambaDataCollator,
|
|
|
|
| 40 |
)
|
| 41 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
| 42 |
from axolotl.utils.schedulers import (
|
|
@@ -896,14 +897,22 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 896 |
if is_eval and training_args.eval_sample_packing:
|
| 897 |
use_batch_sampler_collator = True
|
| 898 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 899 |
if use_batch_sampler_collator:
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
|
|
|
| 905 |
|
| 906 |
-
return
|
| 907 |
self.tokenizer,
|
| 908 |
return_tensors="pt",
|
| 909 |
**kwargs,
|
|
|
|
| 12 |
from dataclasses import dataclass, field
|
| 13 |
from functools import wraps
|
| 14 |
from pathlib import Path
|
| 15 |
+
from typing import Optional, Type, Union
|
| 16 |
|
| 17 |
import torch
|
| 18 |
import transformers
|
|
|
|
| 37 |
BatchSamplerDataCollatorForSeq2Seq,
|
| 38 |
DataCollatorForSeq2Seq,
|
| 39 |
MambaDataCollator,
|
| 40 |
+
V2BatchSamplerDataCollatorForSeq2Seq,
|
| 41 |
)
|
| 42 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
| 43 |
from axolotl.utils.schedulers import (
|
|
|
|
| 897 |
if is_eval and training_args.eval_sample_packing:
|
| 898 |
use_batch_sampler_collator = True
|
| 899 |
|
| 900 |
+
collator: Type[
|
| 901 |
+
Union[
|
| 902 |
+
V2BatchSamplerDataCollatorForSeq2Seq,
|
| 903 |
+
BatchSamplerDataCollatorForSeq2Seq,
|
| 904 |
+
DataCollatorForSeq2Seq,
|
| 905 |
+
]
|
| 906 |
+
]
|
| 907 |
if use_batch_sampler_collator:
|
| 908 |
+
if self.cfg.model_config_type == "mixtral":
|
| 909 |
+
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
| 910 |
+
else:
|
| 911 |
+
collator = BatchSamplerDataCollatorForSeq2Seq
|
| 912 |
+
else:
|
| 913 |
+
collator = DataCollatorForSeq2Seq
|
| 914 |
|
| 915 |
+
return collator(
|
| 916 |
self.tokenizer,
|
| 917 |
return_tensors="pt",
|
| 918 |
**kwargs,
|
src/axolotl/monkeypatch/mixtral/__init__.py
CHANGED
|
@@ -3,20 +3,10 @@ Patches to support multipack for mixtral
|
|
| 3 |
"""
|
| 4 |
import transformers
|
| 5 |
|
|
|
|
| 6 |
|
| 7 |
-
def replace_mixtral_attn_with_multipack_flash_attn():
|
| 8 |
-
from .modeling_mixtral import (
|
| 9 |
-
MixtralMultipackFlashAttention2,
|
| 10 |
-
mixtral_decoder_layer_forward,
|
| 11 |
-
mixtral_model_forward,
|
| 12 |
-
)
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
| 18 |
-
mixtral_model_forward
|
| 19 |
)
|
| 20 |
-
transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[
|
| 21 |
-
"flash_attention_2"
|
| 22 |
-
] = MixtralMultipackFlashAttention2
|
|
|
|
| 3 |
"""
|
| 4 |
import transformers
|
| 5 |
|
| 6 |
+
from axolotl.monkeypatch.utils import get_unpad_data
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
def replace_mixtral_attn_with_multipack_flash_attn():
|
| 10 |
+
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
| 11 |
+
get_unpad_data
|
|
|
|
|
|
|
| 12 |
)
|
|
|
|
|
|
|
|
|
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
DELETED
|
@@ -1,383 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Mixtral modeling for multipack
|
| 3 |
-
"""
|
| 4 |
-
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
|
| 5 |
-
import logging
|
| 6 |
-
import warnings
|
| 7 |
-
from typing import List, Optional, Tuple, Union
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
from einops import rearrange
|
| 11 |
-
from flash_attn import flash_attn_varlen_qkvpacked_func
|
| 12 |
-
from transformers import Cache, DynamicCache
|
| 13 |
-
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 14 |
-
from transformers.modeling_outputs import MoeModelOutputWithPast
|
| 15 |
-
from transformers.models.mixtral.modeling_mixtral import (
|
| 16 |
-
MixtralFlashAttention2,
|
| 17 |
-
apply_rotary_pos_emb,
|
| 18 |
-
repeat_kv,
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
| 22 |
-
|
| 23 |
-
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
|
| 27 |
-
"""
|
| 28 |
-
Custom multipack implementation w flash attention 2
|
| 29 |
-
"""
|
| 30 |
-
|
| 31 |
-
def __init__(self, *args, **kwargs):
|
| 32 |
-
super().__init__(*args, **kwargs)
|
| 33 |
-
self._flash_attn_uses_top_left_mask = True
|
| 34 |
-
|
| 35 |
-
def forward(
|
| 36 |
-
self,
|
| 37 |
-
hidden_states: torch.Tensor,
|
| 38 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 39 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 40 |
-
past_key_value: Optional[Cache] = None,
|
| 41 |
-
output_attentions: bool = False,
|
| 42 |
-
use_cache: bool = False,
|
| 43 |
-
cu_seqlens: Optional[torch.Tensor] = None,
|
| 44 |
-
max_seqlen: Optional[torch.Tensor] = None,
|
| 45 |
-
**kwargs,
|
| 46 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 47 |
-
if "padding_mask" in kwargs:
|
| 48 |
-
warnings.warn(
|
| 49 |
-
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
| 50 |
-
)
|
| 51 |
-
bsz, q_len, _ = hidden_states.size()
|
| 52 |
-
|
| 53 |
-
query_states = self.q_proj(hidden_states)
|
| 54 |
-
key_states = self.k_proj(hidden_states)
|
| 55 |
-
value_states = self.v_proj(hidden_states)
|
| 56 |
-
|
| 57 |
-
query_states = query_states.view(
|
| 58 |
-
bsz, q_len, self.num_heads, self.head_dim
|
| 59 |
-
).transpose(1, 2)
|
| 60 |
-
key_states = key_states.view(
|
| 61 |
-
bsz, q_len, self.num_key_value_heads, self.head_dim
|
| 62 |
-
).transpose(1, 2)
|
| 63 |
-
value_states = value_states.view(
|
| 64 |
-
bsz, q_len, self.num_key_value_heads, self.head_dim
|
| 65 |
-
).transpose(1, 2)
|
| 66 |
-
|
| 67 |
-
kv_seq_len = key_states.shape[-2]
|
| 68 |
-
if past_key_value is not None:
|
| 69 |
-
if self.layer_idx is None:
|
| 70 |
-
raise ValueError(
|
| 71 |
-
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
| 72 |
-
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 73 |
-
"with a layer index."
|
| 74 |
-
)
|
| 75 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 76 |
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 77 |
-
query_states, key_states = apply_rotary_pos_emb(
|
| 78 |
-
query_states, key_states, cos, sin, position_ids
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
if past_key_value is not None:
|
| 82 |
-
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
| 83 |
-
key_states, value_states = past_key_value.update(
|
| 84 |
-
key_states, value_states, self.layer_idx, cache_kwargs
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
# repeat k/v heads if n_kv_heads < n_heads
|
| 88 |
-
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 89 |
-
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 90 |
-
|
| 91 |
-
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
| 92 |
-
# special handling using sample packing
|
| 93 |
-
qkv = torch.stack(
|
| 94 |
-
[query_states, key_states, value_states], dim=2
|
| 95 |
-
) # [bsz, nh, 3, q_len, hd]
|
| 96 |
-
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
| 97 |
-
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
| 98 |
-
|
| 99 |
-
attn_output = flash_attn_varlen_qkvpacked_func(
|
| 100 |
-
qkv,
|
| 101 |
-
cu_seqlens,
|
| 102 |
-
max_seqlen,
|
| 103 |
-
dropout_p=self.attention_dropout,
|
| 104 |
-
softmax_scale=None,
|
| 105 |
-
causal=True,
|
| 106 |
-
)
|
| 107 |
-
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
|
| 108 |
-
|
| 109 |
-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
| 110 |
-
attn_output = self.o_proj(attn_output)
|
| 111 |
-
|
| 112 |
-
if not output_attentions:
|
| 113 |
-
attn_weights = None
|
| 114 |
-
|
| 115 |
-
return attn_output, attn_weights, past_key_value
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def mixtral_decoder_layer_forward(
|
| 119 |
-
self,
|
| 120 |
-
hidden_states: torch.Tensor,
|
| 121 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 122 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 123 |
-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 124 |
-
output_attentions: Optional[bool] = False,
|
| 125 |
-
output_router_logits: Optional[bool] = False,
|
| 126 |
-
use_cache: Optional[bool] = False,
|
| 127 |
-
cu_seqlens: Optional[torch.Tensor] = None,
|
| 128 |
-
max_seqlen: Optional[torch.Tensor] = None,
|
| 129 |
-
**kwargs,
|
| 130 |
-
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 131 |
-
if "padding_mask" in kwargs:
|
| 132 |
-
warnings.warn(
|
| 133 |
-
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
| 134 |
-
)
|
| 135 |
-
"""
|
| 136 |
-
Args:
|
| 137 |
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 138 |
-
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 139 |
-
`(batch, sequence_length)` where padding elements are indicated by 0.
|
| 140 |
-
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 141 |
-
output_attentions (`bool`, *optional*):
|
| 142 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 143 |
-
returned tensors for more detail.
|
| 144 |
-
output_router_logits (`bool`, *optional*):
|
| 145 |
-
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
| 146 |
-
should not be returned during inference.
|
| 147 |
-
use_cache (`bool`, *optional*):
|
| 148 |
-
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 149 |
-
(see `past_key_values`).
|
| 150 |
-
"""
|
| 151 |
-
|
| 152 |
-
residual = hidden_states
|
| 153 |
-
|
| 154 |
-
hidden_states = self.input_layernorm(hidden_states)
|
| 155 |
-
|
| 156 |
-
# Self Attention
|
| 157 |
-
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 158 |
-
hidden_states=hidden_states,
|
| 159 |
-
attention_mask=attention_mask,
|
| 160 |
-
position_ids=position_ids,
|
| 161 |
-
past_key_value=past_key_value,
|
| 162 |
-
output_attentions=output_attentions,
|
| 163 |
-
use_cache=use_cache,
|
| 164 |
-
cu_seqlens=cu_seqlens,
|
| 165 |
-
max_seqlen=max_seqlen,
|
| 166 |
-
)
|
| 167 |
-
hidden_states = residual + hidden_states
|
| 168 |
-
|
| 169 |
-
# Fully Connected
|
| 170 |
-
residual = hidden_states
|
| 171 |
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 172 |
-
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
| 173 |
-
hidden_states = residual + hidden_states
|
| 174 |
-
|
| 175 |
-
outputs = (hidden_states,)
|
| 176 |
-
|
| 177 |
-
if output_attentions:
|
| 178 |
-
outputs += (self_attn_weights,)
|
| 179 |
-
|
| 180 |
-
if use_cache:
|
| 181 |
-
outputs += (present_key_value,)
|
| 182 |
-
|
| 183 |
-
if output_router_logits:
|
| 184 |
-
outputs += (router_logits,)
|
| 185 |
-
|
| 186 |
-
return outputs
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
def mixtral_model_forward(
|
| 190 |
-
self,
|
| 191 |
-
input_ids: torch.LongTensor = None,
|
| 192 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 193 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 194 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 195 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 196 |
-
use_cache: Optional[bool] = None,
|
| 197 |
-
output_attentions: Optional[bool] = None,
|
| 198 |
-
output_hidden_states: Optional[bool] = None,
|
| 199 |
-
output_router_logits: Optional[bool] = None,
|
| 200 |
-
return_dict: Optional[bool] = None,
|
| 201 |
-
) -> Union[Tuple, MoeModelOutputWithPast]:
|
| 202 |
-
output_attentions = (
|
| 203 |
-
output_attentions
|
| 204 |
-
if output_attentions is not None
|
| 205 |
-
else self.config.output_attentions
|
| 206 |
-
)
|
| 207 |
-
output_router_logits = (
|
| 208 |
-
output_router_logits
|
| 209 |
-
if output_router_logits is not None
|
| 210 |
-
else self.config.output_router_logits
|
| 211 |
-
)
|
| 212 |
-
output_hidden_states = (
|
| 213 |
-
output_hidden_states
|
| 214 |
-
if output_hidden_states is not None
|
| 215 |
-
else self.config.output_hidden_states
|
| 216 |
-
)
|
| 217 |
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 218 |
-
|
| 219 |
-
return_dict = (
|
| 220 |
-
return_dict if return_dict is not None else self.config.use_return_dict
|
| 221 |
-
)
|
| 222 |
-
|
| 223 |
-
# retrieve input_ids and inputs_embeds
|
| 224 |
-
if input_ids is not None and inputs_embeds is not None:
|
| 225 |
-
raise ValueError(
|
| 226 |
-
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
| 227 |
-
)
|
| 228 |
-
if input_ids is not None:
|
| 229 |
-
batch_size, seq_length = input_ids.shape
|
| 230 |
-
elif inputs_embeds is not None:
|
| 231 |
-
batch_size, seq_length, _ = inputs_embeds.shape
|
| 232 |
-
else:
|
| 233 |
-
raise ValueError(
|
| 234 |
-
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
| 235 |
-
)
|
| 236 |
-
|
| 237 |
-
past_key_values_length = 0
|
| 238 |
-
|
| 239 |
-
if use_cache:
|
| 240 |
-
use_legacy_cache = not isinstance(past_key_values, Cache)
|
| 241 |
-
if use_legacy_cache:
|
| 242 |
-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 243 |
-
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
| 244 |
-
|
| 245 |
-
cu_seqlens = None
|
| 246 |
-
max_seqlen = None
|
| 247 |
-
if position_ids is None:
|
| 248 |
-
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 249 |
-
position_ids = torch.arange(
|
| 250 |
-
past_key_values_length,
|
| 251 |
-
seq_length + past_key_values_length,
|
| 252 |
-
dtype=torch.long,
|
| 253 |
-
device=device,
|
| 254 |
-
)
|
| 255 |
-
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 256 |
-
else:
|
| 257 |
-
position_ids = position_ids.view(-1, seq_length).long()
|
| 258 |
-
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
| 259 |
-
cu_seqlens = cu_seqlens.squeeze()
|
| 260 |
-
|
| 261 |
-
if inputs_embeds is None:
|
| 262 |
-
inputs_embeds = self.embed_tokens(input_ids)
|
| 263 |
-
|
| 264 |
-
if (
|
| 265 |
-
attention_mask is not None
|
| 266 |
-
and self._attn_implementation == "flash_attention_2"
|
| 267 |
-
and use_cache
|
| 268 |
-
):
|
| 269 |
-
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
| 270 |
-
if is_padding_right:
|
| 271 |
-
raise ValueError(
|
| 272 |
-
"You are attempting to perform batched generation with padding_side='right'"
|
| 273 |
-
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
| 274 |
-
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
| 275 |
-
)
|
| 276 |
-
|
| 277 |
-
if self._attn_implementation == "flash_attention_2":
|
| 278 |
-
# 2d mask is passed through the layers
|
| 279 |
-
attention_mask = (
|
| 280 |
-
attention_mask
|
| 281 |
-
if (attention_mask is not None and 0 in attention_mask)
|
| 282 |
-
else None
|
| 283 |
-
)
|
| 284 |
-
else:
|
| 285 |
-
# 4d mask is passed through the layers
|
| 286 |
-
attention_mask = _prepare_4d_causal_attention_mask(
|
| 287 |
-
attention_mask,
|
| 288 |
-
(batch_size, seq_length),
|
| 289 |
-
inputs_embeds,
|
| 290 |
-
past_key_values_length,
|
| 291 |
-
sliding_window=self.config.sliding_window,
|
| 292 |
-
)
|
| 293 |
-
|
| 294 |
-
hidden_states = inputs_embeds
|
| 295 |
-
|
| 296 |
-
if self.gradient_checkpointing and self.training:
|
| 297 |
-
if use_cache:
|
| 298 |
-
LOG.warning_once(
|
| 299 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 300 |
-
)
|
| 301 |
-
use_cache = False
|
| 302 |
-
|
| 303 |
-
# decoder layers
|
| 304 |
-
all_hidden_states = () if output_hidden_states else None
|
| 305 |
-
all_self_attns = () if output_attentions else None
|
| 306 |
-
all_router_logits = () if output_router_logits else None
|
| 307 |
-
next_decoder_cache = None
|
| 308 |
-
|
| 309 |
-
for decoder_layer in self.layers:
|
| 310 |
-
if output_hidden_states:
|
| 311 |
-
all_hidden_states += (hidden_states,)
|
| 312 |
-
|
| 313 |
-
if self.gradient_checkpointing and self.training:
|
| 314 |
-
layer_outputs = self._gradient_checkpointing_func(
|
| 315 |
-
decoder_layer.__call__,
|
| 316 |
-
hidden_states,
|
| 317 |
-
attention_mask,
|
| 318 |
-
position_ids,
|
| 319 |
-
past_key_values,
|
| 320 |
-
output_attentions,
|
| 321 |
-
output_router_logits,
|
| 322 |
-
use_cache,
|
| 323 |
-
cu_seqlens,
|
| 324 |
-
max_seqlen,
|
| 325 |
-
)
|
| 326 |
-
else:
|
| 327 |
-
layer_outputs = decoder_layer(
|
| 328 |
-
hidden_states,
|
| 329 |
-
attention_mask=attention_mask,
|
| 330 |
-
position_ids=position_ids,
|
| 331 |
-
past_key_value=past_key_values,
|
| 332 |
-
output_attentions=output_attentions,
|
| 333 |
-
output_router_logits=output_router_logits,
|
| 334 |
-
use_cache=use_cache,
|
| 335 |
-
cu_seqlens=cu_seqlens,
|
| 336 |
-
max_seqlen=max_seqlen,
|
| 337 |
-
)
|
| 338 |
-
|
| 339 |
-
hidden_states = layer_outputs[0]
|
| 340 |
-
|
| 341 |
-
if use_cache:
|
| 342 |
-
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 343 |
-
|
| 344 |
-
if output_attentions:
|
| 345 |
-
all_self_attns += (layer_outputs[1],)
|
| 346 |
-
|
| 347 |
-
if output_router_logits:
|
| 348 |
-
all_router_logits += (layer_outputs[-1],)
|
| 349 |
-
|
| 350 |
-
hidden_states = self.norm(hidden_states)
|
| 351 |
-
|
| 352 |
-
# add hidden states from the last decoder layer
|
| 353 |
-
if output_hidden_states:
|
| 354 |
-
all_hidden_states += (hidden_states,)
|
| 355 |
-
|
| 356 |
-
next_cache = None
|
| 357 |
-
if use_cache:
|
| 358 |
-
next_cache = (
|
| 359 |
-
next_decoder_cache.to_legacy_cache()
|
| 360 |
-
if use_legacy_cache
|
| 361 |
-
else next_decoder_cache
|
| 362 |
-
)
|
| 363 |
-
|
| 364 |
-
if not return_dict:
|
| 365 |
-
return tuple(
|
| 366 |
-
v
|
| 367 |
-
for v in [
|
| 368 |
-
hidden_states,
|
| 369 |
-
next_cache,
|
| 370 |
-
all_hidden_states,
|
| 371 |
-
all_self_attns,
|
| 372 |
-
all_router_logits,
|
| 373 |
-
]
|
| 374 |
-
if v is not None
|
| 375 |
-
)
|
| 376 |
-
|
| 377 |
-
return MoeModelOutputWithPast(
|
| 378 |
-
last_hidden_state=hidden_states,
|
| 379 |
-
past_key_values=next_cache,
|
| 380 |
-
hidden_states=all_hidden_states,
|
| 381 |
-
attentions=all_self_attns,
|
| 382 |
-
router_logits=all_router_logits,
|
| 383 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/axolotl/monkeypatch/utils.py
CHANGED
|
@@ -2,6 +2,40 @@
|
|
| 2 |
Shared utils for the monkeypatches
|
| 3 |
"""
|
| 4 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def get_cu_seqlens(attn_mask):
|
|
|
|
| 2 |
Shared utils for the monkeypatches
|
| 3 |
"""
|
| 4 |
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@torch.jit.script
|
| 9 |
+
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
|
| 10 |
+
max_num = int(torch.max(attention_mask).item())
|
| 11 |
+
batch_size, _ = attention_mask.shape
|
| 12 |
+
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
|
| 13 |
+
|
| 14 |
+
for i in range(1, max_num + 1):
|
| 15 |
+
mask = attention_mask == i
|
| 16 |
+
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
|
| 17 |
+
|
| 18 |
+
result = counts.flatten()
|
| 19 |
+
nonzero_indices = torch.nonzero(result).squeeze(-1)
|
| 20 |
+
return result[nonzero_indices]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@torch.jit.script
|
| 24 |
+
def get_unpad_data(attention_mask: torch.Tensor):
|
| 25 |
+
device = attention_mask.device
|
| 26 |
+
seqlens_in_batch = get_max_seqlen_in_batch(attention_mask)
|
| 27 |
+
indices = torch.nonzero(attention_mask.flatten()).flatten()
|
| 28 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 29 |
+
cu_seqlens = (
|
| 30 |
+
F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 31 |
+
.to(device=device)
|
| 32 |
+
.detach()
|
| 33 |
+
)
|
| 34 |
+
return (
|
| 35 |
+
indices,
|
| 36 |
+
cu_seqlens,
|
| 37 |
+
max_seqlen_in_batch,
|
| 38 |
+
)
|
| 39 |
|
| 40 |
|
| 41 |
def get_cu_seqlens(attn_mask):
|
src/axolotl/utils/collators.py
CHANGED
|
@@ -152,6 +152,33 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|
| 152 |
return super().__call__(features, return_tensors=return_tensors)
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
@dataclass
|
| 156 |
class MambaDataCollator:
|
| 157 |
"""
|
|
|
|
| 152 |
return super().__call__(features, return_tensors=return_tensors)
|
| 153 |
|
| 154 |
|
| 155 |
+
@dataclass
|
| 156 |
+
class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
| 157 |
+
"""
|
| 158 |
+
Collator for multipack specific to the using the BatchSampler
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __call__(self, features, return_tensors=None):
|
| 162 |
+
chunked_data = {}
|
| 163 |
+
for feature in features[0].keys():
|
| 164 |
+
if feature == "length":
|
| 165 |
+
continue
|
| 166 |
+
if feature == "attention_mask":
|
| 167 |
+
arrays = [
|
| 168 |
+
(i + 1) * np.array(item[feature])
|
| 169 |
+
for i, item in enumerate(features)
|
| 170 |
+
if feature in item
|
| 171 |
+
]
|
| 172 |
+
chunked_data[feature] = np.concatenate(arrays)
|
| 173 |
+
else:
|
| 174 |
+
arrays = [
|
| 175 |
+
np.array(item[feature]) for item in features if feature in item
|
| 176 |
+
]
|
| 177 |
+
chunked_data[feature] = np.concatenate(arrays)
|
| 178 |
+
features = [chunked_data]
|
| 179 |
+
return super().__call__(features, return_tensors=return_tensors)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
@dataclass
|
| 183 |
class MambaDataCollator:
|
| 184 |
"""
|
src/axolotl/utils/config.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
| 1 |
"""Module for working with config dicts"""
|
| 2 |
-
|
| 3 |
import logging
|
| 4 |
import os
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
from transformers.utils import is_torch_bf16_gpu_available
|
| 8 |
|
| 9 |
from axolotl.utils.bench import log_gpu_memory_usage
|
|
|
|
| 10 |
from axolotl.utils.models import load_model_config
|
| 11 |
|
| 12 |
LOG = logging.getLogger("axolotl")
|
|
@@ -135,7 +137,7 @@ def normalize_config(cfg):
|
|
| 135 |
]
|
| 136 |
)
|
| 137 |
or cfg.is_mistral_derived_model
|
| 138 |
-
or "mistral" in cfg.base_model.lower()
|
| 139 |
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
| 140 |
)
|
| 141 |
|
|
@@ -484,6 +486,40 @@ def validate_config(cfg):
|
|
| 484 |
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
| 485 |
)
|
| 486 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
# TODO
|
| 488 |
# MPT 7b
|
| 489 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
|
| 1 |
"""Module for working with config dicts"""
|
| 2 |
+
import json
|
| 3 |
import logging
|
| 4 |
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
|
| 7 |
import torch
|
| 8 |
from transformers.utils import is_torch_bf16_gpu_available
|
| 9 |
|
| 10 |
from axolotl.utils.bench import log_gpu_memory_usage
|
| 11 |
+
from axolotl.utils.dict import DictDefault
|
| 12 |
from axolotl.utils.models import load_model_config
|
| 13 |
|
| 14 |
LOG = logging.getLogger("axolotl")
|
|
|
|
| 137 |
]
|
| 138 |
)
|
| 139 |
or cfg.is_mistral_derived_model
|
| 140 |
+
or "mistral" in cfg.base_model.lower().split("/")[-1]
|
| 141 |
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
| 142 |
)
|
| 143 |
|
|
|
|
| 486 |
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
| 487 |
)
|
| 488 |
|
| 489 |
+
if (
|
| 490 |
+
cfg.unfrozen_parameters
|
| 491 |
+
and cfg.gradient_checkpointing_kwargs
|
| 492 |
+
and cfg.gradient_checkpointing_kwargs.use_reentrant is True
|
| 493 |
+
):
|
| 494 |
+
# https://github.com/huggingface/transformers/issues/21381
|
| 495 |
+
raise ValueError(
|
| 496 |
+
"`use_reentrant` must be false when used with partially frozen model."
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if cfg.flash_attention and cfg.deepspeed and Path(cfg.deepspeed).is_file():
|
| 500 |
+
with open(cfg.deepspeed, encoding="utf-8") as file:
|
| 501 |
+
contents = file.read()
|
| 502 |
+
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
| 503 |
+
if (
|
| 504 |
+
deepspeed_cfg.zero_optimization
|
| 505 |
+
and deepspeed_cfg.zero_optimization.stage == 3
|
| 506 |
+
):
|
| 507 |
+
if not (
|
| 508 |
+
(
|
| 509 |
+
deepspeed_cfg.bf16
|
| 510 |
+
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
| 511 |
+
is True
|
| 512 |
+
)
|
| 513 |
+
or (
|
| 514 |
+
deepspeed_cfg.fp16
|
| 515 |
+
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
| 516 |
+
is True
|
| 517 |
+
)
|
| 518 |
+
):
|
| 519 |
+
raise ValueError(
|
| 520 |
+
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
# TODO
|
| 524 |
# MPT 7b
|
| 525 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
src/axolotl/utils/models.py
CHANGED
|
@@ -305,12 +305,16 @@ def load_model(
|
|
| 305 |
)
|
| 306 |
|
| 307 |
# Modify mistral derived models
|
| 308 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
| 310 |
replace_mistral_attn_with_flash_attn,
|
| 311 |
)
|
| 312 |
|
| 313 |
-
LOG.info("patching with flash attention")
|
| 314 |
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
| 315 |
|
| 316 |
if (
|
|
@@ -322,7 +326,7 @@ def load_model(
|
|
| 322 |
replace_mixtral_attn_with_multipack_flash_attn,
|
| 323 |
)
|
| 324 |
|
| 325 |
-
LOG.info("patching with flash attention")
|
| 326 |
replace_mixtral_attn_with_multipack_flash_attn()
|
| 327 |
|
| 328 |
if (
|
|
|
|
| 305 |
)
|
| 306 |
|
| 307 |
# Modify mistral derived models
|
| 308 |
+
if (
|
| 309 |
+
cfg.model_config_type == "mistral"
|
| 310 |
+
and cfg.flash_attention
|
| 311 |
+
and cfg.sample_packing
|
| 312 |
+
):
|
| 313 |
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
| 314 |
replace_mistral_attn_with_flash_attn,
|
| 315 |
)
|
| 316 |
|
| 317 |
+
LOG.info("patching mistral with flash attention")
|
| 318 |
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
| 319 |
|
| 320 |
if (
|
|
|
|
| 326 |
replace_mixtral_attn_with_multipack_flash_attn,
|
| 327 |
)
|
| 328 |
|
| 329 |
+
LOG.info("patching mixtral with flash attention")
|
| 330 |
replace_mixtral_attn_with_multipack_flash_attn()
|
| 331 |
|
| 332 |
if (
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -152,6 +152,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|
| 152 |
or (cfg.is_mistral_derived_model and cfg.flash_attention)
|
| 153 |
or cfg.model_config_type == "mamba"
|
| 154 |
):
|
|
|
|
| 155 |
train_dataset = train_dataset.remove_columns("attention_mask")
|
| 156 |
if eval_dataset:
|
| 157 |
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
|
|
|
| 152 |
or (cfg.is_mistral_derived_model and cfg.flash_attention)
|
| 153 |
or cfg.model_config_type == "mamba"
|
| 154 |
):
|
| 155 |
+
LOG.info("dropping attention_mask column")
|
| 156 |
train_dataset = train_dataset.remove_columns("attention_mask")
|
| 157 |
if eval_dataset:
|
| 158 |
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
tests/e2e/patched/test_mixtral_samplepack.py
CHANGED
|
@@ -7,8 +7,6 @@ import os
|
|
| 7 |
import unittest
|
| 8 |
from pathlib import Path
|
| 9 |
|
| 10 |
-
from transformers.utils import is_torch_bf16_gpu_available
|
| 11 |
-
|
| 12 |
from axolotl.cli import load_datasets
|
| 13 |
from axolotl.common.cli import TrainerCliArgs
|
| 14 |
from axolotl.train import train
|
|
@@ -60,12 +58,9 @@ class TestMixtral(unittest.TestCase):
|
|
| 60 |
"save_steps": 10,
|
| 61 |
"eval_steps": 10,
|
| 62 |
"sample_packing": True,
|
|
|
|
| 63 |
}
|
| 64 |
)
|
| 65 |
-
if is_torch_bf16_gpu_available():
|
| 66 |
-
cfg.bf16 = True
|
| 67 |
-
else:
|
| 68 |
-
cfg.fp16 = True
|
| 69 |
normalize_config(cfg)
|
| 70 |
cli_args = TrainerCliArgs()
|
| 71 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
@@ -101,23 +96,16 @@ class TestMixtral(unittest.TestCase):
|
|
| 101 |
"save_steps": 10,
|
| 102 |
"eval_steps": 10,
|
| 103 |
"sample_packing": True,
|
|
|
|
| 104 |
}
|
| 105 |
)
|
| 106 |
-
if is_torch_bf16_gpu_available():
|
| 107 |
-
cfg.bf16 = True
|
| 108 |
-
else:
|
| 109 |
-
cfg.fp16 = True
|
| 110 |
normalize_config(cfg)
|
| 111 |
cli_args = TrainerCliArgs()
|
| 112 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 113 |
|
| 114 |
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 115 |
assert (
|
| 116 |
-
"
|
| 117 |
-
in model.model.layers[0].self_attn.__class__.__module__
|
| 118 |
-
)
|
| 119 |
-
assert (
|
| 120 |
-
"MixtralMultipackFlashAttention2"
|
| 121 |
in model.model.layers[0].self_attn.__class__.__name__
|
| 122 |
)
|
| 123 |
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
|
|
|
| 7 |
import unittest
|
| 8 |
from pathlib import Path
|
| 9 |
|
|
|
|
|
|
|
| 10 |
from axolotl.cli import load_datasets
|
| 11 |
from axolotl.common.cli import TrainerCliArgs
|
| 12 |
from axolotl.train import train
|
|
|
|
| 58 |
"save_steps": 10,
|
| 59 |
"eval_steps": 10,
|
| 60 |
"sample_packing": True,
|
| 61 |
+
"bf16": "auto",
|
| 62 |
}
|
| 63 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
normalize_config(cfg)
|
| 65 |
cli_args = TrainerCliArgs()
|
| 66 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
|
| 96 |
"save_steps": 10,
|
| 97 |
"eval_steps": 10,
|
| 98 |
"sample_packing": True,
|
| 99 |
+
"bf16": "auto",
|
| 100 |
}
|
| 101 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
normalize_config(cfg)
|
| 103 |
cli_args = TrainerCliArgs()
|
| 104 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 105 |
|
| 106 |
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 107 |
assert (
|
| 108 |
+
"MixtralFlashAttention2"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
in model.model.layers[0].self_attn.__class__.__name__
|
| 110 |
)
|
| 111 |
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
tests/e2e/patched/test_model_patches.py
CHANGED
|
@@ -52,11 +52,7 @@ class TestModelPatches(unittest.TestCase):
|
|
| 52 |
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
| 53 |
|
| 54 |
assert (
|
| 55 |
-
"
|
| 56 |
-
in model.model.layers[0].self_attn.__class__.__module__
|
| 57 |
-
)
|
| 58 |
-
assert (
|
| 59 |
-
"MixtralMultipackFlashAttention2"
|
| 60 |
in model.model.layers[0].self_attn.__class__.__name__
|
| 61 |
)
|
| 62 |
|
|
|
|
| 52 |
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
| 53 |
|
| 54 |
assert (
|
| 55 |
+
"MixtralFlashAttention2"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
in model.model.layers[0].self_attn.__class__.__name__
|
| 57 |
)
|
| 58 |
|
tests/monkeypatch/test_llama_attn_hijack_flash.py
CHANGED
|
@@ -5,7 +5,12 @@ import unittest
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
|
| 8 |
-
from axolotl.monkeypatch.utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class TestMonkeyPatchUtils(unittest.TestCase):
|
|
@@ -25,6 +30,70 @@ class TestMonkeyPatchUtils(unittest.TestCase):
|
|
| 25 |
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
| 26 |
)
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
if __name__ == "__main__":
|
| 30 |
unittest.main()
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
|
| 8 |
+
from axolotl.monkeypatch.utils import (
|
| 9 |
+
get_cu_seqlens,
|
| 10 |
+
get_cu_seqlens_from_pos_ids,
|
| 11 |
+
get_max_seqlen_in_batch,
|
| 12 |
+
get_unpad_data,
|
| 13 |
+
)
|
| 14 |
|
| 15 |
|
| 16 |
class TestMonkeyPatchUtils(unittest.TestCase):
|
|
|
|
| 30 |
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
| 31 |
)
|
| 32 |
|
| 33 |
+
def test_get_max_seqlen_in_batch(self):
|
| 34 |
+
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
| 35 |
+
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)
|
| 36 |
+
self.assertTrue(torch.allclose(get_max_seqlen_in_batch(attn_mask), target_res))
|
| 37 |
+
|
| 38 |
+
def test_get_unpad_data(self):
|
| 39 |
+
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
| 40 |
+
target_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
|
| 41 |
+
target_cu_seqlen = torch.tensor([0, 4, 7, 12, 14], dtype=torch.int32)
|
| 42 |
+
target_max_seqlen_in_batch = 5
|
| 43 |
+
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
|
| 44 |
+
self.assertTrue(torch.allclose(target_indices, indices))
|
| 45 |
+
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
|
| 46 |
+
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
|
| 47 |
+
|
| 48 |
+
attn_mask = torch.tensor(
|
| 49 |
+
[
|
| 50 |
+
[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0],
|
| 51 |
+
[1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5],
|
| 52 |
+
]
|
| 53 |
+
)
|
| 54 |
+
target_indices = torch.tensor(
|
| 55 |
+
[
|
| 56 |
+
0,
|
| 57 |
+
1,
|
| 58 |
+
2,
|
| 59 |
+
3,
|
| 60 |
+
4,
|
| 61 |
+
5,
|
| 62 |
+
6,
|
| 63 |
+
7,
|
| 64 |
+
8,
|
| 65 |
+
9,
|
| 66 |
+
10,
|
| 67 |
+
11,
|
| 68 |
+
12,
|
| 69 |
+
13,
|
| 70 |
+
16,
|
| 71 |
+
17,
|
| 72 |
+
18,
|
| 73 |
+
19,
|
| 74 |
+
20,
|
| 75 |
+
21,
|
| 76 |
+
22,
|
| 77 |
+
23,
|
| 78 |
+
24,
|
| 79 |
+
25,
|
| 80 |
+
26,
|
| 81 |
+
27,
|
| 82 |
+
28,
|
| 83 |
+
29,
|
| 84 |
+
30,
|
| 85 |
+
31,
|
| 86 |
+
]
|
| 87 |
+
)
|
| 88 |
+
target_cu_seqlen = torch.tensor(
|
| 89 |
+
[0, 4, 7, 12, 14, 17, 22, 24, 27, 30], dtype=torch.int32
|
| 90 |
+
)
|
| 91 |
+
target_max_seqlen_in_batch = 5
|
| 92 |
+
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
|
| 93 |
+
self.assertTrue(torch.allclose(target_indices, indices))
|
| 94 |
+
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
|
| 95 |
+
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
|
| 96 |
+
|
| 97 |
|
| 98 |
if __name__ == "__main__":
|
| 99 |
unittest.main()
|