| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from torchtitan.config_manager import JobConfig |
| | from torchtitan.distributed import ParallelDims |
| | from torchtitan.protocols.model_converter import ( |
| | ModelConverter, |
| | register_model_converter, |
| | ) |
| | from torchtitan.tools.logging import logger |
| |
|
| |
|
| | def _is_sm89_or_later(): |
| | |
| | return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) |
| |
|
| |
|
| | class Float8Converter(ModelConverter): |
| | def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): |
| | self.enabled = False |
| |
|
| | float8_config = job_config.float8 |
| | if not _is_sm89_or_later(): |
| | logger.warning( |
| | "Failed to swap to Float8Linear because float8 is only supported on SM89 or later", |
| | ) |
| | return |
| | try: |
| | from torchao.float8 import Float8LinearConfig |
| | except ImportError as e: |
| | raise ImportError( |
| | "torchao is not installed. Please install it to use float8 linear layers." |
| | ) from e |
| |
|
| | if float8_config.recipe_name is not None and not hasattr( |
| | Float8LinearConfig, "from_recipe_name" |
| | ): |
| | logger.warning( |
| | "Failed to swap to Float8Linear with recipe lookup because the torchao version " |
| | "is too old, please install torchao v0.9.0 or later and try again", |
| | ) |
| | return |
| |
|
| | self.enabled = True |
| | self.filter_fqns = float8_config.filter_fqns |
| |
|
| | if float8_config.recipe_name is not None: |
| | assert ( |
| | not float8_config.enable_fsdp_float8_all_gather |
| | ), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported" |
| | assert ( |
| | not float8_config.force_recompute_fp8_weight_in_bwd |
| | ), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported" |
| | self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name) |
| | self.precompute_scale = False |
| | logger.info( |
| | f"Float8 training active with recipe {float8_config.recipe_name}" |
| | ) |
| |
|
| | else: |
| | |
| | enable_fsdp_float8_all_gather = ( |
| | parallel_dims.dp_shard_enabled |
| | and float8_config.enable_fsdp_float8_all_gather |
| | ) |
| | self.config = Float8LinearConfig( |
| | enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, |
| | force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd, |
| | ) |
| | |
| | self.precompute_scale = ( |
| | enable_fsdp_float8_all_gather |
| | and float8_config.precompute_float8_dynamic_scale_for_fsdp |
| | ) |
| | logger.info("Float8 tensorwise scaled training active") |
| |
|
| | def convert(self, model: nn.Module): |
| | return self.convert_to_float8_training(model) |
| |
|
| | def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): |
| | return self.precompute_float8_dynamic_scale_for_fsdp(model) |
| |
|
| | def convert_to_float8_training(self, model: nn.Module): |
| | """ |
| | This function converts the linear layers of `model` to `Float8Linear`. |
| | Note that today, only dynamic tensor scaling (the default) is supported. |
| | This will mutate the model inplace. |
| | """ |
| | if not self.enabled: |
| | return |
| |
|
| | from torchao.float8 import convert_to_float8_training |
| |
|
| | |
| | convert_to_float8_training( |
| | model, |
| | config=self.config, |
| | module_filter_fn=self._module_filter_fn, |
| | ) |
| | logger.info( |
| | "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=" |
| | f"{self.config.enable_fsdp_float8_all_gather}" |
| | ) |
| |
|
| | def _module_filter_fn(self, mod: nn.Module, fqn: str) -> bool: |
| | if not isinstance(mod, nn.Linear): |
| | return False |
| |
|
| | |
| | dims_multiples_of_16 = ( |
| | mod.weight.shape[0] % 16 == 0 and mod.weight.shape[1] % 16 == 0 |
| | ) |
| |
|
| | |
| | is_filtered_fqn = any(filtered_fqn in fqn for filtered_fqn in self.filter_fqns) |
| |
|
| | return dims_multiples_of_16 and not is_filtered_fqn |
| |
|
| | def precompute_float8_dynamic_scale_for_fsdp( |
| | self, model: nn.Module | list[nn.Module] |
| | ): |
| | if not self.enabled: |
| | return |
| |
|
| | if not self.precompute_scale: |
| | return |
| |
|
| | from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp |
| |
|
| | models = [model] if isinstance(model, nn.Module) else model |
| | for m in models: |
| | precompute_float8_dynamic_scale_for_fsdp(m) |
| |
|
| |
|
| | register_model_converter(Float8Converter, "float8") |
| |
|