btlm and falcon monkey patches for flash attn (#566)
Browse files
examples/cerebras/btlm-ft.yml
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
base_model: cerebras/btlm-3b-8k-base
|
| 2 |
+
base_model_config: cerebras/btlm-3b-8k-base
|
| 3 |
+
model_type: AutoModelForCausalLM
|
| 4 |
+
tokenizer_type: GPT2Tokenizer
|
| 5 |
+
trust_remote_code: true
|
| 6 |
+
tokenizer_use_fast: true
|
| 7 |
+
tokenizer_legacy: true
|
| 8 |
+
|
| 9 |
+
load_in_8bit: false
|
| 10 |
+
load_in_4bit: false
|
| 11 |
+
strict: false
|
| 12 |
+
push_dataset_to_hub:
|
| 13 |
+
hf_use_auth_token: true
|
| 14 |
+
datasets:
|
| 15 |
+
- path: mhenrichsen/alpaca_2k_test
|
| 16 |
+
type: alpaca
|
| 17 |
+
dataset_prepared_path: last_prepared_run
|
| 18 |
+
val_set_size: 0.01
|
| 19 |
+
|
| 20 |
+
adapter:
|
| 21 |
+
lora_model_dir:
|
| 22 |
+
sequence_len: 2048
|
| 23 |
+
max_packed_sequence_len:
|
| 24 |
+
sample_packing: false
|
| 25 |
+
sample_packing_eff_est:
|
| 26 |
+
sample_packing_seq_len_multiplier:
|
| 27 |
+
total_num_tokens:
|
| 28 |
+
|
| 29 |
+
lora_r:
|
| 30 |
+
lora_alpha:
|
| 31 |
+
lora_dropout:
|
| 32 |
+
lora_target_modules:
|
| 33 |
+
lora_target_linear:
|
| 34 |
+
lora_fan_in_fan_out:
|
| 35 |
+
|
| 36 |
+
wandb_project:
|
| 37 |
+
wandb_entity:
|
| 38 |
+
wandb_watch:
|
| 39 |
+
wandb_run_id:
|
| 40 |
+
wandb_log_model:
|
| 41 |
+
|
| 42 |
+
output_dir: btlm-out
|
| 43 |
+
gradient_accumulation_steps: 1
|
| 44 |
+
micro_batch_size: 1
|
| 45 |
+
num_epochs: 1
|
| 46 |
+
optimizer: adamw_torch
|
| 47 |
+
adam_beta2: 0.95
|
| 48 |
+
adam_eps: 0.000000001
|
| 49 |
+
max_grad_norm: 1.0
|
| 50 |
+
|
| 51 |
+
torchdistx_path:
|
| 52 |
+
lr_scheduler: cosine
|
| 53 |
+
lr_quadratic_warmup: true
|
| 54 |
+
learning_rate: 0.000085
|
| 55 |
+
train_on_inputs: true
|
| 56 |
+
group_by_length: false
|
| 57 |
+
bf16: true
|
| 58 |
+
fp16: false
|
| 59 |
+
tf32: true
|
| 60 |
+
|
| 61 |
+
gradient_checkpointing: false
|
| 62 |
+
early_stopping_patience:
|
| 63 |
+
resume_from_checkpoint:
|
| 64 |
+
local_rank:
|
| 65 |
+
logging_steps: 1
|
| 66 |
+
|
| 67 |
+
xformers_attention:
|
| 68 |
+
flash_attention: true
|
| 69 |
+
sdp_attention:
|
| 70 |
+
flash_optimum:
|
| 71 |
+
|
| 72 |
+
gptq_groupsize:
|
| 73 |
+
gptq_model_v1:
|
| 74 |
+
|
| 75 |
+
warmup_steps: 32
|
| 76 |
+
eval_steps:
|
| 77 |
+
save_steps:
|
| 78 |
+
save_total_limit:
|
| 79 |
+
|
| 80 |
+
debug:
|
| 81 |
+
deepspeed:
|
| 82 |
+
weight_decay: 0.1
|
| 83 |
+
special_tokens:
|
| 84 |
+
pad_token: "<|endoftext|>"
|
| 85 |
+
fsdp:
|
| 86 |
+
# - full_shard
|
| 87 |
+
# - auto_wrap
|
| 88 |
+
fsdp_config:
|
| 89 |
+
# fsdp_state_dict_type: FULL_STATE_DICT
|
| 90 |
+
# fsdp_transformer_layer_cls_to_wrap: BTLMBlock
|
src/axolotl/monkeypatch/btlm_attn_hijack_flash.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Flash attention monkey patch for cerebras btlm model
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import importlib
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from flash_attn.flash_attn_interface import flash_attn_func
|
| 11 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 12 |
+
|
| 13 |
+
LOG = logging.getLogger("axolotl")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
|
| 17 |
+
# this is a wonky hack to get the remotely loaded module
|
| 18 |
+
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
| 19 |
+
# we need to load the model here in order for modeling_btlm to be available
|
| 20 |
+
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
| 21 |
+
module_name = model_config.__class__.__module__.replace(
|
| 22 |
+
".configuration_btlm", ".modeling_btlm"
|
| 23 |
+
)
|
| 24 |
+
modeling_btlm = importlib.import_module(module_name)
|
| 25 |
+
modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access
|
| 26 |
+
flashattn_attn
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def flashattn_attn(
|
| 31 |
+
self,
|
| 32 |
+
query: torch.Tensor,
|
| 33 |
+
key: Optional[torch.Tensor] = None,
|
| 34 |
+
value: Optional[torch.Tensor] = None,
|
| 35 |
+
attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
| 36 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 37 |
+
position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
| 38 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 39 |
+
softmax_scale = (
|
| 40 |
+
1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
query = query.permute(0, 2, 1, 3)
|
| 44 |
+
key = key.permute(0, 2, 1, 3)
|
| 45 |
+
value = value.permute(0, 2, 1, 3)
|
| 46 |
+
|
| 47 |
+
# Perform Flash attention
|
| 48 |
+
attn_output = flash_attn_func(
|
| 49 |
+
query,
|
| 50 |
+
key,
|
| 51 |
+
value,
|
| 52 |
+
dropout_p=0.0, # Assuming you have this attribute
|
| 53 |
+
softmax_scale=softmax_scale, # Set this if you have specific scaling in mind
|
| 54 |
+
causal=not self.is_cross_attention, # Assuming you have this attribute
|
| 55 |
+
return_attn_probs=False, # Set this based on your needs
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Optional: Apply head mask if it's not None
|
| 59 |
+
if head_mask is not None:
|
| 60 |
+
attn_output *= head_mask
|
| 61 |
+
|
| 62 |
+
attn_output = attn_output.permute(0, 2, 1, 3)
|
| 63 |
+
|
| 64 |
+
return attn_output, None # We don't have explicit attn_weights in Flash attention
|
src/axolotl/monkeypatch/falcon_attn_hijack_flash.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Flash Attention monkey patch for Falcon
|
| 3 |
+
|
| 4 |
+
copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import transformers
|
| 11 |
+
from flash_attn import flash_attn_func
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def forward(
|
| 15 |
+
self,
|
| 16 |
+
hidden_states: torch.Tensor,
|
| 17 |
+
alibi: Optional[torch.Tensor],
|
| 18 |
+
attention_mask: torch.Tensor, # pylint: disable=unused-argument
|
| 19 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 20 |
+
head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
| 21 |
+
use_cache: bool = False,
|
| 22 |
+
output_attentions: bool = False, # pylint: disable=unused-argument
|
| 23 |
+
):
|
| 24 |
+
fused_qkv = self.query_key_value(
|
| 25 |
+
hidden_states
|
| 26 |
+
) # [batch_size, seq_length, 3 x hidden_size]
|
| 27 |
+
num_kv_heads = (
|
| 28 |
+
self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
| 29 |
+
)
|
| 30 |
+
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
| 31 |
+
(
|
| 32 |
+
query_layer,
|
| 33 |
+
key_layer,
|
| 34 |
+
value_layer,
|
| 35 |
+
) = self._split_heads( # pylint: disable=protected-access
|
| 36 |
+
fused_qkv
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
batch_size, query_length, _, _ = query_layer.shape
|
| 40 |
+
|
| 41 |
+
query_layer = query_layer.transpose(1, 2).reshape(
|
| 42 |
+
batch_size * self.num_heads, query_length, self.head_dim
|
| 43 |
+
)
|
| 44 |
+
key_layer = key_layer.transpose(1, 2).reshape(
|
| 45 |
+
batch_size * num_kv_heads,
|
| 46 |
+
query_length,
|
| 47 |
+
self.head_dim,
|
| 48 |
+
)
|
| 49 |
+
value_layer = value_layer.transpose(1, 2).reshape(
|
| 50 |
+
batch_size * num_kv_heads, query_length, self.head_dim
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
| 54 |
+
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
| 55 |
+
|
| 56 |
+
if layer_past is not None:
|
| 57 |
+
past_key, past_value = layer_past
|
| 58 |
+
# concatenate along seq_length dimension:
|
| 59 |
+
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
| 60 |
+
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
| 61 |
+
key_layer = torch.cat((past_key, key_layer), dim=1)
|
| 62 |
+
value_layer = torch.cat((past_value, value_layer), dim=1)
|
| 63 |
+
|
| 64 |
+
# unused
|
| 65 |
+
# _, kv_length, _ = key_layer.shape
|
| 66 |
+
if use_cache:
|
| 67 |
+
present = (key_layer, value_layer)
|
| 68 |
+
else:
|
| 69 |
+
present = None
|
| 70 |
+
# unused
|
| 71 |
+
# attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
| 72 |
+
query_layer_ = (
|
| 73 |
+
query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 74 |
+
.transpose(1, 2)
|
| 75 |
+
.to(torch.bfloat16)
|
| 76 |
+
)
|
| 77 |
+
key_layer_ = (
|
| 78 |
+
key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
| 79 |
+
.transpose(1, 2)
|
| 80 |
+
.to(torch.bfloat16)
|
| 81 |
+
)
|
| 82 |
+
value_layer_ = (
|
| 83 |
+
value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
| 84 |
+
.transpose(1, 2)
|
| 85 |
+
.to(torch.bfloat16)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if alibi is not None:
|
| 89 |
+
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
|
| 90 |
+
|
| 91 |
+
# below output will have shape (batch_size, seqlen, nheads, headdim)
|
| 92 |
+
attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True)
|
| 93 |
+
attn_output = attn_output.reshape(
|
| 94 |
+
batch_size, query_length, self.num_heads * self.head_dim
|
| 95 |
+
)
|
| 96 |
+
output_tensor = self.dense(attn_output)
|
| 97 |
+
return output_tensor, present
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def replace_falcon_attn_with_flash_attn():
|
| 101 |
+
transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward
|
src/axolotl/utils/models.py
CHANGED
|
@@ -100,10 +100,31 @@ def load_model(
|
|
| 100 |
base_model = cfg.base_model
|
| 101 |
base_model_config = cfg.base_model_config
|
| 102 |
model_type = cfg.model_type
|
|
|
|
| 103 |
|
| 104 |
# TODO refactor as a kwarg
|
| 105 |
load_in_8bit = cfg.load_in_8bit
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
if cfg.is_llama_derived_model and cfg.flash_attention:
|
| 108 |
if cfg.device not in ["mps", "cpu"] and not inference:
|
| 109 |
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
|
@@ -338,6 +359,9 @@ def load_model(
|
|
| 338 |
for name, module in model.named_modules():
|
| 339 |
if "norm" in name:
|
| 340 |
module.to(torch.float32)
|
|
|
|
|
|
|
|
|
|
| 341 |
if "lm_head" in name or "embed_tokens" in name:
|
| 342 |
if hasattr(module, "weight"):
|
| 343 |
module.to(torch.float32)
|
|
|
|
| 100 |
base_model = cfg.base_model
|
| 101 |
base_model_config = cfg.base_model_config
|
| 102 |
model_type = cfg.model_type
|
| 103 |
+
model_config = load_model_config(cfg)
|
| 104 |
|
| 105 |
# TODO refactor as a kwarg
|
| 106 |
load_in_8bit = cfg.load_in_8bit
|
| 107 |
|
| 108 |
+
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
|
| 109 |
+
if cfg.flash_attention:
|
| 110 |
+
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
| 111 |
+
replace_btlm_attn_with_flash_attn,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
replace_btlm_attn_with_flash_attn(cfg.base_model)
|
| 115 |
+
|
| 116 |
+
if hasattr(model_config, "model_type") and model_config.model_type in [
|
| 117 |
+
"falcon",
|
| 118 |
+
"RefinedWebModel",
|
| 119 |
+
"RefinedWeb",
|
| 120 |
+
]:
|
| 121 |
+
if cfg.flash_attention:
|
| 122 |
+
from axolotl.monkeypatch.falcon_attn_hijack_flash import (
|
| 123 |
+
replace_falcon_attn_with_flash_attn,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
replace_falcon_attn_with_flash_attn()
|
| 127 |
+
|
| 128 |
if cfg.is_llama_derived_model and cfg.flash_attention:
|
| 129 |
if cfg.device not in ["mps", "cpu"] and not inference:
|
| 130 |
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
|
|
|
| 359 |
for name, module in model.named_modules():
|
| 360 |
if "norm" in name:
|
| 361 |
module.to(torch.float32)
|
| 362 |
+
if model_config.model_type == "btlm":
|
| 363 |
+
# don't upcast lm_head for btlm
|
| 364 |
+
continue
|
| 365 |
if "lm_head" in name or "embed_tokens" in name:
|
| 366 |
if hasattr(module, "weight"):
|
| 367 |
module.to(torch.float32)
|