File size: 8,474 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""Weight casting utilities for efficient model loading."""
from src.Device import Device
import torch
import logging

logger = logging.getLogger(__name__)

def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
    """Cast a weight tensor to specified dtype and device."""
    if device is None or weight.device == device:
        if not copy and (dtype is None or weight.dtype == dtype):
            return weight
        return weight.to(dtype=dtype, copy=copy)
    r = torch.empty_like(weight, dtype=dtype, device=device)
    r.copy_(weight, non_blocking=non_blocking)
    return r


def cast_to_input(weight, input, non_blocking=False, copy=True):
    """Cast weight tensor to match input tensor's dtype and device."""
    return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)


def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
    """Cast module's bias and weight to match input tensor."""
    if input is not None:
        dtype = dtype or input.dtype
        bias_dtype = bias_dtype or dtype
        device = device or input.device

    non_blocking = Device.device_supports_non_blocking(device)
    bias = None
    if s.bias is not None:
        has_fn = s.bias_function is not None
        bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_fn)
        if has_fn:
            bias = s.bias_function(bias)

    has_fn = s.weight_function is not None
    weight = cast_to(s.weight, None, device, non_blocking=non_blocking, copy=has_fn)
    
    # Handle NVFP4 dequantization
    if getattr(s, "quant_format", None) == "nvfp4":
        from src.Utilities.Quantization import dequantize_nvfp4
        weight = dequantize_nvfp4(
            weight, 
            s.weight_scale_2, 
            s.weight_scale, 
            s.original_shape
        )
        weight = weight.to(dtype)
    else:
        weight = weight.to(dtype)

    if has_fn:
        weight = s.weight_function(weight)
    return weight, bias


class CastWeightBiasOp:
    """Mixin for cast weight/bias operations."""
    comfy_cast_weights = False
    weight_function = None
    bias_function = None


class disable_weight_init:
    """Module wrappers with disabled weight initialization."""

    class Linear(torch.nn.Linear, CastWeightBiasOp):
        def reset_parameters(self): return None
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.linear(input, weight, bias)
        def forward(self, *args, **kwargs):
            return self.forward_comfy_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)

    class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
        def reset_parameters(self): return None
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)
        def forward(self, *args, **kwargs):
            return self.forward_comfy_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)

    class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
        def reset_parameters(self): return None
        def forward_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)
        def forward(self, *args, **kwargs):
            return self.forward_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)

    class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
        def reset_parameters(self): return None
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)
        def forward(self, *args, **kwargs):
            return self.forward_comfy_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)

    class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
        def reset_parameters(self): return None
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
        def forward(self, *args, **kwargs):
            return self.forward_comfy_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)

    class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
        def reset_parameters(self): return None
        def forward_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input) if self.weight is not None else (None, None)
            return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
        def forward(self, *args, **kwargs):
            return self.forward_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)

    class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
        def reset_parameters(self): return None
        def forward_comfy_cast_weights(self, input, output_size=None):
            output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 2, self.dilation)
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.conv_transpose2d(input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
        def forward(self, *args, **kwargs):
            return self.forward_comfy_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)

    class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
        def reset_parameters(self): return None
        def forward_comfy_cast_weights(self, input, output_size=None):
            output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, 1, self.dilation)
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.conv_transpose1d(input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
        def forward(self, *args, **kwargs):
            return self.forward_comfy_cast_weights(*args, **kwargs) if self.comfy_cast_weights else super().forward(*args, **kwargs)

    class Embedding(torch.nn.Embedding, CastWeightBiasOp):
        def reset_parameters(self):
            self.bias = None
            return None
        def forward_comfy_cast_weights(self, input, out_dtype=None):
            output_dtype = out_dtype
            if self.weight.dtype in (torch.float16, torch.bfloat16):
                out_dtype = None
            weight, _ = cast_bias_weight(self, device=input.device, dtype=out_dtype)
            return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            kwargs.pop("out_dtype", None)
            return super().forward(*args, **kwargs)

    @classmethod
    def conv_nd(cls, dims, *args, **kwargs):
        """Create Conv2d/Conv3d based on dimensions."""
        if dims == 2: return cls.Conv2d(*args, **kwargs)
        if dims == 3: return cls.Conv3d(*args, **kwargs)
        raise ValueError(f"unsupported dimensions: {dims}")


class manual_cast(disable_weight_init):
    """Module wrappers with manual casting enabled by default."""
    class Linear(disable_weight_init.Linear): comfy_cast_weights = True
    class Conv1d(disable_weight_init.Conv1d): comfy_cast_weights = True
    class Conv2d(disable_weight_init.Conv2d): comfy_cast_weights = True
    class Conv3d(disable_weight_init.Conv3d): comfy_cast_weights = True
    class GroupNorm(disable_weight_init.GroupNorm): comfy_cast_weights = True
    class LayerNorm(disable_weight_init.LayerNorm): comfy_cast_weights = True
    class ConvTranspose2d(disable_weight_init.ConvTranspose2d): comfy_cast_weights = True
    class ConvTranspose1d(disable_weight_init.ConvTranspose1d): comfy_cast_weights = True
    class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True