|
|
"""Modified from https://github.com/kijai/ComfyUI-MochiWrapper
|
|
|
"""
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs):
|
|
|
weight_dtype = cls.weight.dtype
|
|
|
cls.to(origin_dtype)
|
|
|
|
|
|
|
|
|
inputs = [input.to(origin_dtype) for input in inputs]
|
|
|
out = cls.original_forward(*inputs, **kwargs)
|
|
|
|
|
|
cls.to(weight_dtype)
|
|
|
return out
|
|
|
|
|
|
def replace_parameters_by_name(module, name_keywords, device):
|
|
|
from torch import nn
|
|
|
for name, param in list(module.named_parameters(recurse=False)):
|
|
|
if any(keyword in name for keyword in name_keywords):
|
|
|
if isinstance(param, nn.Parameter):
|
|
|
tensor = param.data
|
|
|
delattr(module, name)
|
|
|
setattr(module, name, tensor.to(device=device))
|
|
|
for child_name, child_module in module.named_children():
|
|
|
replace_parameters_by_name(child_module, name_keywords, device)
|
|
|
|
|
|
def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens']):
|
|
|
for name, module in model.named_modules():
|
|
|
flag = False
|
|
|
for _exclude_module_name in exclude_module_name:
|
|
|
if _exclude_module_name in name:
|
|
|
flag = True
|
|
|
if flag:
|
|
|
continue
|
|
|
for param_name, param in module.named_parameters():
|
|
|
flag = False
|
|
|
for _exclude_module_name in exclude_module_name:
|
|
|
if _exclude_module_name in param_name:
|
|
|
flag = True
|
|
|
if flag:
|
|
|
continue
|
|
|
param.data = param.data.to(torch.float8_e4m3fn)
|
|
|
|
|
|
def convert_weight_dtype_wrapper(module, origin_dtype):
|
|
|
for name, module in module.named_modules():
|
|
|
if name == "" or "embed_tokens" in name:
|
|
|
continue
|
|
|
original_forward = module.forward
|
|
|
if hasattr(module, "weight") and module.weight is not None:
|
|
|
setattr(module, "original_forward", original_forward)
|
|
|
setattr(
|
|
|
module,
|
|
|
"forward",
|
|
|
lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs)
|
|
|
)
|
|
|
|