|
|
|
|
|
|
|
|
|
|
| import re
|
|
|
| import torch
|
|
|
| from dinov3.layers.attention import LinearKMaskedBias
|
| from dinov3.utils import named_replace
|
|
|
|
|
| EPS = 1e-12
|
|
|
|
|
| def scale(t, amax_t):
|
| max_v = torch.finfo(torch.float8_e4m3fn).max
|
| scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v
|
| t_fp8 = (t / scale_t).to(torch.float8_e4m3fn)
|
| return t_fp8, scale_t
|
|
|
|
|
| def matmul(first, amax_first, second_t, amax_second_t, bias):
|
| first_fp8, scale_first = scale(first, amax_first)
|
| second_t_fp8, scale_second_t = scale(second_t, amax_second_t)
|
|
|
|
|
|
|
| output = torch._scaled_mm(
|
| first_fp8,
|
| second_t_fp8.t(),
|
| scale_a=scale_first.new_ones((1, 1)),
|
| scale_b=scale_second_t.t().new_ones((1, 1)),
|
| bias=None,
|
| out_dtype=torch.bfloat16,
|
| use_fast_accum=False,
|
| )
|
| output = (output * scale_first * scale_second_t.t()).to(torch.bfloat16)
|
| if bias is not None:
|
| output = output + bias
|
| return output
|
|
|
|
|
| @torch.compiler.allow_in_graph
|
| class Fp8LinearFn(torch.autograd.Function):
|
| @staticmethod
|
| def forward(ctx, a, b_t, bias):
|
| amax_a = a.abs().amax(dim=-1, keepdim=True)
|
| amax_b_t = b_t.abs().amax(dim=-1, keepdim=True)
|
| out = matmul(a, amax_a, b_t, amax_b_t, bias)
|
|
|
| ctx.a_requires_grad = a.requires_grad
|
| ctx.b_requires_grad = b_t.requires_grad
|
| ctx.bias_requires_grad = bias.requires_grad if bias is not None else False
|
|
|
| ctx.save_for_backward(a, b_t, amax_b_t.max())
|
|
|
| return out
|
|
|
| @staticmethod
|
| def backward(ctx, grad_out):
|
| a, b_t, amax_b = ctx.saved_tensors
|
|
|
| if ctx.a_requires_grad:
|
| b = b_t.t().contiguous()
|
| amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True)
|
| amax_b = amax_b.repeat(b.shape[0], 1)
|
| grad_a = matmul(grad_out, amax_grad_out, b, amax_b, None)
|
| else:
|
| grad_a = None
|
| if ctx.b_requires_grad:
|
| grad_b = grad_out.t() @ a
|
| else:
|
| grad_b = None
|
| if ctx.bias_requires_grad:
|
| grad_bias = grad_out.sum(dim=0)
|
| else:
|
| grad_bias = None
|
|
|
| return grad_a, grad_b, grad_bias
|
|
|
|
|
| class Fp8Linear(torch.nn.Linear):
|
| def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias)
|
| out = out.unflatten(0, input.shape[:-1])
|
| return out
|
|
|
|
|
| class Fp8LinearKMaskedBias(LinearKMaskedBias):
|
| def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| masked_bias = self.bias * self.bias_mask if self.bias is not None else None
|
| out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, masked_bias)
|
| out = out.unflatten(0, input.shape[:-1])
|
| return out
|
|
|
|
|
| def convert_linears_to_fp8(root_module: torch.nn.Module, *, filter: str) -> torch.nn.Module:
|
| filter_re = re.compile(filter)
|
| total_count = 0
|
|
|
| def replace(module: torch.nn.Module, name: str) -> torch.nn.Module:
|
| nonlocal total_count
|
| if not isinstance(module, torch.nn.Linear) or not filter_re.search(name):
|
| return module
|
| if type(module) == torch.nn.Linear:
|
| new_cls = Fp8Linear
|
| elif type(module) == LinearKMaskedBias:
|
| new_cls = Fp8LinearKMaskedBias
|
| else:
|
| assert False, str(type(module))
|
| if module.in_features % 64 != 0 or module.out_features % 64 != 0:
|
|
|
|
|
|
|
|
|
|
|
| raise RuntimeError(
|
| "fp8 requires all dimensions to be multiples of 64 " "(consider using ffn_layer=swiglu64 or higher)"
|
| )
|
| new_module = new_cls(
|
| in_features=module.in_features,
|
| out_features=module.out_features,
|
| bias=module.bias is not None,
|
| dtype=module.weight.dtype,
|
| device=module.weight.device,
|
| )
|
| new_module.weight = module.weight
|
| new_module.bias = module.bias
|
| total_count += 1
|
| return new_module
|
|
|
| out = named_replace(replace, root_module)
|
| assert total_count > 0, "fp8: no layer found to convert"
|
|
|
| torch._dynamo.reset_code_caches()
|
| from torch._inductor.cudagraph_trees import reset_cudagraph_trees
|
|
|
| reset_cudagraph_trees()
|
| return out
|
|
|