|
|
import torch |
|
|
import torch.nn as nn |
|
|
from .utils import log |
|
|
|
|
|
|
|
|
def fp8_linear_forward(cls, base_dtype, input): |
|
|
weight_dtype = cls.weight.dtype |
|
|
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: |
|
|
if len(input.shape) == 3: |
|
|
input_shape = input.shape |
|
|
|
|
|
scale_weight = getattr(cls, 'scale_weight', None) |
|
|
if scale_weight is None: |
|
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32) |
|
|
else: |
|
|
scale_weight = scale_weight.to(input.device).squeeze() |
|
|
|
|
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32) |
|
|
|
|
|
input = torch.clamp(input, min=-448, max=448, out=input) |
|
|
inn = input.reshape(-1, input_shape[2]).to(torch.float8_e4m3fn).contiguous() |
|
|
|
|
|
bias = cls.bias.to(base_dtype) if cls.bias is not None else None |
|
|
|
|
|
o = torch._scaled_mm(inn, cls.weight.t(), out_dtype=base_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) |
|
|
|
|
|
return o.reshape((-1, input_shape[1], cls.weight.shape[0])) |
|
|
else: |
|
|
return cls.original_forward(input.to(base_dtype)) |
|
|
else: |
|
|
return cls.original_forward(input) |
|
|
|
|
|
|
|
|
@torch.compiler.disable() |
|
|
def apply_lora(weight, lora, step=None): |
|
|
for lora_diff, lora_strength in zip(lora[0], lora[1]): |
|
|
if isinstance(lora_strength, list): |
|
|
lora_strength = lora_strength[step] |
|
|
if lora_strength == 0.0: |
|
|
continue |
|
|
elif lora_strength == 0.0: |
|
|
continue |
|
|
patch_diff = torch.mm( |
|
|
lora_diff[0].flatten(start_dim=1).to(weight.device), |
|
|
lora_diff[1].flatten(start_dim=1).to(weight.device) |
|
|
).reshape(weight.shape) |
|
|
alpha = lora_diff[2] / lora_diff[1].shape[0] if lora_diff[2] is not None else 1.0 |
|
|
scale = lora_strength * alpha |
|
|
weight = weight.add(patch_diff, alpha=scale) |
|
|
return weight |
|
|
|
|
|
def convert_fp8_linear(module, base_dtype, params_to_keep={}, scale_weight_keys=None): |
|
|
log.info("FP8 matmul enabled") |
|
|
for name, submodule in module.named_modules(): |
|
|
if not any(keyword in name for keyword in params_to_keep): |
|
|
if isinstance(submodule, nn.Linear): |
|
|
if scale_weight_keys is not None: |
|
|
scale_key = f"{name}.scale_weight" |
|
|
if scale_key in scale_weight_keys: |
|
|
setattr(submodule, "scale_weight", scale_weight_keys[scale_key].float()) |
|
|
original_forward = submodule.forward |
|
|
setattr(submodule, "original_forward", original_forward) |
|
|
setattr(submodule, "forward", lambda input, m=submodule: fp8_linear_forward(m, base_dtype, input)) |
|
|
|
|
|
|