# Copyright (c) Meta Platforms, Inc. and affiliates. # # This software may be used and distributed in accordance with # the terms of the DINOv3 License Agreement. import re import torch from dinov3.layers.attention import LinearKMaskedBias from dinov3.utils import named_replace # avoid division by zero when calculating scale EPS = 1e-12 def scale(t, amax_t): max_v = torch.finfo(torch.float8_e4m3fn).max scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v t_fp8 = (t / scale_t).to(torch.float8_e4m3fn) return t_fp8, scale_t def matmul(first, amax_first, second_t, amax_second_t, bias): first_fp8, scale_first = scale(first, amax_first) second_t_fp8, scale_second_t = scale(second_t, amax_second_t) # PyTorch's row-wise scaled matmul kernel is based on CUTLASS and is quite # slow. Hence we fall back to an "unscaled" matmul, which uses cuBLAS, and # apply the scale manually afterwards. output = torch._scaled_mm( first_fp8, second_t_fp8.t(), scale_a=scale_first.new_ones((1, 1)), scale_b=scale_second_t.t().new_ones((1, 1)), bias=None, out_dtype=torch.bfloat16, use_fast_accum=False, ) output = (output * scale_first * scale_second_t.t()).to(torch.bfloat16) if bias is not None: output = output + bias return output @torch.compiler.allow_in_graph class Fp8LinearFn(torch.autograd.Function): @staticmethod def forward(ctx, a, b_t, bias): amax_a = a.abs().amax(dim=-1, keepdim=True) amax_b_t = b_t.abs().amax(dim=-1, keepdim=True) out = matmul(a, amax_a, b_t, amax_b_t, bias) ctx.a_requires_grad = a.requires_grad ctx.b_requires_grad = b_t.requires_grad ctx.bias_requires_grad = bias.requires_grad if bias is not None else False ctx.save_for_backward(a, b_t, amax_b_t.max()) return out @staticmethod def backward(ctx, grad_out): a, b_t, amax_b = ctx.saved_tensors if ctx.a_requires_grad: b = b_t.t().contiguous() amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True) amax_b = amax_b.repeat(b.shape[0], 1) grad_a = matmul(grad_out, amax_grad_out, b, amax_b, None) else: grad_a = None if ctx.b_requires_grad: grad_b = grad_out.t() @ a else: grad_b = None if ctx.bias_requires_grad: grad_bias = grad_out.sum(dim=0) else: grad_bias = None return grad_a, grad_b, grad_bias class Fp8Linear(torch.nn.Linear): def forward(self, input: torch.Tensor) -> torch.Tensor: out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias) out = out.unflatten(0, input.shape[:-1]) return out class Fp8LinearKMaskedBias(LinearKMaskedBias): def forward(self, input: torch.Tensor) -> torch.Tensor: masked_bias = self.bias * self.bias_mask if self.bias is not None else None out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, masked_bias) out = out.unflatten(0, input.shape[:-1]) return out def convert_linears_to_fp8(root_module: torch.nn.Module, *, filter: str) -> torch.nn.Module: filter_re = re.compile(filter) total_count = 0 def replace(module: torch.nn.Module, name: str) -> torch.nn.Module: nonlocal total_count if not isinstance(module, torch.nn.Linear) or not filter_re.search(name): return module if type(module) == torch.nn.Linear: new_cls = Fp8Linear elif type(module) == LinearKMaskedBias: new_cls = Fp8LinearKMaskedBias else: assert False, str(type(module)) if module.in_features % 64 != 0 or module.out_features % 64 != 0: # This is not a strict requirement, but H100 TensorCores for fp8 # operate on tiles of 64 elements anyways, and Inductor sometimes # pads inner dims to become multiples of 64. Also, if one day we # switch back to cuBLAS, it artificially requires dims to be # multiples of 16. raise RuntimeError( "fp8 requires all dimensions to be multiples of 64 " "(consider using ffn_layer=swiglu64 or higher)" ) new_module = new_cls( in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=module.weight.dtype, device=module.weight.device, ) new_module.weight = module.weight new_module.bias = module.bias total_count += 1 return new_module out = named_replace(replace, root_module) assert total_count > 0, "fp8: no layer found to convert" # Force re-compile everything torch._dynamo.reset_code_caches() from torch._inductor.cudagraph_trees import reset_cudagraph_trees reset_cudagraph_trees() return out