Buckets:
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import re | |
| import warnings | |
| from typing import Callable | |
| import torch | |
| from torch.distributed.tensor import DTensor, Partial, Shard | |
| # avoid division by zero when calculating scale | |
| EPS = 1e-12 | |
| def get_splitk(t): | |
| # When tensor parallelism splits the operands along the reduction dim, it's | |
| # more natural (and efficient, and accurate) to do sub-row-wise scaling, so | |
| # that each rank can compute its own scales independently. | |
| if isinstance(t, DTensor) and t.placements == (Shard(dim=1),): | |
| return t.device_mesh.size() | |
| else: | |
| return 1 | |
| def mul_tiled(a, *bs): | |
| # If b is m x n, divide a into m x n chunks and multiply each by an element of b | |
| for b in bs: | |
| a = a.unflatten(0, (b.shape[0], -1)).unflatten(-1, (b.shape[-1], -1)) | |
| a = a * b[:, None, :, None] | |
| a = a.flatten(end_dim=1).flatten(start_dim=-2) | |
| return a | |
| def apply_to_partial(fn, t, *args, **kwargs): | |
| # With tensor parallelism, _scaled_mm returns a "partial" result, but we do | |
| # manual (post-)scaling which we want to apply to each partial term | |
| # separately, thus we do this hack to "unpack" the DTensors. | |
| if isinstance(t, DTensor) and t.placements == (Partial(),): | |
| return torch.distributed.tensor.experimental.local_map(fn, [*t.placements])(t, *args, **kwargs) | |
| else: | |
| return fn(t, *args, **kwargs) | |
| def scale(t, amax_t): | |
| max_v = torch.finfo(torch.float8_e4m3fn).max | |
| scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v | |
| t_fp8 = mul_tiled(t, scale_t.reciprocal()).to(torch.float8_e4m3fn) | |
| return t_fp8, scale_t | |
| def matmul(first, amax_first, second_t, amax_second_t, bias, use_fast_accum): | |
| first_fp8, scale_first = scale(first, amax_first) | |
| second_t_fp8, scale_second_t = scale(second_t, amax_second_t) | |
| # PyTorch's row-wise scaled matmul kernel is based on CUTLASS and is quite | |
| # slow when fast_accum is disabled. Hence we fall back to an "unscaled" | |
| # matmul, which uses cuBLAS, and apply the scale manually afterwards. | |
| post_scales = [] | |
| post_bias = None | |
| if not use_fast_accum: | |
| post_scales = [scale_first, scale_second_t.t()] | |
| scale_first = scale_first.new_ones((1, 1)) | |
| scale_second_t = scale_second_t.t().new_ones((1, 1)) | |
| post_bias, bias = bias, None | |
| res = torch._scaled_mm( | |
| first_fp8, | |
| second_t_fp8.t(), | |
| scale_a=scale_first, | |
| scale_b=scale_second_t.t(), | |
| bias=bias, | |
| out_dtype=torch.bfloat16, | |
| use_fast_accum=use_fast_accum, | |
| ) | |
| res = apply_to_partial(mul_tiled, res, *post_scales).to(torch.bfloat16) | |
| if post_bias is not None: | |
| res += post_bias | |
| return res | |
| class Fp8LinearFn(torch.autograd.Function): | |
| def forward(ctx, a, b_t, bias): | |
| amax_a = a.abs().unflatten(-1, (get_splitk(a), -1)).amax(dim=-1) | |
| amax_b_t = b_t.abs().unflatten(-1, (get_splitk(b_t), -1)).amax(dim=-1) | |
| out = matmul(a, amax_a, b_t, amax_b_t, bias, use_fast_accum=True) | |
| ctx.a_requires_grad = a.requires_grad | |
| ctx.b_requires_grad = b_t.requires_grad | |
| ctx.bias_requires_grad = bias.requires_grad if bias is not None else False | |
| ctx.save_for_backward(a, b_t, amax_b_t) | |
| return out | |
| def backward(ctx, grad_out): | |
| a, b_t, amax_b_t = ctx.saved_tensors | |
| # Workaround for https://github.com/pytorch/pytorch/issues/141881. | |
| # The partitioner would pre-compute the transposed scaling of the weight | |
| # in the forward (as it's most efficient, but it actually uses too much | |
| # memory). We prevent that by making the scaling depend on the gradient | |
| # in a way that has no effect and will be optimized away later. | |
| # Care is needed to support tensor parallelism and circumvent bugs. | |
| b_t = b_t + grad_out[:1, :, None].squeeze(0) * 0 | |
| if ctx.a_requires_grad: | |
| b = b_t.t().contiguous() | |
| amax_grad_out = ( | |
| grad_out.abs().unflatten(-1, (get_splitk(grad_out), -1)).amax(dim=-1) | |
| ) | |
| amax_b = amax_b_t.t().unflatten(-1, (get_splitk(b), -1)).amax(dim=-1) | |
| amax_b = amax_b.repeat_interleave( | |
| b.shape[0] // amax_b.shape[0], dim=0, output_size=b.shape[0] | |
| ) | |
| grad_a = matmul(grad_out, amax_grad_out, b, amax_b, None, use_fast_accum=False) | |
| else: | |
| grad_a = None | |
| if ctx.b_requires_grad: | |
| grad_b = grad_out.t() @ a | |
| else: | |
| grad_b = None | |
| if ctx.bias_requires_grad: | |
| grad_bias = grad_out.sum(dim=0) | |
| else: | |
| grad_bias = None | |
| return grad_a, grad_b, grad_bias | |
| class Fp8Linear(torch.nn.Linear): | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias) | |
| out = out.unflatten(0, input.shape[:-1]) | |
| return out | |
| def named_replace(fn: Callable[[torch.nn.Module, str], torch.nn.Module], module: torch.nn.Module, name="") -> torch.nn.Module: | |
| for child_name, child_module in list(module.named_children()): | |
| full_name = f"{name}.{child_name}" if name else child_name | |
| new_child_module = named_replace(fn, child_module, full_name) | |
| setattr(module, child_name, new_child_module) | |
| module = fn(module, name) | |
| return module | |
| def convert_linears_to_fp8(root_module: torch.nn.Module, recipe: str, filter: str) -> torch.nn.Module: | |
| if recipe not in ["rowwise"]: | |
| raise RuntimeError(f"Unknown float8 recipe {recipe!r}") | |
| if recipe == "rowwise" and torch.__version__ < "2.5": | |
| # We need https://github.com/pytorch/pytorch/pull/134781. | |
| warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0") | |
| # Multi-kernel makes Inductor auto-tune between a regular "streaming"-based | |
| # reduction kernel and a "persistent" reduction kernel. Since fp8 has some | |
| # multi-pass steps (e.g., first get amax, then scale), persistent kernels | |
| # should perform better. | |
| torch._inductor.config.triton.multi_kernel = 1 | |
| filter_re = re.compile(filter) | |
| def replace(module: torch.nn.Module, name: str) -> torch.nn.Module: | |
| if not isinstance(module, torch.nn.Linear) or not filter_re.search(name): | |
| return module | |
| if type(module) == torch.nn.Linear: | |
| if recipe == "rowwise": | |
| new_module = Fp8Linear( | |
| in_features=module.in_features, | |
| out_features=module.out_features, | |
| bias=module.bias is not None, | |
| dtype=module.weight.dtype, | |
| device=module.weight.device, | |
| ) | |
| new_module.weight = module.weight | |
| new_module.bias = module.bias | |
| else: | |
| assert False, recipe | |
| else: | |
| assert False, str(type(module)) | |
| return new_module | |
| out = named_replace(replace, root_module) | |
| # Force re-compile everything | |
| torch._dynamo.reset_code_caches() | |
| from torch._inductor.cudagraph_trees import reset_cudagraph_trees | |
| reset_cudagraph_trees() | |
| return out | |
| # We need some upstream PyTorch fixes which are only present in v2.7+ or in | |
| # nightlies starting from January 7, 2025. For earlier versions, we copy-pasted | |
| # the relevant pieces of code below. | |
| if torch.__version__ < "2.7.0.dev20250107": | |
| from torch.distributed.device_mesh import DeviceMesh | |
| from torch.distributed.tensor._dtensor_spec import DTensorSpec | |
| from torch.distributed.tensor._op_schema import ( | |
| OpSchema, | |
| OpStrategy, | |
| PlacementStrategy, | |
| RuntimeSchemaInfo, | |
| ) | |
| from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies | |
| from torch.distributed.tensor._ops._math_ops import ( | |
| _infer_reduction_dims, | |
| common_reduction_strategy, | |
| ) | |
| from torch.distributed.tensor._ops.utils import ( | |
| generate_redistribute_costs, | |
| is_tensor_shardable, | |
| prod, | |
| register_op_strategy, | |
| ) | |
| from torch.distributed.tensor.placement_types import Replicate | |
| # Cherry-pick of https://github.com/pytorch/pytorch/pull/143747 | |
| LINEAR_REDUCTION_OP_MAP = { | |
| torch.ops.aten.amax.default: "max", | |
| torch.ops.aten.amin.default: "min", | |
| } | |
| def linear_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: | |
| args_schema = op_schema.args_schema | |
| input_strategy = args_schema[0] | |
| assert isinstance(input_strategy, OpStrategy) | |
| dims = None | |
| if len(op_schema.args_schema) > 1: | |
| dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim) | |
| reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims | |
| keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2]) | |
| reduction_op = LINEAR_REDUCTION_OP_MAP[op_schema.op] | |
| return common_reduction_strategy( | |
| mesh, | |
| input_strategy, | |
| reduce_dims, | |
| keep_dim=keep_dim, | |
| reduction_linear=True, | |
| reduction_op=reduction_op, | |
| ) | |
| # Cherry-pick of https://github.com/pytorch/pytorch/pull/143760 | |
| def _mm_like_strategy( | |
| mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema | |
| ) -> OpStrategy: | |
| ( | |
| self_strategy, | |
| mat2_strategy, | |
| scale_self_strategy, | |
| scale_mat2_strategy, | |
| bias_strategy, | |
| scale_result_strategy, | |
| *_, | |
| ) = op_schema.args_schema | |
| assert isinstance(self_strategy, OpStrategy) | |
| assert isinstance(mat2_strategy, OpStrategy) | |
| assert isinstance(scale_self_strategy, OpStrategy) | |
| assert isinstance(scale_mat2_strategy, OpStrategy) | |
| assert bias_strategy is None | |
| assert scale_result_strategy is None | |
| # generate all possible strategies for mm | |
| mm_strategy = gen_einsum_strategies(mm_equation, mesh) | |
| assert isinstance(mm_strategy, OpStrategy) | |
| # filter out invalid strategies and associate costs | |
| strategies = mm_strategy.strategies | |
| filtered_strategies = [] | |
| for strtg in strategies: | |
| assert isinstance(strtg, PlacementStrategy) | |
| assert strtg.input_specs is not None | |
| self_spec = strtg.input_specs[0] | |
| mat2_spec = strtg.input_specs[1] | |
| assert isinstance(self_spec, DTensorSpec) | |
| assert isinstance(mat2_spec, DTensorSpec) | |
| scale_self_spec = ( | |
| DTensorSpec(self_spec.mesh, (Replicate(),)) | |
| if prod(scale_self_strategy.shape) == 1 | |
| else self_spec | |
| ) | |
| scale_mat2_spec = ( | |
| DTensorSpec(mat2_spec.mesh, (Replicate(),)) | |
| if prod(scale_mat2_strategy.shape) == 1 | |
| else mat2_spec | |
| ) | |
| strtg.input_specs.extend([scale_self_spec, scale_mat2_spec]) | |
| if ( | |
| is_tensor_shardable(self_strategy.shape, self_spec) | |
| and is_tensor_shardable(mat2_strategy.shape, mat2_spec) | |
| and is_tensor_shardable(scale_self_strategy.shape, scale_self_spec) | |
| and is_tensor_shardable(scale_mat2_strategy.shape, scale_mat2_spec) | |
| ): | |
| redistribute_cost = [ | |
| generate_redistribute_costs(self_strategy, self_spec), | |
| generate_redistribute_costs(mat2_strategy, mat2_spec), | |
| generate_redistribute_costs(scale_self_strategy, scale_self_spec), | |
| generate_redistribute_costs(scale_mat2_strategy, scale_mat2_spec), | |
| ] | |
| strtg.redistribute_cost = redistribute_cost | |
| filtered_strategies.append(strtg) | |
| mm_strategy.strategies = filtered_strategies | |
| return mm_strategy | |
| def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: | |
| return _mm_like_strategy("mk,kn->mn", mesh, op_schema) | |
Xet Storage Details
- Size:
- 12.3 kB
- Xet hash:
- b3c9560817e879098c94d96694a16307603ffbb27b8690b3430904301cebbe39
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.