Spaces:
Running on Zero
Running on Zero
File size: 8,474 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """Weight casting utilities for efficient model loading."""
from src.Device import Device
import torch
import logging
logger = logging.getLogger(__name__)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
"""Cast a weight tensor to specified dtype and device."""
if device is None or weight.device == device:
if not copy and (dtype is None or weight.dtype == dtype):
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to_input(weight, input, non_blocking=False, copy=True):
"""Cast weight tensor to match input tensor's dtype and device."""
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
"""Cast module's bias and weight to match input tensor."""
if input is not None:
dtype = dtype or input.dtype
bias_dtype = bias_dtype or dtype
device = device or input.device
non_blocking = Device.device_supports_non_blocking(device)
bias = None
if s.bias is not None:
has_fn = s.bias_function is not None
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_fn)
if has_fn:
bias = s.bias_function(bias)
has_fn = s.weight_function is not None
weight = cast_to(s.weight, None, device, non_blocking=non_blocking, copy=has_fn)
# Handle NVFP4 dequantization
if getattr(s, "quant_format", None) == "nvfp4":
from src.Utilities.Quantization import dequantize_nvfp4
weight = dequantize_nvfp4(
weight,
s.weight_scale_2,
s.weight_scale,
s.original_shape
)
weight = weight.to(dtype)
else:
weight = weight.to(dtype)
if has_fn:
weight = s.weight_function(weight)
return weight, bias
class CastWeightBiasOp:
"""Mixin for cast weight/bias operations."""
comfy_cast_weights = False
weight_function = None
bias_function = None
class disable_weight_init:
"""Module wrappers with disabled weight initialization."""
class Linear(torch.nn.Linear, CastWeightBiasOp):
def reset_parameters(self): return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)
class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
def reset_parameters(self): return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
def reset_parameters(self): return None
def forward_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
return self.forward_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
def reset_parameters(self): return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
def reset_parameters(self): return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self): return None
def forward_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) if self.weight is not None else (None, None)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
return self.forward_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
def reset_parameters(self): return None
def forward_comfy_cast_weights(self, input, output_size=None):
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 2, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)
class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
def reset_parameters(self): return None
def forward_comfy_cast_weights(self, input, output_size=None):
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 1, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
def forward_comfy_cast_weights(self, input, out_dtype=None):
output_dtype = out_dtype
if self.weight.dtype in (torch.float16, torch.bfloat16):
out_dtype = None
weight, _ = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
kwargs.pop("out_dtype", None)
return super().forward(*args, **kwargs)
@classmethod
def conv_nd(cls, dims, *args, **kwargs):
"""Create Conv2d/Conv3d based on dimensions."""
if dims == 2: return cls.Conv2d(*args, **kwargs)
if dims == 3: return cls.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class manual_cast(disable_weight_init):
"""Module wrappers with manual casting enabled by default."""
class Linear(disable_weight_init.Linear): comfy_cast_weights = True
class Conv1d(disable_weight_init.Conv1d): comfy_cast_weights = True
class Conv2d(disable_weight_init.Conv2d): comfy_cast_weights = True
class Conv3d(disable_weight_init.Conv3d): comfy_cast_weights = True
class GroupNorm(disable_weight_init.GroupNorm): comfy_cast_weights = True
class LayerNorm(disable_weight_init.LayerNorm): comfy_cast_weights = True
class ConvTranspose2d(disable_weight_init.ConvTranspose2d): comfy_cast_weights = True
class ConvTranspose1d(disable_weight_init.ConvTranspose1d): comfy_cast_weights = True
class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True
|