| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import types |
| | from typing import TYPE_CHECKING, Any, Optional |
| |
|
| | from ..extras import logging |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from ..hparams import TrainingArguments |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]: |
| | """Create AORecipeKwargs for FP8 training with HuggingFace Accelerate. |
| | |
| | Args: |
| | training_args: Training arguments containing FP8 configuration |
| | |
| | Returns: |
| | List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise |
| | """ |
| | if not training_args.fp8: |
| | return [] |
| |
|
| | backend = getattr(training_args, "fp8_backend", "auto") |
| | logger.info_rank0(f"Creating FP8 configuration with backend: {backend}") |
| |
|
| | try: |
| | |
| | if backend == "te": |
| | from accelerate.utils import FP8RecipeKwargs |
| |
|
| | logger.info_rank0("Using Transformer Engine FP8 backend") |
| | return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")] |
| |
|
| | |
| | from accelerate.utils import AORecipeKwargs |
| |
|
| | |
| | config = None |
| | if backend == "torchao" or backend == "auto": |
| | from torchao.float8 import Float8LinearConfig |
| |
|
| | |
| | |
| | config = Float8LinearConfig.from_recipe_name("rowwise") |
| |
|
| | |
| | if hasattr(config, "enable_amax_init"): |
| | config.enable_amax_init = True |
| | if hasattr(config, "enable_pre_and_post_forward"): |
| | config.enable_pre_and_post_forward = True |
| |
|
| | |
| | |
| | def module_filter_func(module, layer_name): |
| | |
| | skip_layers = ["embed", "lm_head", "output", "classifier"] |
| | if any(skip_name in layer_name.lower() for skip_name in skip_layers): |
| | return False |
| |
|
| | |
| | if not (hasattr(module, "weight") and len(module.weight.shape) == 2): |
| | return False |
| |
|
| | |
| | weight = module.weight |
| | in_features, out_features = weight.shape[1], weight.shape[0] |
| |
|
| | |
| | if in_features % 16 != 0 or out_features % 16 != 0: |
| | logger.debug( |
| | f"Skipping layer {layer_name} with dimensions {out_features}x{in_features} (not divisible by 16)" |
| | ) |
| | return False |
| |
|
| | return True |
| |
|
| | |
| | if ( |
| | hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") |
| | and training_args.fp8_enable_fsdp_float8_all_gather |
| | ): |
| | logger.info_rank0("FSDP float8 all-gather optimization requested") |
| |
|
| | return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)] |
| | except Exception as e: |
| | logger.info_rank0(f"Failed to create FP8 configuration: {e}") |
| | return [] |
| |
|
| |
|
| | def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]: |
| | """Get the mixed precision setting for Accelerate when using FP8. |
| | |
| | Args: |
| | training_args: Training arguments containing FP8 configuration |
| | |
| | Returns: |
| | "fp8" if FP8 is enabled, None otherwise |
| | """ |
| | return "fp8" if training_args.fp8 else None |
| |
|
| |
|
| | def configure_fp8_environment(training_args: "TrainingArguments") -> None: |
| | """Configure FP8 environment for HuggingFace Accelerate. |
| | |
| | FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether |
| | DeepSpeed or FSDP is used for distributed training. This function sets up the environment |
| | variables and validates the FP8 configuration. |
| | |
| | Args: |
| | training_args: Training arguments containing FP8 configuration |
| | """ |
| | if not training_args.fp8: |
| | return |
| |
|
| | |
| | os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" |
| | logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8") |
| |
|
| | |
| | backend = getattr(training_args, "fp8_backend", "auto") |
| | if backend != "auto": |
| | os.environ["FP8_BACKEND"] = backend |
| | logger.info_rank0(f"Set FP8_BACKEND={backend}") |
| |
|
| | |
| | fp8_kwargs = create_fp8_kwargs(training_args) |
| | logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items") |
| |
|
| | |
| | if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather: |
| | os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true" |
| | logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true") |
| |
|
| | logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate") |
| |
|
| |
|
| | def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None: |
| | """Verify that FP8 training is actually working after model preparation. |
| | |
| | Args: |
| | accelerator: The HuggingFace Accelerator instance |
| | training_args: Training arguments containing FP8 configuration |
| | """ |
| | if not training_args.fp8: |
| | return |
| |
|
| | |
| | fp8_enabled = getattr(accelerator, "fp8_enabled", False) |
| | fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN") |
| |
|
| | backend = getattr(training_args, "fp8_backend", "auto") |
| | if backend == "torchao" or backend == "auto": |
| | logger.info_rank0( |
| | "FP8 training enabled with TorchAO backend. For optimal performance, " |
| | "ensure model layer dimensions are mostly divisible by 16. " |
| | "If you encounter issues, try fp8_backend='te' with Transformer Engine." |
| | ) |
| | else: |
| | logger.info_rank0(f"FP8 training enabled with {backend} backend.") |
| |
|
| | logger.info_rank0(f"Accelerate FP8 status - enabled: {fp8_enabled}, backend: {fp8_backend_type}") |
| |
|
| | if not fp8_enabled: |
| | logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.") |
| |
|
| |
|
| | def patch_accelerator_for_fp8() -> None: |
| | """Patch Accelerator to inject FP8 recipe kwargs. |
| | |
| | This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator. |
| | We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'. |
| | """ |
| | import transformer_engine.pytorch as te |
| | from accelerate import Accelerator |
| |
|
| | |
| | if getattr(Accelerator, "_te_fp8_patched", False): |
| | return |
| |
|
| | |
| | if not hasattr(te, "fp8"): |
| | te.fp8 = types.ModuleType("fp8") |
| | te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported") |
| |
|
| | try: |
| | from accelerate.utils import TERecipeKwargs as FP8Recipe |
| |
|
| | use_te_recipe = True |
| | except ImportError: |
| | from accelerate.utils import FP8RecipeKwargs as FP8Recipe |
| |
|
| | use_te_recipe = False |
| |
|
| | original_init = Accelerator.__init__ |
| |
|
| | def patched_init(self, *args, **kwargs): |
| | if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]: |
| | if use_te_recipe: |
| | kwargs["kwargs_handlers"] = [ |
| | FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max") |
| | ] |
| | else: |
| | kwargs["kwargs_handlers"] = [ |
| | FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max") |
| | ] |
| | |
| | kwargs["mixed_precision"] = "fp8" |
| | return original_init(self, *args, **kwargs) |
| |
|
| | Accelerator.__init__ = patched_init |
| | Accelerator._te_fp8_patched = True |
| |
|