ComfyUI-WanVideoWrapper / fp8_optimization.py
aliensmn's picture
Mirror from https://github.com/kijai/ComfyUI-WanVideoWrapper
cf812a0 verified
import torch
import torch.nn as nn
from .utils import log
#based on ComfyUI's and MinusZoneAI's fp8_linear optimization
def fp8_linear_forward(cls, base_dtype, input):
weight_dtype = cls.weight.dtype
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if len(input.shape) == 3:
input_shape = input.shape
scale_weight = getattr(cls, 'scale_weight', None)
if scale_weight is None:
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
else:
scale_weight = scale_weight.to(input.device).squeeze()
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
inn = input.reshape(-1, input_shape[2]).to(torch.float8_e4m3fn).contiguous() #always e4m3fn because e5m2 * e5m2 is not supported
bias = cls.bias.to(base_dtype) if cls.bias is not None else None
o = torch._scaled_mm(inn, cls.weight.t(), out_dtype=base_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
return o.reshape((-1, input_shape[1], cls.weight.shape[0]))
else:
return cls.original_forward(input.to(base_dtype))
else:
return cls.original_forward(input)
@torch.compiler.disable()
def apply_lora(weight, lora, step=None):
for lora_diff, lora_strength in zip(lora[0], lora[1]):
if isinstance(lora_strength, list):
lora_strength = lora_strength[step]
if lora_strength == 0.0:
continue
elif lora_strength == 0.0:
continue
patch_diff = torch.mm(
lora_diff[0].flatten(start_dim=1).to(weight.device),
lora_diff[1].flatten(start_dim=1).to(weight.device)
).reshape(weight.shape)
alpha = lora_diff[2] / lora_diff[1].shape[0] if lora_diff[2] is not None else 1.0
scale = lora_strength * alpha
weight = weight.add(patch_diff, alpha=scale)
return weight
def convert_fp8_linear(module, base_dtype, params_to_keep={}, scale_weight_keys=None):
log.info("FP8 matmul enabled")
for name, submodule in module.named_modules():
if not any(keyword in name for keyword in params_to_keep):
if isinstance(submodule, nn.Linear):
if scale_weight_keys is not None:
scale_key = f"{name}.scale_weight"
if scale_key in scale_weight_keys:
setattr(submodule, "scale_weight", scale_weight_keys[scale_key].float())
original_forward = submodule.forward
setattr(submodule, "original_forward", original_forward)
setattr(submodule, "forward", lambda input, m=submodule: fp8_linear_forward(m, base_dtype, input))