zaydzuhri's picture
Add files using upload-large-folder tool
bd301da verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# [Note] Getting the 'torchao' package:
# This script requires the 'torchao' package to function correctly.
# Please ensure you have this package installed from the appropriate repository.
# You can obtain it from https://github.com/pytorch/ao by following the
# installation instructions.
# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
import torch
import torch.nn as nn
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.protocols.model_converter import (
ModelConverter,
register_model_converter,
)
from torchtitan.tools.logging import logger
def _is_sm89_or_later():
# Float8 is only supported on SM89 or later (H100+ GPUs)
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
class Float8Converter(ModelConverter):
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False
float8_config = job_config.float8
if not _is_sm89_or_later():
logger.warning(
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
)
return
try:
from torchao.float8 import Float8LinearConfig
except ImportError as e:
raise ImportError(
"torchao is not installed. Please install it to use float8 linear layers."
) from e
if float8_config.recipe_name is not None and not hasattr(
Float8LinearConfig, "from_recipe_name"
):
logger.warning(
"Failed to swap to Float8Linear with recipe lookup because the torchao version "
"is too old, please install torchao v0.9.0 or later and try again",
)
return
self.enabled = True
self.filter_fqns = float8_config.filter_fqns
if float8_config.recipe_name is not None:
assert (
not float8_config.enable_fsdp_float8_all_gather
), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported"
assert (
not float8_config.force_recompute_fp8_weight_in_bwd
), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported"
self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name)
self.precompute_scale = False
logger.info(
f"Float8 training active with recipe {float8_config.recipe_name}"
)
else:
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_shard_enabled
and float8_config.enable_fsdp_float8_all_gather
)
self.config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
)
# for precompute_float8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and float8_config.precompute_float8_dynamic_scale_for_fsdp
)
logger.info("Float8 tensorwise scaled training active")
def convert(self, model: nn.Module):
return self.convert_to_float8_training(model)
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
return self.precompute_float8_dynamic_scale_for_fsdp(model)
def convert_to_float8_training(self, model: nn.Module):
"""
This function converts the linear layers of `model` to `Float8Linear`.
Note that today, only dynamic tensor scaling (the default) is supported.
This will mutate the model inplace.
"""
if not self.enabled:
return
from torchao.float8 import convert_to_float8_training
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
convert_to_float8_training(
model,
config=self.config,
module_filter_fn=self._module_filter_fn,
)
logger.info(
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
f"{self.config.enable_fsdp_float8_all_gather}"
)
def _module_filter_fn(self, mod: nn.Module, fqn: str) -> bool:
if not isinstance(mod, nn.Linear):
return False
# All dims must be divisible by 16 due to float8 tensorcore hardware requirements.
dims_multiples_of_16 = (
mod.weight.shape[0] % 16 == 0 and mod.weight.shape[1] % 16 == 0
)
# If the fqn matches any filtered fqn, then we should not convert this module.
is_filtered_fqn = any(filtered_fqn in fqn for filtered_fqn in self.filter_fqns)
return dims_multiples_of_16 and not is_filtered_fqn
def precompute_float8_dynamic_scale_for_fsdp(
self, model: nn.Module | list[nn.Module]
):
if not self.enabled:
return
if not self.precompute_scale:
return
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
models = [model] if isinstance(model, nn.Module) else model
for m in models:
precompute_float8_dynamic_scale_for_fsdp(m)
register_model_converter(Float8Converter, "float8")