|
|
| import gguf
|
| import torch
|
| import logging
|
|
|
| import comfy.ops
|
| import comfy.lora
|
| import comfy.model_management
|
| from .dequant import dequantize_tensor, is_quantized
|
|
|
| def chained_hasattr(obj, chained_attr):
|
| probe = obj
|
| for attr in chained_attr.split('.'):
|
| if hasattr(probe, attr):
|
| probe = getattr(probe, attr)
|
| else:
|
| return False
|
| return True
|
|
|
|
|
| def get_torch_compiler_disable_decorator():
|
| def dummy_decorator(*args, **kwargs):
|
| def noop(x):
|
| return x
|
| return noop
|
|
|
| from packaging import version
|
|
|
| if not chained_hasattr(torch, "compiler.disable"):
|
| logging.info("ComfyUI-GGUF: Torch too old for torch.compile - bypassing")
|
| return dummy_decorator
|
| elif version.parse(torch.__version__) >= version.parse("2.8"):
|
| logging.info("ComfyUI-GGUF: Allowing full torch compile")
|
| return dummy_decorator
|
| if chained_hasattr(torch, "_dynamo.config.nontraceable_tensor_subclasses"):
|
| logging.info("ComfyUI-GGUF: Allowing full torch compile (nightly)")
|
| return dummy_decorator
|
| else:
|
| logging.info("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch")
|
| return torch.compiler.disable
|
|
|
| torch_compiler_disable = get_torch_compiler_disable_decorator()
|
|
|
| class GGMLTensor(torch.Tensor):
|
| """
|
| Main tensor-like class for storing quantized weights
|
| """
|
| def __init__(self, *args, tensor_type, tensor_shape, patches=[], **kwargs):
|
| super().__init__()
|
| self.tensor_type = tensor_type
|
| self.tensor_shape = tensor_shape
|
| self.patches = patches
|
|
|
| def __new__(cls, *args, tensor_type, tensor_shape, patches=[], **kwargs):
|
| return super().__new__(cls, *args, **kwargs)
|
|
|
| def to(self, *args, **kwargs):
|
| new = super().to(*args, **kwargs)
|
| new.tensor_type = getattr(self, "tensor_type", None)
|
| new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
|
| new.patches = getattr(self, "patches", []).copy()
|
| return new
|
|
|
| def clone(self, *args, **kwargs):
|
| return self
|
|
|
| def detach(self, *args, **kwargs):
|
| return self
|
|
|
| def copy_(self, *args, **kwargs):
|
|
|
| try:
|
| return super().copy_(*args, **kwargs)
|
| except Exception as e:
|
| logging.warning(f"ignoring 'copy_' on tensor: {e}")
|
|
|
| def new_empty(self, size, *args, **kwargs):
|
|
|
| new_tensor = super().new_empty(size, *args, **kwargs)
|
| return GGMLTensor(
|
| new_tensor,
|
| tensor_type = getattr(self, "tensor_type", None),
|
| tensor_shape = size,
|
| patches = getattr(self, "patches", []).copy()
|
| )
|
|
|
| @property
|
| def shape(self):
|
| if not hasattr(self, "tensor_shape"):
|
| self.tensor_shape = self.size()
|
| return self.tensor_shape
|
|
|
| class GGMLLayer(torch.nn.Module):
|
| """
|
| This (should) be responsible for de-quantizing on the fly
|
| """
|
| comfy_cast_weights = True
|
| dequant_dtype = None
|
| patch_dtype = None
|
| largest_layer = False
|
| torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
|
|
|
| def is_ggml_quantized(self, *, weight=None, bias=None):
|
| if weight is None:
|
| weight = self.weight
|
| if bias is None:
|
| bias = self.bias
|
| return is_quantized(weight) or is_quantized(bias)
|
|
|
| def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
| weight, bias = state_dict.get(f"{prefix}weight"), state_dict.get(f"{prefix}bias")
|
|
|
| if self.is_ggml_quantized(weight=weight, bias=bias) or isinstance(self, torch.nn.Linear):
|
| return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
|
|
| if isinstance(self, torch.nn.Embedding) and self.weight.shape[0] >= (64 * 1024):
|
| return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
| return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
|
|
| def ggml_load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
| prefix_len = len(prefix)
|
| for k,v in state_dict.items():
|
| if k[prefix_len:] == "weight":
|
| self.weight = torch.nn.Parameter(v, requires_grad=False)
|
| elif k[prefix_len:] == "bias" and v is not None:
|
| self.bias = torch.nn.Parameter(v, requires_grad=False)
|
| else:
|
| unexpected_keys.append(k)
|
|
|
|
|
| if self.weight is None and isinstance(self, torch.nn.Linear):
|
| v = torch.zeros(self.in_features, self.out_features)
|
| self.weight = torch.nn.Parameter(v, requires_grad=False)
|
| missing_keys.append(prefix+"weight")
|
|
|
|
|
| if getattr(self.weight, "is_largest_weight", False):
|
| self.largest_layer = True
|
|
|
| def _save_to_state_dict(self, *args, **kwargs):
|
| if self.is_ggml_quantized():
|
| return self.ggml_save_to_state_dict(*args, **kwargs)
|
| return super()._save_to_state_dict(*args, **kwargs)
|
|
|
| def ggml_save_to_state_dict(self, destination, prefix, keep_vars):
|
|
|
| weight = torch.zeros_like(self.weight, device=torch.device("meta"))
|
| destination[prefix + "weight"] = weight
|
| if self.bias is not None:
|
| bias = torch.zeros_like(self.bias, device=torch.device("meta"))
|
| destination[prefix + "bias"] = bias
|
|
|
|
|
| if self.largest_layer:
|
| shape = getattr(self.weight, "tensor_shape", self.weight.shape)
|
| dtype = self.dequant_dtype if self.dequant_dtype and self.dequant_dtype != "target" else torch.float16
|
| temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype)
|
| destination[prefix + "temp.weight"] = temp
|
|
|
| return
|
|
|
| destination[prefix + "weight"] = self.get_weight(self.weight)
|
| if bias is not None:
|
| destination[prefix + "bias"] = self.get_weight(self.bias)
|
|
|
| def get_weight(self, tensor, dtype):
|
| if tensor is None:
|
| return
|
|
|
|
|
| patch_list = []
|
| device = tensor.device
|
| for patches, key in getattr(tensor, "patches", []):
|
| patch_list += move_patch_to_device(patches, device)
|
|
|
|
|
| weight = dequantize_tensor(tensor, dtype, self.dequant_dtype)
|
|
|
|
|
| if isinstance(weight, GGMLTensor):
|
| weight = torch.Tensor(weight)
|
|
|
|
|
| if len(patch_list) > 0:
|
| if self.patch_dtype is None:
|
| weight = comfy.lora.calculate_weight(patch_list, weight, key)
|
| else:
|
|
|
| patch_dtype = dtype if self.patch_dtype == "target" else self.patch_dtype
|
| weight = comfy.lora.calculate_weight(patch_list, weight, key, patch_dtype)
|
| return weight
|
|
|
| @torch_compiler_disable()
|
| def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
| if input is not None:
|
| if dtype is None:
|
| dtype = getattr(input, "dtype", torch.float32)
|
| if bias_dtype is None:
|
| bias_dtype = dtype
|
| if device is None:
|
| device = input.device
|
|
|
| bias = None
|
| non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
| if s.bias is not None:
|
| bias = s.get_weight(s.bias.to(device), dtype)
|
| bias = comfy.ops.cast_to(bias, bias_dtype, device, non_blocking=non_blocking, copy=False)
|
|
|
| weight = s.get_weight(s.weight.to(device), dtype)
|
| weight = comfy.ops.cast_to(weight, dtype, device, non_blocking=non_blocking, copy=False)
|
| return weight, bias
|
|
|
| def forward_comfy_cast_weights(self, input, *args, **kwargs):
|
| if self.is_ggml_quantized():
|
| out = self.forward_ggml_cast_weights(input, *args, **kwargs)
|
| else:
|
| out = super().forward_comfy_cast_weights(input, *args, **kwargs)
|
|
|
|
|
| if isinstance(out, GGMLTensor):
|
| out = torch.Tensor(out)
|
| return out
|
|
|
| def forward_ggml_cast_weights(self, input):
|
| raise NotImplementedError
|
|
|
| class GGMLOps(comfy.ops.manual_cast):
|
| """
|
| Dequantize weights on the fly before doing the compute
|
| """
|
| class Linear(GGMLLayer, comfy.ops.manual_cast.Linear):
|
| def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
| torch.nn.Module.__init__(self)
|
|
|
|
|
|
|
| self.in_features = in_features
|
| self.out_features = out_features
|
| self.weight = None
|
| self.bias = None
|
|
|
| def forward_ggml_cast_weights(self, input):
|
| weight, bias = self.cast_bias_weight(input)
|
| return torch.nn.functional.linear(input, weight, bias)
|
|
|
| class Conv2d(GGMLLayer, comfy.ops.manual_cast.Conv2d):
|
| def forward_ggml_cast_weights(self, input):
|
| weight, bias = self.cast_bias_weight(input)
|
| return self._conv_forward(input, weight, bias)
|
|
|
| class Embedding(GGMLLayer, comfy.ops.manual_cast.Embedding):
|
| def forward_ggml_cast_weights(self, input, out_dtype=None):
|
| output_dtype = out_dtype
|
| if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
| out_dtype = None
|
| weight, _bias = self.cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
| return torch.nn.functional.embedding(
|
| input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
|
| ).to(dtype=output_dtype)
|
|
|
| class LayerNorm(GGMLLayer, comfy.ops.manual_cast.LayerNorm):
|
| def forward_ggml_cast_weights(self, input):
|
| if self.weight is None:
|
| return super().forward_comfy_cast_weights(input)
|
| weight, bias = self.cast_bias_weight(input)
|
| return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
|
|
| class GroupNorm(GGMLLayer, comfy.ops.manual_cast.GroupNorm):
|
| def forward_ggml_cast_weights(self, input):
|
| weight, bias = self.cast_bias_weight(input)
|
| return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
|
|
| def move_patch_to_device(item, device):
|
| if isinstance(item, torch.Tensor):
|
| return item.to(device, non_blocking=True)
|
| elif isinstance(item, tuple):
|
| return tuple(move_patch_to_device(x, device) for x in item)
|
| elif isinstance(item, list):
|
| return [move_patch_to_device(x, device) for x in item]
|
| else:
|
| return item
|
|
|
|
|
| def manual_resize_tensor(tensor, target_shape):
|
| """
|
| Manually trims a tensor to match a target shape.
|
| Handles n-dimensional slicing automatically.
|
| """
|
| if tensor.shape == target_shape:
|
| return tensor
|
|
|
|
|
|
|
| slices = []
|
| for i, dim in enumerate(target_shape):
|
| slices.append(slice(0, dim))
|
|
|
| return tensor[tuple(slices)] |