| |
| |
| |
| |
| @@ -348,6 +348,99 @@ model = AutoModelForCausalLM.from_pretrained( |
| ) |
| ``` |
| |
| +### Fine-Tuning with torch.compile and Padding-Free Data Collation |
| + |
| +In addition to optimizing inference, you can also enhance the training efficiency of large language models by leveraging torch.compile during fine-tuning and using a padding-free data collator. This approach can significantly speed up training and reduce computational overhead. |
| + |
| +Here's how you can fine-tune a Llama model using SFTTrainer from the TRL library, with torch_compile enabled and a padding-free data collator: |
| + |
| +``` |
| +#################### IMPORTS ################### |
| + |
| +import math |
| +import datasets |
| +import dataclasses |
| +from transformers import ( |
| + AutoModelForCausalLM, |
| + AutoTokenizer, |
| + TrainingArguments |
| +) |
| +from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM |
| + |
| +#################### MODEL LOADING WITH FLASH ATTENTION ################### |
| + |
| +model_name = "meta-llama/Llama-3.2-1B" |
| +model = AutoModelForCausalLM.from_pretrained( |
| + model_name, |
| + attn_implementation="flash_attention_2" # Enables FlashAttention-2 |
| +) |
| +tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
| + |
| +#################### DATA PREPROCESSING (PADDING-FREE) ################### |
| + |
| +response_template = "\n### Label:" |
| +response_template_ids = tokenizer.encode( |
| + response_template, add_special_tokens=False |
| +)[2:] # Exclude special tokens |
| + |
| +data_collator = DataCollatorForCompletionOnlyLM( |
| + response_template_ids=response_template_ids, |
| + tokenizer=tokenizer, |
| + ignore_index=-100, |
| + padding_free=True # Enables padding-free collation |
| +) |
| + |
| +def format_dataset(example): |
| + return { |
| + "output": example["output"] + tokenizer.eos_token |
| + } |
| + |
| +data_files = {"train": "path/to/dataset"} # Replace with your dataset path |
| +json_dataset = datasets.load_dataset("json", data_files=data_files) |
| +formatted_train_dataset = json_dataset["train"].map(format_dataset) |
| + |
| +################# TRAINING CONFIGURATION ############################ |
| + |
| +train_args = TrainingArguments( |
| + num_train_epochs=5, |
| + per_device_train_batch_size=4, |
| + per_device_eval_batch_size=4, |
| + gradient_accumulation_steps=4, |
| + learning_rate=1e-5, |
| + weight_decay=0.0, |
| + warmup_ratio=0.03, |
| + lr_scheduler_type="cosine", |
| + logging_steps=1, |
| + include_tokens_per_second=True, |
| + save_strategy="epoch", |
| + output_dir="output", |
| + torch_compile=True, # Enables torch.compile |
| + torch_compile_backend="inductor", |
| + torch_compile_mode="default" |
| +) |
| + |
| +# Convert TrainingArguments to SFTConfig |
| +transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)] |
| +transformer_kwargs = { |
| + k: v |
| + for k, v in train_args.to_dict().items() |
| + if k in transformer_train_arg_fields |
| +} |
| +training_args = SFTConfig(**transformer_kwargs) |
| + |
| +####################### FINE-TUNING ##################### |
| + |
| +trainer = SFTTrainer( |
| + model=model, |
| + tokenizer=tokenizer, |
| + train_dataset=formatted_train_dataset, |
| + data_collator=data_collator, |
| + dataset_text_field="output", |
| + args=training_args, |
| +) |
| +trainer.train() |
| +``` |
| + |
| ### PyTorch scaled dot product attention |
| |
| Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation. |
| |
| |
| |
| |
| @@ -15,7 +15,7 @@ |
| |
| import inspect |
| import os |
| -from typing import Optional, Tuple |
| +from typing import Optional, Tuple, TypedDict |
| |
| import torch |
| import torch.nn.functional as F |
| @@ -180,6 +180,10 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids): |
| return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) |
| |
| |
| +flash_241 = is_flash_attn_greater_or_equal("2.4.1") |
| +deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" |
| + |
| + |
| def _flash_attention_forward( |
| query_states: torch.Tensor, |
| key_states: torch.Tensor, |
| @@ -194,6 +198,10 @@ def _flash_attention_forward( |
| use_top_left_mask: bool = False, |
| softcap: Optional[float] = None, |
| deterministic: bool = None, |
| + cu_seq_lens_q: Optional[torch.LongTensor] = None, |
| + cu_seq_lens_k: Optional[torch.LongTensor] = None, |
| + max_length_q: Optional[int] = None, |
| + max_length_k: Optional[int] = None, |
| ): |
| """ |
| Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
| @@ -232,9 +240,9 @@ def _flash_attention_forward( |
| ) |
| flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} |
| |
| - if is_flash_attn_greater_or_equal("2.4.1"): |
| + if flash_241: |
| if deterministic is None: |
| - deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" |
| + deterministic = deterministic_g |
| flash_kwargs["deterministic"] = deterministic |
| |
| if softcap is not None: |
| @@ -267,24 +275,32 @@ def _flash_attention_forward( |
| # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing |
| # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. |
| # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach |
| - # Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always) |
| - elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): |
| + elif position_ids is not None and ( |
| + max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()) |
| + ): |
| batch_size = query_states.size(0) |
| - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( |
| - query_states, key_states, value_states, position_ids |
| - ) |
| |
| - cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
| - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
| + if cu_seq_lens_q is None or cu_seq_lens_k is None: |
| + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = ( |
| + prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids) |
| + ) |
| + |
| + cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens |
| + max_length_q, max_length_k = max_seq_lens |
| + |
| + else: |
| + query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) |
| + key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) |
| + value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) |
| |
| attn_output = flash_attn_varlen_func( |
| query_states, |
| key_states, |
| value_states, |
| - cu_seqlens_q=cu_seqlens_q, |
| - cu_seqlens_k=cu_seqlens_k, |
| - max_seqlen_q=max_seqlen_in_batch_q, |
| - max_seqlen_k=max_seqlen_in_batch_k, |
| + cu_seqlens_q=cu_seq_lens_q, |
| + cu_seqlens_k=cu_seq_lens_k, |
| + max_seqlen_q=max_length_q, |
| + max_seqlen_k=max_length_k, |
| dropout_p=dropout, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| @@ -299,3 +315,24 @@ def _flash_attention_forward( |
| ) |
| |
| return attn_output |
| + |
| + |
| +class FlashAttentionKwargs(TypedDict, total=False): |
| + """ |
| + Keyword arguments for Flash Attention with Compile. |
| + |
| + Attributes: |
| + cu_seq_lens_q (`torch.LongTensor`, *optional*) |
| + Gets cumlative sequence length for query state. |
| + cu_seq_lens_k (`torch.LongTensor`, *optional*) |
| + Gets cumlative sequence length for key state. |
| + max_length_q (`int`, *optional*): |
| + Maximum sequence length for query state. |
| + max_length_k (`int`, *optional*): |
| + Maximum sequence length for key state. |
| + """ |
| + |
| + cu_seq_lens_q: Optional[torch.LongTensor] |
| + cu_seq_lens_k: Optional[torch.LongTensor] |
| + max_length_q: Optional[int] |
| + max_length_k: Optional[int] |
| |
| |
| |
| |
| @@ -33,12 +33,14 @@ |
| from ...cache_utils import Cache, DynamicCache, StaticCache |
| from ...generation import GenerationMixin |
| from ...modeling_attn_mask_utils import AttentionMaskConverter |
| +from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| ) |
| from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS |
| from ...modeling_utils import PreTrainedModel |
| +from ...processing_utils import Unpack |
| from ...pytorch_utils import ALL_LAYERNORM_LAYERS |
| from ...utils import ( |
| add_start_docstrings, |
| @@ -832,6 +834,7 @@ def forward( |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| @@ -913,6 +916,7 @@ def forward( |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| + **flash_attn_kwargs, |
| ) |
| |
| hidden_states = layer_outputs[0] |
| |
| |
| |
| |
| @@ -38,6 +38,7 @@ |
| ) |
| from ...modeling_utils import PreTrainedModel |
| from ...utils import ( |
| + add_code_sample_docstrings, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_flash_attn_2_available, |
| @@ -51,7 +52,11 @@ |
| if is_flash_attn_2_available(): |
| from ...modeling_flash_attention_utils import _flash_attention_forward |
| |
| -from ...modeling_flash_attention_utils import _flash_attention_forward |
| +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward |
| +from ...processing_utils import Unpack |
| + |
| + |
| +_CHECKPOINT_FOR_DOC = "dummy" |
| |
| |
| class GlmRMSNorm(nn.Module): |
| @@ -736,6 +741,7 @@ def forward( |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| @@ -817,6 +823,7 @@ def forward( |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| + **flash_attn_kwargs, |
| ) |
| |
| hidden_states = layer_outputs[0] |
| @@ -1222,6 +1229,11 @@ def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
| |
| @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING) |
| + @add_code_sample_docstrings( |
| + checkpoint=_CHECKPOINT_FOR_DOC, |
| + output_type=TokenClassifierOutput, |
| + config_class=_CONFIG_FOR_DOC, |
| + ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| |
| |
| |
| |
| @@ -46,6 +46,8 @@ |
| |
| logger = logging.get_logger(__name__) |
| |
| +_CHECKPOINT_FOR_DOC = "dummy" |
| + |
| |
| class GlmRMSNorm(Phi3RMSNorm): |
| pass |
| |
| |
| |
| |
| @@ -29,7 +29,7 @@ |
| from ...cache_utils import Cache, DynamicCache, StaticCache |
| from ...generation import GenerationMixin |
| from ...modeling_attn_mask_utils import AttentionMaskConverter |
| -from ...modeling_flash_attention_utils import _flash_attention_forward |
| +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward |
| from ...modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| @@ -39,8 +39,10 @@ |
| ) |
| from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS |
| from ...modeling_utils import PreTrainedModel |
| +from ...processing_utils import Unpack |
| from ...pytorch_utils import ALL_LAYERNORM_LAYERS |
| from ...utils import ( |
| + LossKwargs, |
| add_code_sample_docstrings, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| @@ -422,6 +424,7 @@ def forward( |
| use_cache: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 |
| + **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| if isinstance(past_key_value, StaticCache): |
| raise ValueError( |
| @@ -506,6 +509,7 @@ def forward( |
| sliding_window=getattr(self, "sliding_window", None), |
| use_top_left_mask=self._flash_attn_uses_top_left_mask, |
| is_causal=self.is_causal, |
| + **kwargs, |
| ) |
| |
| attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() |
| @@ -870,6 +874,7 @@ def forward( |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| @@ -951,6 +956,7 @@ def forward( |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| + **flash_attn_kwargs, |
| ) |
| |
| hidden_states = layer_outputs[0] |
| @@ -1102,6 +1108,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( |
| return causal_mask |
| |
| |
| +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... |
| + |
| + |
| class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
| |
| @@ -1148,7 +1157,7 @@ def forward( |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| num_logits_to_keep: int = 0, |
| - **loss_kwargs, |
| + **kwargs: Unpack[KwargsForCausalLM], |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| @@ -1198,6 +1207,7 @@ def forward( |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| + **kwargs, |
| ) |
| |
| hidden_states = outputs[0] |
| @@ -1211,7 +1221,7 @@ def forward( |
| |
| loss = None |
| if labels is not None: |
| - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs) |
| + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
| |
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| |
| |
| |
| |
| @@ -815,7 +815,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": |
| # Otherwise it passes the casts down and casts the LongTensor containing the token idxs |
| # into a HalfTensor |
| if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): |
| - self.data = {k: v.to(device=device) for k, v in self.data.items() if isinstance(v, torch.Tensor)} |
| + self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()} |
| else: |
| logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") |
| return self |
| |
| |
| |
| |
| @@ -37,6 +37,7 @@ |
| from .generic import ( |
| ContextManagers, |
| ExplicitEnum, |
| + LossKwargs, |
| ModelOutput, |
| PaddingStrategy, |
| TensorType, |
| |
| |
| |
| |
| @@ -24,7 +24,7 @@ |
| from dataclasses import fields, is_dataclass |
| from enum import Enum |
| from functools import partial, wraps |
| -from typing import Any, ContextManager, Iterable, List, Optional, Tuple |
| +from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict |
| |
| import numpy as np |
| from packaging import version |
| @@ -854,3 +854,16 @@ def wrapper(*args, **kwargs): |
| return wrapper |
| |
| return decorator |
| + |
| + |
| +class LossKwargs(TypedDict, total=False): |
| + """ |
| + Keyword arguments to be passed to the loss function |
| + |
| + Attributes: |
| + num_items_in_batch (`int`, *optional*): |
| + Number of items in the batch. It is recommended to pass it when |
| + you are doing gradient accumulation. |
| + """ |
| + |
| + num_items_in_batch: Optional[int] |
|
|