harness / diffs /33932.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md
index 16be638498df..0a6a7e15bea0 100644
--- a/docs/source/en/llm_optims.md
+++ b/docs/source/en/llm_optims.md
@@ -348,6 +348,99 @@ model = AutoModelForCausalLM.from_pretrained(
)
```
+### Fine-Tuning with torch.compile and Padding-Free Data Collation
+
+In addition to optimizing inference, you can also enhance the training efficiency of large language models by leveraging torch.compile during fine-tuning and using a padding-free data collator. This approach can significantly speed up training and reduce computational overhead.
+
+Here's how you can fine-tune a Llama model using SFTTrainer from the TRL library, with torch_compile enabled and a padding-free data collator:
+
+```
+#################### IMPORTS ###################
+
+import math
+import datasets
+import dataclasses
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ TrainingArguments
+)
+from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
+
+#################### MODEL LOADING WITH FLASH ATTENTION ###################
+
+model_name = "meta-llama/Llama-3.2-1B"
+model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ attn_implementation="flash_attention_2" # Enables FlashAttention-2
+)
+tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
+
+#################### DATA PREPROCESSING (PADDING-FREE) ###################
+
+response_template = "\n### Label:"
+response_template_ids = tokenizer.encode(
+ response_template, add_special_tokens=False
+)[2:] # Exclude special tokens
+
+data_collator = DataCollatorForCompletionOnlyLM(
+ response_template_ids=response_template_ids,
+ tokenizer=tokenizer,
+ ignore_index=-100,
+ padding_free=True # Enables padding-free collation
+)
+
+def format_dataset(example):
+ return {
+ "output": example["output"] + tokenizer.eos_token
+ }
+
+data_files = {"train": "path/to/dataset"} # Replace with your dataset path
+json_dataset = datasets.load_dataset("json", data_files=data_files)
+formatted_train_dataset = json_dataset["train"].map(format_dataset)
+
+################# TRAINING CONFIGURATION ############################
+
+train_args = TrainingArguments(
+ num_train_epochs=5,
+ per_device_train_batch_size=4,
+ per_device_eval_batch_size=4,
+ gradient_accumulation_steps=4,
+ learning_rate=1e-5,
+ weight_decay=0.0,
+ warmup_ratio=0.03,
+ lr_scheduler_type="cosine",
+ logging_steps=1,
+ include_tokens_per_second=True,
+ save_strategy="epoch",
+ output_dir="output",
+ torch_compile=True, # Enables torch.compile
+ torch_compile_backend="inductor",
+ torch_compile_mode="default"
+)
+
+# Convert TrainingArguments to SFTConfig
+transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)]
+transformer_kwargs = {
+ k: v
+ for k, v in train_args.to_dict().items()
+ if k in transformer_train_arg_fields
+}
+training_args = SFTConfig(**transformer_kwargs)
+
+####################### FINE-TUNING #####################
+
+trainer = SFTTrainer(
+ model=model,
+ tokenizer=tokenizer,
+ train_dataset=formatted_train_dataset,
+ data_collator=data_collator,
+ dataset_text_field="output",
+ args=training_args,
+)
+trainer.train()
+```
+
### PyTorch scaled dot product attention
Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.
diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py
index da961c6060e4..045d2f6d6460 100644
--- a/src/transformers/modeling_flash_attention_utils.py
+++ b/src/transformers/modeling_flash_attention_utils.py
@@ -15,7 +15,7 @@
import inspect
import os
-from typing import Optional, Tuple
+from typing import Optional, Tuple, TypedDict
import torch
import torch.nn.functional as F
@@ -180,6 +180,10 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
+flash_241 = is_flash_attn_greater_or_equal("2.4.1")
+deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
+
+
def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
@@ -194,6 +198,10 @@ def _flash_attention_forward(
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = None,
+ cu_seq_lens_q: Optional[torch.LongTensor] = None,
+ cu_seq_lens_k: Optional[torch.LongTensor] = None,
+ max_length_q: Optional[int] = None,
+ max_length_k: Optional[int] = None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -232,9 +240,9 @@ def _flash_attention_forward(
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
- if is_flash_attn_greater_or_equal("2.4.1"):
+ if flash_241:
if deterministic is None:
- deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
+ deterministic = deterministic_g
flash_kwargs["deterministic"] = deterministic
if softcap is not None:
@@ -267,24 +275,32 @@ def _flash_attention_forward(
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
- # Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always)
- elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
+ elif position_ids is not None and (
+ max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
+ ):
batch_size = query_states.size(0)
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
- query_states, key_states, value_states, position_ids
- )
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+ if cu_seq_lens_q is None or cu_seq_lens_k is None:
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
+ prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
+ )
+
+ cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
+ max_length_q, max_length_k = max_seq_lens
+
+ else:
+ query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
+ key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
+ value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
+ cu_seqlens_q=cu_seq_lens_q,
+ cu_seqlens_k=cu_seq_lens_k,
+ max_seqlen_q=max_length_q,
+ max_seqlen_k=max_length_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
@@ -299,3 +315,24 @@ def _flash_attention_forward(
)
return attn_output
+
+
+class FlashAttentionKwargs(TypedDict, total=False):
+ """
+ Keyword arguments for Flash Attention with Compile.
+
+ Attributes:
+ cu_seq_lens_q (`torch.LongTensor`, *optional*)
+ Gets cumlative sequence length for query state.
+ cu_seq_lens_k (`torch.LongTensor`, *optional*)
+ Gets cumlative sequence length for key state.
+ max_length_q (`int`, *optional*):
+ Maximum sequence length for query state.
+ max_length_k (`int`, *optional*):
+ Maximum sequence length for key state.
+ """
+
+ cu_seq_lens_q: Optional[torch.LongTensor]
+ cu_seq_lens_k: Optional[torch.LongTensor]
+ max_length_q: Optional[int]
+ max_length_k: Optional[int]
diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py
index 9aa588be4310..b215fb6561bf 100644
--- a/src/transformers/models/cohere/modeling_cohere.py
+++ b/src/transformers/models/cohere/modeling_cohere.py
@@ -33,12 +33,14 @@
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
@@ -832,6 +834,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -913,6 +916,7 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
+ **flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py
index aad4da282b78..6354e20e33fe 100644
--- a/src/transformers/models/glm/modeling_glm.py
+++ b/src/transformers/models/glm/modeling_glm.py
@@ -38,6 +38,7 @@
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
+ add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
@@ -51,7 +52,11 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
-from ...modeling_flash_attention_utils import _flash_attention_forward
+from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
+from ...processing_utils import Unpack
+
+
+_CHECKPOINT_FOR_DOC = "dummy"
class GlmRMSNorm(nn.Module):
@@ -736,6 +741,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -817,6 +823,7 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
+ **flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
@@ -1222,6 +1229,11 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py
index 55bf89d1c56b..c26477fdc173 100644
--- a/src/transformers/models/glm/modular_glm.py
+++ b/src/transformers/models/glm/modular_glm.py
@@ -46,6 +46,8 @@
logger = logging.get_logger(__name__)
+_CHECKPOINT_FOR_DOC = "dummy"
+
class GlmRMSNorm(Phi3RMSNorm):
pass
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 617ef38e4ae3..4d95f01849d6 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -29,7 +29,7 @@
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import _flash_attention_forward
+from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -39,8 +39,10 @@
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
+ LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
@@ -422,6 +424,7 @@ def forward(
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
@@ -506,6 +509,7 @@ def forward(
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ **kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -870,6 +874,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -951,6 +956,7 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
+ **flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
@@ -1102,6 +1108,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
@@ -1148,7 +1157,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
- **loss_kwargs,
+ **kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@@ -1198,6 +1207,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
+ **kwargs,
)
hidden_states = outputs[0]
@@ -1211,7 +1221,7 @@ def forward(
loss = None
if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py
index 16c05a14028e..4f3187d510fa 100644
--- a/src/transformers/tokenization_utils_base.py
+++ b/src/transformers/tokenization_utils_base.py
@@ -815,7 +815,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
- self.data = {k: v.to(device=device) for k, v in self.data.items() if isinstance(v, torch.Tensor)}
+ self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()}
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index a781389c2fbd..2a10bcaa3c94 100755
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -37,6 +37,7 @@
from .generic import (
ContextManagers,
ExplicitEnum,
+ LossKwargs,
ModelOutput,
PaddingStrategy,
TensorType,
diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py
index a5f01fa2e0df..26ec82b20fd4 100644
--- a/src/transformers/utils/generic.py
+++ b/src/transformers/utils/generic.py
@@ -24,7 +24,7 @@
from dataclasses import fields, is_dataclass
from enum import Enum
from functools import partial, wraps
-from typing import Any, ContextManager, Iterable, List, Optional, Tuple
+from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict
import numpy as np
from packaging import version
@@ -854,3 +854,16 @@ def wrapper(*args, **kwargs):
return wrapper
return decorator
+
+
+class LossKwargs(TypedDict, total=False):
+ """
+ Keyword arguments to be passed to the loss function
+
+ Attributes:
+ num_items_in_batch (`int`, *optional*):
+ Number of items in the batch. It is recommended to pass it when
+ you are doing gradient accumulation.
+ """
+
+ num_items_in_batch: Optional[int]