Spaces:
Runtime error
Runtime error
| import torch | |
| import logging | |
| from typing import Tuple, Dict | |
| import comfy.float | |
| _LAYOUT_REGISTRY = {} | |
| _GENERIC_UTILS = {} | |
| def register_layout_op(torch_op, layout_type): | |
| """ | |
| Decorator to register a layout-specific operation handler. | |
| Args: | |
| torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) | |
| layout_type: Layout class (e.g., TensorCoreFP8Layout) | |
| Example: | |
| @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) | |
| def fp8_linear(func, args, kwargs): | |
| # FP8-specific linear implementation | |
| ... | |
| """ | |
| def decorator(handler_func): | |
| if torch_op not in _LAYOUT_REGISTRY: | |
| _LAYOUT_REGISTRY[torch_op] = {} | |
| _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func | |
| return handler_func | |
| return decorator | |
| def register_generic_util(torch_op): | |
| """ | |
| Decorator to register a generic utility that works for all layouts. | |
| Args: | |
| torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) | |
| Example: | |
| @register_generic_util(torch.ops.aten.detach.default) | |
| def generic_detach(func, args, kwargs): | |
| # Works for any layout | |
| ... | |
| """ | |
| def decorator(handler_func): | |
| _GENERIC_UTILS[torch_op] = handler_func | |
| return handler_func | |
| return decorator | |
| def _get_layout_from_args(args): | |
| for arg in args: | |
| if isinstance(arg, QuantizedTensor): | |
| return arg._layout_type | |
| elif isinstance(arg, (list, tuple)): | |
| for item in arg: | |
| if isinstance(item, QuantizedTensor): | |
| return item._layout_type | |
| return None | |
| def _move_layout_params_to_device(params, device): | |
| new_params = {} | |
| for k, v in params.items(): | |
| if isinstance(v, torch.Tensor): | |
| new_params[k] = v.to(device=device) | |
| else: | |
| new_params[k] = v | |
| return new_params | |
| def _copy_layout_params(params): | |
| new_params = {} | |
| for k, v in params.items(): | |
| if isinstance(v, torch.Tensor): | |
| new_params[k] = v.clone() | |
| else: | |
| new_params[k] = v | |
| return new_params | |
| def _copy_layout_params_inplace(src, dst, non_blocking=False): | |
| for k, v in src.items(): | |
| if isinstance(v, torch.Tensor): | |
| dst[k].copy_(v, non_blocking=non_blocking) | |
| else: | |
| dst[k] = v | |
| class QuantizedLayout: | |
| """ | |
| Base class for quantization layouts. | |
| A layout encapsulates the format-specific logic for quantization/dequantization | |
| and provides a uniform interface for extracting raw tensors needed for computation. | |
| New quantization formats should subclass this and implement the required methods. | |
| """ | |
| def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: | |
| raise NotImplementedError(f"{cls.__name__} must implement quantize()") | |
| def dequantize(qdata, **layout_params) -> torch.Tensor: | |
| raise NotImplementedError("TensorLayout must implement dequantize()") | |
| def get_plain_tensors(cls, qtensor) -> torch.Tensor: | |
| raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") | |
| class QuantizedTensor(torch.Tensor): | |
| """ | |
| Universal quantized tensor that works with any layout. | |
| This tensor subclass uses a pluggable layout system to support multiple | |
| quantization formats (FP8, INT4, INT8, etc.) without code duplication. | |
| The layout_type determines format-specific behavior, while common operations | |
| (detach, clone, to) are handled generically. | |
| Attributes: | |
| _qdata: The quantized tensor data | |
| _layout_type: Layout class (e.g., TensorCoreFP8Layout) | |
| _layout_params: Dict with layout-specific params (scale, zero_point, etc.) | |
| """ | |
| def __new__(cls, qdata, layout_type, layout_params): | |
| """ | |
| Create a quantized tensor. | |
| Args: | |
| qdata: The quantized data tensor | |
| layout_type: Layout class (subclass of QuantizedLayout) | |
| layout_params: Dict with layout-specific parameters | |
| """ | |
| return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) | |
| def __init__(self, qdata, layout_type, layout_params): | |
| self._qdata = qdata | |
| self._layout_type = layout_type | |
| self._layout_params = layout_params | |
| def __repr__(self): | |
| layout_name = self._layout_type | |
| param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) | |
| return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" | |
| def layout_type(self): | |
| return self._layout_type | |
| def __tensor_flatten__(self): | |
| """ | |
| Tensor flattening protocol for proper device movement. | |
| """ | |
| inner_tensors = ["_qdata"] | |
| ctx = { | |
| "layout_type": self._layout_type, | |
| } | |
| tensor_params = {} | |
| non_tensor_params = {} | |
| for k, v in self._layout_params.items(): | |
| if isinstance(v, torch.Tensor): | |
| tensor_params[k] = v | |
| else: | |
| non_tensor_params[k] = v | |
| ctx["tensor_param_keys"] = list(tensor_params.keys()) | |
| ctx["non_tensor_params"] = non_tensor_params | |
| for k, v in tensor_params.items(): | |
| attr_name = f"_layout_param_{k}" | |
| object.__setattr__(self, attr_name, v) | |
| inner_tensors.append(attr_name) | |
| return inner_tensors, ctx | |
| def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): | |
| """ | |
| Tensor unflattening protocol for proper device movement. | |
| Reconstructs the QuantizedTensor after device movement. | |
| """ | |
| layout_type = ctx["layout_type"] | |
| layout_params = dict(ctx["non_tensor_params"]) | |
| for key in ctx["tensor_param_keys"]: | |
| attr_name = f"_layout_param_{key}" | |
| layout_params[key] = inner_tensors[attr_name] | |
| return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params) | |
| def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': | |
| qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs) | |
| return cls(qdata, layout_type, layout_params) | |
| def dequantize(self) -> torch.Tensor: | |
| return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params) | |
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | |
| kwargs = kwargs or {} | |
| # Step 1: Check generic utilities first (detach, clone, to, etc.) | |
| if func in _GENERIC_UTILS: | |
| return _GENERIC_UTILS[func](func, args, kwargs) | |
| # Step 2: Check layout-specific handlers (linear, matmul, etc.) | |
| layout_type = _get_layout_from_args(args) | |
| if layout_type and func in _LAYOUT_REGISTRY: | |
| handler = _LAYOUT_REGISTRY[func].get(layout_type) | |
| if handler: | |
| return handler(func, args, kwargs) | |
| # Step 3: Fallback to dequantization | |
| if isinstance(args[0] if args else None, QuantizedTensor): | |
| logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") | |
| return cls._dequant_and_fallback(func, args, kwargs) | |
| def _dequant_and_fallback(cls, func, args, kwargs): | |
| def dequant_arg(arg): | |
| if isinstance(arg, QuantizedTensor): | |
| return arg.dequantize() | |
| elif isinstance(arg, (list, tuple)): | |
| return type(arg)(dequant_arg(a) for a in arg) | |
| return arg | |
| new_args = dequant_arg(args) | |
| new_kwargs = dequant_arg(kwargs) | |
| return func(*new_args, **new_kwargs) | |
| def data_ptr(self): | |
| return self._qdata.data_ptr() | |
| def is_pinned(self): | |
| return self._qdata.is_pinned() | |
| def is_contiguous(self, *arg, **kwargs): | |
| return self._qdata.is_contiguous(*arg, **kwargs) | |
| # ============================================================================== | |
| # Generic Utilities (Layout-Agnostic Operations) | |
| # ============================================================================== | |
| def _create_transformed_qtensor(qt, transform_fn): | |
| new_data = transform_fn(qt._qdata) | |
| new_params = _copy_layout_params(qt._layout_params) | |
| return QuantizedTensor(new_data, qt._layout_type, new_params) | |
| def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): | |
| if target_dtype is not None and target_dtype != qt.dtype: | |
| logging.warning( | |
| f"QuantizedTensor: dtype conversion requested to {target_dtype}, " | |
| f"but not supported for quantized tensors. Ignoring dtype." | |
| ) | |
| if target_layout is not None and target_layout != torch.strided: | |
| logging.warning( | |
| f"QuantizedTensor: layout change requested to {target_layout}, " | |
| f"but not supported. Ignoring layout." | |
| ) | |
| # Handle device transfer | |
| current_device = qt._qdata.device | |
| if target_device is not None: | |
| # Normalize device for comparison | |
| if isinstance(target_device, str): | |
| target_device = torch.device(target_device) | |
| if isinstance(current_device, str): | |
| current_device = torch.device(current_device) | |
| if target_device != current_device: | |
| logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") | |
| new_q_data = qt._qdata.to(device=target_device) | |
| new_params = _move_layout_params_to_device(qt._layout_params, target_device) | |
| new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) | |
| logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") | |
| return new_qt | |
| logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") | |
| return qt | |
| def generic_detach(func, args, kwargs): | |
| """Detach operation - creates a detached copy of the quantized tensor.""" | |
| qt = args[0] | |
| if isinstance(qt, QuantizedTensor): | |
| return _create_transformed_qtensor(qt, lambda x: x.detach()) | |
| return func(*args, **kwargs) | |
| def generic_clone(func, args, kwargs): | |
| """Clone operation - creates a deep copy of the quantized tensor.""" | |
| qt = args[0] | |
| if isinstance(qt, QuantizedTensor): | |
| return _create_transformed_qtensor(qt, lambda x: x.clone()) | |
| return func(*args, **kwargs) | |
| def generic_to_copy(func, args, kwargs): | |
| """Device/dtype transfer operation - handles .to(device) calls.""" | |
| qt = args[0] | |
| if isinstance(qt, QuantizedTensor): | |
| return _handle_device_transfer( | |
| qt, | |
| target_device=kwargs.get('device', None), | |
| target_dtype=kwargs.get('dtype', None), | |
| op_name="_to_copy" | |
| ) | |
| return func(*args, **kwargs) | |
| def generic_to_dtype_layout(func, args, kwargs): | |
| """Handle .to(device) calls using the dtype_layout variant.""" | |
| qt = args[0] | |
| if isinstance(qt, QuantizedTensor): | |
| return _handle_device_transfer( | |
| qt, | |
| target_device=kwargs.get('device', None), | |
| target_dtype=kwargs.get('dtype', None), | |
| target_layout=kwargs.get('layout', None), | |
| op_name="to" | |
| ) | |
| return func(*args, **kwargs) | |
| def generic_copy_(func, args, kwargs): | |
| qt_dest = args[0] | |
| src = args[1] | |
| non_blocking = args[2] if len(args) > 2 else False | |
| if isinstance(qt_dest, QuantizedTensor): | |
| if isinstance(src, QuantizedTensor): | |
| # Copy from another quantized tensor | |
| qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking) | |
| qt_dest._layout_type = src._layout_type | |
| _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking) | |
| else: | |
| # Copy from regular tensor - just copy raw data | |
| qt_dest._qdata.copy_(src) | |
| return qt_dest | |
| return func(*args, **kwargs) | |
| def generic_to_dtype(func, args, kwargs): | |
| """Handle .to(dtype) calls - dtype conversion only.""" | |
| src = args[0] | |
| if isinstance(src, QuantizedTensor): | |
| # For dtype-only conversion, just change the orig_dtype, no real cast is needed | |
| target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') | |
| src._layout_params["orig_dtype"] = target_dtype | |
| return src | |
| return func(*args, **kwargs) | |
| def generic_has_compatible_shallow_copy_type(func, args, kwargs): | |
| return True | |
| def generic_empty_like(func, args, kwargs): | |
| """Empty_like operation - creates an empty tensor with the same quantized structure.""" | |
| qt = args[0] | |
| if isinstance(qt, QuantizedTensor): | |
| # Create empty tensor with same shape and dtype as the quantized data | |
| hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"]) | |
| new_qdata = torch.empty_like(qt._qdata, **kwargs) | |
| # Handle device transfer for layout params | |
| target_device = kwargs.get('device', new_qdata.device) | |
| new_params = _move_layout_params_to_device(qt._layout_params, target_device) | |
| # Update orig_dtype if dtype is specified | |
| new_params['orig_dtype'] = hp_dtype | |
| return QuantizedTensor(new_qdata, qt._layout_type, new_params) | |
| return func(*args, **kwargs) | |
| # ============================================================================== | |
| # FP8 Layout + Operation Handlers | |
| # ============================================================================== | |
| class TensorCoreFP8Layout(QuantizedLayout): | |
| """ | |
| Storage format: | |
| - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) | |
| - scale: Scalar tensor (float32) for dequantization | |
| - orig_dtype: Original dtype before quantization (for casting back) | |
| """ | |
| def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): | |
| orig_dtype = tensor.dtype | |
| if scale is None: | |
| scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max | |
| if not isinstance(scale, torch.Tensor): | |
| scale = torch.tensor(scale) | |
| scale = scale.to(device=tensor.device, dtype=torch.float32) | |
| if inplace_ops: | |
| tensor *= (1.0 / scale).to(tensor.dtype) | |
| else: | |
| tensor = tensor * (1.0 / scale).to(tensor.dtype) | |
| if stochastic_rounding > 0: | |
| tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) | |
| else: | |
| lp_amax = torch.finfo(dtype).max | |
| torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor) | |
| tensor = tensor.to(dtype, memory_format=torch.contiguous_format) | |
| layout_params = { | |
| 'scale': scale, | |
| 'orig_dtype': orig_dtype | |
| } | |
| return tensor, layout_params | |
| def dequantize(qdata, scale, orig_dtype, **kwargs): | |
| plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) | |
| plain_tensor.mul_(scale) | |
| return plain_tensor | |
| def get_plain_tensors(cls, qtensor): | |
| return qtensor._qdata, qtensor._layout_params['scale'] | |
| QUANT_ALGOS = { | |
| "float8_e4m3fn": { | |
| "storage_t": torch.float8_e4m3fn, | |
| "parameters": {"weight_scale", "input_scale"}, | |
| "comfy_tensor_layout": "TensorCoreFP8Layout", | |
| }, | |
| } | |
| LAYOUTS = { | |
| "TensorCoreFP8Layout": TensorCoreFP8Layout, | |
| } | |
| def fp8_linear(func, args, kwargs): | |
| input_tensor = args[0] | |
| weight = args[1] | |
| bias = args[2] if len(args) > 2 else None | |
| if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): | |
| plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) | |
| plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) | |
| out_dtype = kwargs.get("out_dtype") | |
| if out_dtype is None: | |
| out_dtype = input_tensor._layout_params['orig_dtype'] | |
| weight_t = plain_weight.t() | |
| tensor_2d = False | |
| if len(plain_input.shape) == 2: | |
| tensor_2d = True | |
| plain_input = plain_input.unsqueeze(1) | |
| input_shape = plain_input.shape | |
| if len(input_shape) != 3: | |
| return None | |
| try: | |
| output = torch._scaled_mm( | |
| plain_input.reshape(-1, input_shape[2]).contiguous(), | |
| weight_t, | |
| bias=bias, | |
| scale_a=scale_a, | |
| scale_b=scale_b, | |
| out_dtype=out_dtype, | |
| ) | |
| if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 | |
| output = output[0] | |
| if not tensor_2d: | |
| output = output.reshape((-1, input_shape[1], weight.shape[0])) | |
| if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: | |
| output_scale = scale_a * scale_b | |
| output_params = { | |
| 'scale': output_scale, | |
| 'orig_dtype': input_tensor._layout_params['orig_dtype'] | |
| } | |
| return QuantizedTensor(output, "TensorCoreFP8Layout", output_params) | |
| else: | |
| return output | |
| except Exception as e: | |
| raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") | |
| # Case 2: DQ Fallback | |
| if isinstance(weight, QuantizedTensor): | |
| weight = weight.dequantize() | |
| if isinstance(input_tensor, QuantizedTensor): | |
| input_tensor = input_tensor.dequantize() | |
| return torch.nn.functional.linear(input_tensor, weight, bias) | |
| def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None): | |
| if out_dtype is None: | |
| out_dtype = input_tensor._layout_params['orig_dtype'] | |
| plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) | |
| plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) | |
| output = torch._scaled_mm( | |
| plain_input.contiguous(), | |
| plain_weight, | |
| bias=bias, | |
| scale_a=scale_a, | |
| scale_b=scale_b, | |
| out_dtype=out_dtype, | |
| ) | |
| if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 | |
| output = output[0] | |
| return output | |
| def fp8_addmm(func, args, kwargs): | |
| input_tensor = args[1] | |
| weight = args[2] | |
| bias = args[0] | |
| if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): | |
| return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None)) | |
| a = list(args) | |
| if isinstance(args[0], QuantizedTensor): | |
| a[0] = args[0].dequantize() | |
| if isinstance(args[1], QuantizedTensor): | |
| a[1] = args[1].dequantize() | |
| if isinstance(args[2], QuantizedTensor): | |
| a[2] = args[2].dequantize() | |
| return func(*a, **kwargs) | |
| def fp8_mm(func, args, kwargs): | |
| input_tensor = args[0] | |
| weight = args[1] | |
| if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): | |
| return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None)) | |
| a = list(args) | |
| if isinstance(args[0], QuantizedTensor): | |
| a[0] = args[0].dequantize() | |
| if isinstance(args[1], QuantizedTensor): | |
| a[1] = args[1].dequantize() | |
| return func(*a, **kwargs) | |
| def fp8_func(func, args, kwargs): | |
| input_tensor = args[0] | |
| if isinstance(input_tensor, QuantizedTensor): | |
| plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) | |
| ar = list(args) | |
| ar[0] = plain_input | |
| return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) | |
| return func(*args, **kwargs) | |