mihaiciorobitca commited on
Commit
c7c90f2
·
verified ·
1 Parent(s): babf923

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ComfyUI/comfy/image_encoders/dino2.py +141 -0
  2. ComfyUI/comfy/k_diffusion/sa_solver.py +121 -0
  3. ComfyUI/comfy/k_diffusion/sampling.py +1761 -0
  4. ComfyUI/comfy/ldm/common_dit.py +16 -0
  5. ComfyUI/comfy/model_detection.py +910 -0
  6. ComfyUI/comfy/model_patcher.py +1215 -0
  7. ComfyUI/comfy/ops.py +441 -0
  8. ComfyUI/comfy/patcher_extension.py +157 -0
  9. ComfyUI/comfy/sample.py +52 -0
  10. ComfyUI/comfy/samplers.py +1143 -0
  11. ComfyUI/comfy/sd1_clip.py +687 -0
  12. ComfyUI/comfy/sd1_clip_config.json +25 -0
  13. ComfyUI/comfy/sd1_tokenizer/merges.txt +0 -0
  14. ComfyUI/comfy/sd1_tokenizer/tokenizer_config.json +34 -0
  15. ComfyUI/comfy/sd1_tokenizer/vocab.json +0 -0
  16. ComfyUI/comfy/supported_models.py +1235 -0
  17. ComfyUI/comfy/supported_models_base.py +119 -0
  18. ComfyUI/comfy/t2i_adapter/adapter.py +299 -0
  19. ComfyUI/comfy/taesd/taesd.py +79 -0
  20. ComfyUI/comfy/text_encoders/ace.py +153 -0
  21. ComfyUI/comfy/text_encoders/ace_text_cleaners.py +395 -0
  22. ComfyUI/comfy/text_encoders/aura_t5.py +22 -0
  23. ComfyUI/comfy/text_encoders/bert.py +143 -0
  24. ComfyUI/comfy/text_encoders/cosmos.py +42 -0
  25. ComfyUI/comfy/text_encoders/flux.py +70 -0
  26. ComfyUI/comfy/text_encoders/genmo.py +38 -0
  27. ComfyUI/comfy/text_encoders/hidream.py +155 -0
  28. ComfyUI/comfy/text_encoders/hunyuan_video.py +159 -0
  29. ComfyUI/comfy/text_encoders/hydit.py +81 -0
  30. ComfyUI/comfy/text_encoders/hydit_clip.json +35 -0
  31. ComfyUI/comfy/text_encoders/llama.py +358 -0
  32. ComfyUI/comfy/text_encoders/long_clipl.py +27 -0
  33. ComfyUI/comfy/text_encoders/lt.py +18 -0
  34. ComfyUI/comfy/text_encoders/lumina2.py +39 -0
  35. ComfyUI/comfy/text_encoders/mt5_config_xl.json +22 -0
  36. ComfyUI/comfy/text_encoders/omnigen2.py +44 -0
  37. ComfyUI/comfy/text_encoders/pixart_t5.py +42 -0
  38. ComfyUI/comfy/text_encoders/sd2_clip.py +23 -0
  39. ComfyUI/comfy/text_encoders/sd2_clip_config.json +23 -0
  40. ComfyUI/comfy/text_encoders/sd3_clip.py +166 -0
  41. ComfyUI/comfy/text_encoders/t5.py +249 -0
  42. ComfyUI/comfy/text_encoders/t5_config_base.json +22 -0
  43. ComfyUI/comfy/text_encoders/t5_config_xxl.json +22 -0
  44. ComfyUI/comfy/text_encoders/t5_old_config_xxl.json +22 -0
  45. ComfyUI/comfy/text_encoders/umt5_config_base.json +22 -0
  46. ComfyUI/comfy/utils.py +1104 -0
  47. ComfyUI/comfy/weight_adapter/__init__.py +34 -0
  48. ComfyUI/comfy/weight_adapter/boft.py +115 -0
  49. ComfyUI/comfy_api/feature_flags.py +69 -0
  50. ComfyUI/comfy_api_nodes/README.md +65 -0
ComfyUI/comfy/image_encoders/dino2.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from comfy.text_encoders.bert import BertAttention
3
+ import comfy.model_management
4
+ from comfy.ldm.modules.attention import optimized_attention_for_device
5
+
6
+
7
+ class Dino2AttentionOutput(torch.nn.Module):
8
+ def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
9
+ super().__init__()
10
+ self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
11
+
12
+ def forward(self, x):
13
+ return self.dense(x)
14
+
15
+
16
+ class Dino2AttentionBlock(torch.nn.Module):
17
+ def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
18
+ super().__init__()
19
+ self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
20
+ self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
21
+
22
+ def forward(self, x, mask, optimized_attention):
23
+ return self.output(self.attention(x, mask, optimized_attention))
24
+
25
+
26
+ class LayerScale(torch.nn.Module):
27
+ def __init__(self, dim, dtype, device, operations):
28
+ super().__init__()
29
+ self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
30
+
31
+ def forward(self, x):
32
+ return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
33
+
34
+
35
+ class SwiGLUFFN(torch.nn.Module):
36
+ def __init__(self, dim, dtype, device, operations):
37
+ super().__init__()
38
+ in_features = out_features = dim
39
+ hidden_features = int(dim * 4)
40
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
41
+
42
+ self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
43
+ self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
44
+
45
+ def forward(self, x):
46
+ x = self.weights_in(x)
47
+ x1, x2 = x.chunk(2, dim=-1)
48
+ x = torch.nn.functional.silu(x1) * x2
49
+ return self.weights_out(x)
50
+
51
+
52
+ class Dino2Block(torch.nn.Module):
53
+ def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
54
+ super().__init__()
55
+ self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
56
+ self.layer_scale1 = LayerScale(dim, dtype, device, operations)
57
+ self.layer_scale2 = LayerScale(dim, dtype, device, operations)
58
+ self.mlp = SwiGLUFFN(dim, dtype, device, operations)
59
+ self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
60
+ self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
61
+
62
+ def forward(self, x, optimized_attention):
63
+ x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
64
+ x = x + self.layer_scale2(self.mlp(self.norm2(x)))
65
+ return x
66
+
67
+
68
+ class Dino2Encoder(torch.nn.Module):
69
+ def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
70
+ super().__init__()
71
+ self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
72
+
73
+ def forward(self, x, intermediate_output=None):
74
+ optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
75
+
76
+ if intermediate_output is not None:
77
+ if intermediate_output < 0:
78
+ intermediate_output = len(self.layer) + intermediate_output
79
+
80
+ intermediate = None
81
+ for i, l in enumerate(self.layer):
82
+ x = l(x, optimized_attention)
83
+ if i == intermediate_output:
84
+ intermediate = x.clone()
85
+ return x, intermediate
86
+
87
+
88
+ class Dino2PatchEmbeddings(torch.nn.Module):
89
+ def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
90
+ super().__init__()
91
+ self.projection = operations.Conv2d(
92
+ in_channels=num_channels,
93
+ out_channels=dim,
94
+ kernel_size=patch_size,
95
+ stride=patch_size,
96
+ bias=True,
97
+ dtype=dtype,
98
+ device=device
99
+ )
100
+
101
+ def forward(self, pixel_values):
102
+ return self.projection(pixel_values).flatten(2).transpose(1, 2)
103
+
104
+
105
+ class Dino2Embeddings(torch.nn.Module):
106
+ def __init__(self, dim, dtype, device, operations):
107
+ super().__init__()
108
+ patch_size = 14
109
+ image_size = 518
110
+
111
+ self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
112
+ self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
113
+ self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
114
+ self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
115
+
116
+ def forward(self, pixel_values):
117
+ x = self.patch_embeddings(pixel_values)
118
+ # TODO: mask_token?
119
+ x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
120
+ x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
121
+ return x
122
+
123
+
124
+ class Dinov2Model(torch.nn.Module):
125
+ def __init__(self, config_dict, dtype, device, operations):
126
+ super().__init__()
127
+ num_layers = config_dict["num_hidden_layers"]
128
+ dim = config_dict["hidden_size"]
129
+ heads = config_dict["num_attention_heads"]
130
+ layer_norm_eps = config_dict["layer_norm_eps"]
131
+
132
+ self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
133
+ self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
134
+ self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
135
+
136
+ def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
137
+ x = self.embeddings(pixel_values)
138
+ x, i = self.encoder(x, intermediate_output=intermediate_output)
139
+ x = self.layernorm(x)
140
+ pooled_output = x[:, 0, :]
141
+ return x, i, pooled_output, None
ComfyUI/comfy/k_diffusion/sa_solver.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
2
+ # Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
3
+ # Codebase ref: https://github.com/scxue/SA-Solver
4
+
5
+ import math
6
+ from typing import Union, Callable
7
+ import torch
8
+
9
+
10
+ def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
11
+ """Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
12
+
13
+ Integral of exp((1 + tau^2) * x) * x^p dx
14
+ = product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
15
+ with base case p=0 where integral equals product_terms[0].
16
+
17
+ where
18
+ product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
19
+
20
+ Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
21
+ Return coefficients used by the SA-Solver in data prediction mode.
22
+
23
+ Args:
24
+ s: Start time s.
25
+ t: End time t.
26
+ solver_order: Current order of the solver.
27
+ tau_t: Stochastic strength parameter in the SDE.
28
+
29
+ Returns:
30
+ Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,).
31
+ """
32
+ tau_mul = 1 + tau_t ** 2
33
+ h = t - s
34
+ p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
35
+
36
+ # product_terms after factoring out exp((1 + tau^2) * t)
37
+ # Includes (1 + tau^2) factor from outside the integral
38
+ product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
39
+
40
+ # Lower triangular recursive coefficient matrix
41
+ # Accumulates recursive coefficients based on p / (1 + tau^2)
42
+ recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
43
+ log_factorial = (p + 1).lgamma()
44
+ recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
45
+ if tau_t > 0:
46
+ recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
47
+ signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
48
+ recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
49
+
50
+ return recursive_coeff_mat @ product_terms_factored
51
+
52
+
53
+ def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
54
+ """Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
55
+ tau_mul = 1 + tau_t ** 2
56
+ h = lambda_t - lambda_s
57
+ alpha_t = sigma_next * lambda_t.exp()
58
+ if is_corrector_step:
59
+ # Simplified 1-step (order-2) corrector
60
+ b_1 = alpha_t * (0.5 * tau_mul * h)
61
+ b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
62
+ else:
63
+ # Simplified 2-step predictor
64
+ b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
65
+ b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
66
+ return torch.stack([b_2, b_1])
67
+
68
+
69
+ def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
70
+ """Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
71
+
72
+ The solver order corresponds to the number of input lambdas (half-logSNR points).
73
+
74
+ Args:
75
+ sigma_next: Sigma at end time t.
76
+ curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
77
+ lambda_s: Lambda at start time s.
78
+ lambda_t: Lambda at end time t.
79
+ tau_t: Stochastic strength parameter in the SDE.
80
+ simple_order_2: Whether to enable the simple order-2 scheme.
81
+ is_corrector_step: Flag for corrector step in simple order-2 mode.
82
+
83
+ Returns:
84
+ b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
85
+ """
86
+ num_timesteps = curr_lambdas.shape[0]
87
+
88
+ if simple_order_2 and num_timesteps == 2:
89
+ return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
90
+
91
+ # Compute coefficients by solving a linear system from Lagrange basis interpolation
92
+ exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
93
+ vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
94
+ lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
95
+
96
+ # (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
97
+ # = sigma_t * exp(lambda_t) = alpha_t
98
+ # exp((1 + tau^2) * lambda_t) is extracted from the integral
99
+ alpha_t = sigma_next * lambda_t.exp()
100
+ return alpha_t * lagrange_integrals
101
+
102
+
103
+ def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
104
+ """Return a function that controls the stochasticity of SA-Solver.
105
+
106
+ When eta = 0, SA-Solver runs as ODE. The official approach uses
107
+ time t to determine the SDE interval, while here we use sigma instead.
108
+
109
+ See:
110
+ https://github.com/scxue/SA-Solver/blob/main/README.md
111
+ """
112
+
113
+ def tau_func(sigma: Union[torch.Tensor, float]) -> float:
114
+ if eta <= 0:
115
+ return 0.0 # ODE
116
+
117
+ if isinstance(sigma, torch.Tensor):
118
+ sigma = sigma.item()
119
+ return eta if start_sigma >= sigma >= end_sigma else 0.0
120
+
121
+ return tau_func
ComfyUI/comfy/k_diffusion/sampling.py ADDED
@@ -0,0 +1,1761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+
4
+ from scipy import integrate
5
+ import torch
6
+ from torch import nn
7
+ import torchsde
8
+ from tqdm.auto import trange, tqdm
9
+
10
+ from . import utils
11
+ from . import deis
12
+ from . import sa_solver
13
+ import comfy.model_patcher
14
+ import comfy.model_sampling
15
+
16
+ def append_zero(x):
17
+ return torch.cat([x, x.new_zeros([1])])
18
+
19
+
20
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
21
+ """Constructs the noise schedule of Karras et al. (2022)."""
22
+ ramp = torch.linspace(0, 1, n, device=device)
23
+ min_inv_rho = sigma_min ** (1 / rho)
24
+ max_inv_rho = sigma_max ** (1 / rho)
25
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
26
+ return append_zero(sigmas).to(device)
27
+
28
+
29
+ def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
30
+ """Constructs an exponential noise schedule."""
31
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
32
+ return append_zero(sigmas)
33
+
34
+
35
+ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
36
+ """Constructs an polynomial in log sigma noise schedule."""
37
+ ramp = torch.linspace(1, 0, n, device=device) ** rho
38
+ sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
39
+ return append_zero(sigmas)
40
+
41
+
42
+ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
43
+ """Constructs a continuous VP noise schedule."""
44
+ t = torch.linspace(1, eps_s, n, device=device)
45
+ sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
46
+ return append_zero(sigmas)
47
+
48
+
49
+ def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
50
+ """Constructs the noise schedule proposed by Tiankai et al. (2024). """
51
+ epsilon = 1e-5 # avoid log(0)
52
+ x = torch.linspace(0, 1, n, device=device)
53
+ clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
54
+ lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
55
+ sigmas = clamp(torch.exp(lmb))
56
+ return sigmas
57
+
58
+
59
+
60
+ def to_d(x, sigma, denoised):
61
+ """Converts a denoiser output to a Karras ODE derivative."""
62
+ return (x - denoised) / utils.append_dims(sigma, x.ndim)
63
+
64
+
65
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
66
+ """Calculates the noise level (sigma_down) to step down to and the amount
67
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
68
+ if not eta:
69
+ return sigma_to, 0.
70
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
71
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
72
+ return sigma_down, sigma_up
73
+
74
+
75
+ def default_noise_sampler(x, seed=None):
76
+ if seed is not None:
77
+ generator = torch.Generator(device=x.device)
78
+ generator.manual_seed(seed)
79
+ else:
80
+ generator = None
81
+
82
+ return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator)
83
+
84
+
85
+ class BatchedBrownianTree:
86
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
87
+
88
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
89
+ self.cpu_tree = True
90
+ if "cpu" in kwargs:
91
+ self.cpu_tree = kwargs.pop("cpu")
92
+ t0, t1, self.sign = self.sort(t0, t1)
93
+ w0 = kwargs.get('w0', torch.zeros_like(x))
94
+ if seed is None:
95
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
96
+ self.batched = True
97
+ try:
98
+ assert len(seed) == x.shape[0]
99
+ w0 = w0[0]
100
+ except TypeError:
101
+ seed = [seed]
102
+ self.batched = False
103
+ if self.cpu_tree:
104
+ self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
105
+ else:
106
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
107
+
108
+ @staticmethod
109
+ def sort(a, b):
110
+ return (a, b, 1) if a < b else (b, a, -1)
111
+
112
+ def __call__(self, t0, t1):
113
+ t0, t1, sign = self.sort(t0, t1)
114
+ if self.cpu_tree:
115
+ w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
116
+ else:
117
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
118
+
119
+ return w if self.batched else w[0]
120
+
121
+
122
+ class BrownianTreeNoiseSampler:
123
+ """A noise sampler backed by a torchsde.BrownianTree.
124
+
125
+ Args:
126
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
127
+ random samples.
128
+ sigma_min (float): The low end of the valid interval.
129
+ sigma_max (float): The high end of the valid interval.
130
+ seed (int or List[int]): The random seed. If a list of seeds is
131
+ supplied instead of a single integer, then the noise sampler will
132
+ use one BrownianTree per batch item, each with its own seed.
133
+ transform (callable): A function that maps sigma to the sampler's
134
+ internal timestep.
135
+ """
136
+
137
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
138
+ self.transform = transform
139
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
140
+ self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
141
+
142
+ def __call__(self, sigma, sigma_next):
143
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
144
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
145
+
146
+
147
+ def sigma_to_half_log_snr(sigma, model_sampling):
148
+ """Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
149
+ if isinstance(model_sampling, comfy.model_sampling.CONST):
150
+ # log((1 - t) / t) = log((1 - sigma) / sigma)
151
+ return sigma.logit().neg()
152
+ return sigma.log().neg()
153
+
154
+
155
+ def half_log_snr_to_sigma(half_log_snr, model_sampling):
156
+ """Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
157
+ if isinstance(model_sampling, comfy.model_sampling.CONST):
158
+ # 1 / (1 + exp(half_log_snr))
159
+ return half_log_snr.neg().sigmoid()
160
+ return half_log_snr.neg().exp()
161
+
162
+
163
+ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
164
+ """Adjust the first sigma to avoid invalid logSNR."""
165
+ if len(sigmas) <= 1:
166
+ return sigmas
167
+ if isinstance(model_sampling, comfy.model_sampling.CONST):
168
+ if sigmas[0] >= 1:
169
+ sigmas = sigmas.clone()
170
+ sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
171
+ return sigmas
172
+
173
+
174
+ @torch.no_grad()
175
+ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
176
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
177
+ extra_args = {} if extra_args is None else extra_args
178
+ s_in = x.new_ones([x.shape[0]])
179
+ for i in trange(len(sigmas) - 1, disable=disable):
180
+ if s_churn > 0:
181
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
182
+ sigma_hat = sigmas[i] * (gamma + 1)
183
+ else:
184
+ gamma = 0
185
+ sigma_hat = sigmas[i]
186
+
187
+ if gamma > 0:
188
+ eps = torch.randn_like(x) * s_noise
189
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
190
+ denoised = model(x, sigma_hat * s_in, **extra_args)
191
+ d = to_d(x, sigma_hat, denoised)
192
+ if callback is not None:
193
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
194
+ dt = sigmas[i + 1] - sigma_hat
195
+ # Euler method
196
+ x = x + d * dt
197
+ return x
198
+
199
+
200
+ @torch.no_grad()
201
+ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
202
+ if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
203
+ return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
204
+ """Ancestral sampling with Euler method steps."""
205
+ extra_args = {} if extra_args is None else extra_args
206
+ seed = extra_args.get("seed", None)
207
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
208
+ s_in = x.new_ones([x.shape[0]])
209
+ for i in trange(len(sigmas) - 1, disable=disable):
210
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
211
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
212
+ if callback is not None:
213
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
214
+
215
+ if sigma_down == 0:
216
+ x = denoised
217
+ else:
218
+ d = to_d(x, sigmas[i], denoised)
219
+ # Euler method
220
+ dt = sigma_down - sigmas[i]
221
+ x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
222
+ return x
223
+
224
+ @torch.no_grad()
225
+ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
226
+ """Ancestral sampling with Euler method steps."""
227
+ extra_args = {} if extra_args is None else extra_args
228
+ seed = extra_args.get("seed", None)
229
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
230
+ s_in = x.new_ones([x.shape[0]])
231
+ for i in trange(len(sigmas) - 1, disable=disable):
232
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
233
+ # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
234
+ if callback is not None:
235
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
236
+
237
+ if sigmas[i + 1] == 0:
238
+ x = denoised
239
+ else:
240
+ downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
241
+ sigma_down = sigmas[i + 1] * downstep_ratio
242
+ alpha_ip1 = 1 - sigmas[i + 1]
243
+ alpha_down = 1 - sigma_down
244
+ renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
245
+ # Euler method
246
+ sigma_down_i_ratio = sigma_down / sigmas[i]
247
+ x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
248
+ if eta > 0:
249
+ x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
250
+ return x
251
+
252
+ @torch.no_grad()
253
+ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
254
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
255
+ extra_args = {} if extra_args is None else extra_args
256
+ s_in = x.new_ones([x.shape[0]])
257
+ for i in trange(len(sigmas) - 1, disable=disable):
258
+ if s_churn > 0:
259
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
260
+ sigma_hat = sigmas[i] * (gamma + 1)
261
+ else:
262
+ gamma = 0
263
+ sigma_hat = sigmas[i]
264
+
265
+ sigma_hat = sigmas[i] * (gamma + 1)
266
+ if gamma > 0:
267
+ eps = torch.randn_like(x) * s_noise
268
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
269
+ denoised = model(x, sigma_hat * s_in, **extra_args)
270
+ d = to_d(x, sigma_hat, denoised)
271
+ if callback is not None:
272
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
273
+ dt = sigmas[i + 1] - sigma_hat
274
+ if sigmas[i + 1] == 0:
275
+ # Euler method
276
+ x = x + d * dt
277
+ else:
278
+ # Heun's method
279
+ x_2 = x + d * dt
280
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
281
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
282
+ d_prime = (d + d_2) / 2
283
+ x = x + d_prime * dt
284
+ return x
285
+
286
+
287
+ @torch.no_grad()
288
+ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
289
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
290
+ extra_args = {} if extra_args is None else extra_args
291
+ s_in = x.new_ones([x.shape[0]])
292
+ for i in trange(len(sigmas) - 1, disable=disable):
293
+ if s_churn > 0:
294
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
295
+ sigma_hat = sigmas[i] * (gamma + 1)
296
+ else:
297
+ gamma = 0
298
+ sigma_hat = sigmas[i]
299
+
300
+ if gamma > 0:
301
+ eps = torch.randn_like(x) * s_noise
302
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
303
+ denoised = model(x, sigma_hat * s_in, **extra_args)
304
+ d = to_d(x, sigma_hat, denoised)
305
+ if callback is not None:
306
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
307
+ if sigmas[i + 1] == 0:
308
+ # Euler method
309
+ dt = sigmas[i + 1] - sigma_hat
310
+ x = x + d * dt
311
+ else:
312
+ # DPM-Solver-2
313
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
314
+ dt_1 = sigma_mid - sigma_hat
315
+ dt_2 = sigmas[i + 1] - sigma_hat
316
+ x_2 = x + d * dt_1
317
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
318
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
319
+ x = x + d_2 * dt_2
320
+ return x
321
+
322
+
323
+ @torch.no_grad()
324
+ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
325
+ if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
326
+ return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
327
+
328
+ """Ancestral sampling with DPM-Solver second-order steps."""
329
+ extra_args = {} if extra_args is None else extra_args
330
+ seed = extra_args.get("seed", None)
331
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
332
+ s_in = x.new_ones([x.shape[0]])
333
+ for i in trange(len(sigmas) - 1, disable=disable):
334
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
335
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
336
+ if callback is not None:
337
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
338
+ d = to_d(x, sigmas[i], denoised)
339
+ if sigma_down == 0:
340
+ # Euler method
341
+ dt = sigma_down - sigmas[i]
342
+ x = x + d * dt
343
+ else:
344
+ # DPM-Solver-2
345
+ sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
346
+ dt_1 = sigma_mid - sigmas[i]
347
+ dt_2 = sigma_down - sigmas[i]
348
+ x_2 = x + d * dt_1
349
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
350
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
351
+ x = x + d_2 * dt_2
352
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
353
+ return x
354
+
355
+ @torch.no_grad()
356
+ def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
357
+ """Ancestral sampling with DPM-Solver second-order steps."""
358
+ extra_args = {} if extra_args is None else extra_args
359
+ seed = extra_args.get("seed", None)
360
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
361
+ s_in = x.new_ones([x.shape[0]])
362
+ for i in trange(len(sigmas) - 1, disable=disable):
363
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
364
+ downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
365
+ sigma_down = sigmas[i+1] * downstep_ratio
366
+ alpha_ip1 = 1 - sigmas[i+1]
367
+ alpha_down = 1 - sigma_down
368
+ renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
369
+
370
+ if callback is not None:
371
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
372
+ d = to_d(x, sigmas[i], denoised)
373
+ if sigma_down == 0:
374
+ # Euler method
375
+ dt = sigma_down - sigmas[i]
376
+ x = x + d * dt
377
+ else:
378
+ # DPM-Solver-2
379
+ sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
380
+ dt_1 = sigma_mid - sigmas[i]
381
+ dt_2 = sigma_down - sigmas[i]
382
+ x_2 = x + d * dt_1
383
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
384
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
385
+ x = x + d_2 * dt_2
386
+ x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
387
+ return x
388
+
389
+ def linear_multistep_coeff(order, t, i, j):
390
+ if order - 1 > i:
391
+ raise ValueError(f'Order {order} too high for step {i}')
392
+ def fn(tau):
393
+ prod = 1.
394
+ for k in range(order):
395
+ if j == k:
396
+ continue
397
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
398
+ return prod
399
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
400
+
401
+
402
+ @torch.no_grad()
403
+ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
404
+ extra_args = {} if extra_args is None else extra_args
405
+ s_in = x.new_ones([x.shape[0]])
406
+ sigmas_cpu = sigmas.detach().cpu().numpy()
407
+ ds = []
408
+ for i in trange(len(sigmas) - 1, disable=disable):
409
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
410
+ d = to_d(x, sigmas[i], denoised)
411
+ ds.append(d)
412
+ if len(ds) > order:
413
+ ds.pop(0)
414
+ if callback is not None:
415
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
416
+ if sigmas[i + 1] == 0:
417
+ # Denoising step
418
+ x = denoised
419
+ else:
420
+ cur_order = min(i + 1, order)
421
+ coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
422
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
423
+ return x
424
+
425
+
426
+ class PIDStepSizeController:
427
+ """A PID controller for ODE adaptive step size control."""
428
+ def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
429
+ self.h = h
430
+ self.b1 = (pcoeff + icoeff + dcoeff) / order
431
+ self.b2 = -(pcoeff + 2 * dcoeff) / order
432
+ self.b3 = dcoeff / order
433
+ self.accept_safety = accept_safety
434
+ self.eps = eps
435
+ self.errs = []
436
+
437
+ def limiter(self, x):
438
+ return 1 + math.atan(x - 1)
439
+
440
+ def propose_step(self, error):
441
+ inv_error = 1 / (float(error) + self.eps)
442
+ if not self.errs:
443
+ self.errs = [inv_error, inv_error, inv_error]
444
+ self.errs[0] = inv_error
445
+ factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
446
+ factor = self.limiter(factor)
447
+ accept = factor >= self.accept_safety
448
+ if accept:
449
+ self.errs[2] = self.errs[1]
450
+ self.errs[1] = self.errs[0]
451
+ self.h *= factor
452
+ return accept
453
+
454
+
455
+ class DPMSolver(nn.Module):
456
+ """DPM-Solver. See https://arxiv.org/abs/2206.00927."""
457
+
458
+ def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
459
+ super().__init__()
460
+ self.model = model
461
+ self.extra_args = {} if extra_args is None else extra_args
462
+ self.eps_callback = eps_callback
463
+ self.info_callback = info_callback
464
+
465
+ def t(self, sigma):
466
+ return -sigma.log()
467
+
468
+ def sigma(self, t):
469
+ return t.neg().exp()
470
+
471
+ def eps(self, eps_cache, key, x, t, *args, **kwargs):
472
+ if key in eps_cache:
473
+ return eps_cache[key], eps_cache
474
+ sigma = self.sigma(t) * x.new_ones([x.shape[0]])
475
+ eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
476
+ if self.eps_callback is not None:
477
+ self.eps_callback()
478
+ return eps, {key: eps, **eps_cache}
479
+
480
+ def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
481
+ eps_cache = {} if eps_cache is None else eps_cache
482
+ h = t_next - t
483
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
484
+ x_1 = x - self.sigma(t_next) * h.expm1() * eps
485
+ return x_1, eps_cache
486
+
487
+ def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
488
+ eps_cache = {} if eps_cache is None else eps_cache
489
+ h = t_next - t
490
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
491
+ s1 = t + r1 * h
492
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
493
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
494
+ x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
495
+ return x_2, eps_cache
496
+
497
+ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
498
+ eps_cache = {} if eps_cache is None else eps_cache
499
+ h = t_next - t
500
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
501
+ s1 = t + r1 * h
502
+ s2 = t + r2 * h
503
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
504
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
505
+ u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
506
+ eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
507
+ x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
508
+ return x_3, eps_cache
509
+
510
+ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
511
+ noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
512
+ if not t_end > t_start and eta:
513
+ raise ValueError('eta must be 0 for reverse sampling')
514
+
515
+ m = math.floor(nfe / 3) + 1
516
+ ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
517
+
518
+ if nfe % 3 == 0:
519
+ orders = [3] * (m - 2) + [2, 1]
520
+ else:
521
+ orders = [3] * (m - 1) + [nfe % 3]
522
+
523
+ for i in range(len(orders)):
524
+ eps_cache = {}
525
+ t, t_next = ts[i], ts[i + 1]
526
+ if eta:
527
+ sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
528
+ t_next_ = torch.minimum(t_end, self.t(sd))
529
+ su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
530
+ else:
531
+ t_next_, su = t_next, 0.
532
+
533
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
534
+ denoised = x - self.sigma(t) * eps
535
+ if self.info_callback is not None:
536
+ self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
537
+
538
+ if orders[i] == 1:
539
+ x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
540
+ elif orders[i] == 2:
541
+ x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
542
+ else:
543
+ x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
544
+
545
+ x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
546
+
547
+ return x
548
+
549
+ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
550
+ noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
551
+ if order not in {2, 3}:
552
+ raise ValueError('order should be 2 or 3')
553
+ forward = t_end > t_start
554
+ if not forward and eta:
555
+ raise ValueError('eta must be 0 for reverse sampling')
556
+ h_init = abs(h_init) * (1 if forward else -1)
557
+ atol = torch.tensor(atol)
558
+ rtol = torch.tensor(rtol)
559
+ s = t_start
560
+ x_prev = x
561
+ accept = True
562
+ pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
563
+ info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
564
+
565
+ while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
566
+ eps_cache = {}
567
+ t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
568
+ if eta:
569
+ sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
570
+ t_ = torch.minimum(t_end, self.t(sd))
571
+ su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
572
+ else:
573
+ t_, su = t, 0.
574
+
575
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
576
+ denoised = x - self.sigma(s) * eps
577
+
578
+ if order == 2:
579
+ x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
580
+ x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
581
+ else:
582
+ x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
583
+ x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
584
+ delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
585
+ error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
586
+ accept = pid.propose_step(error)
587
+ if accept:
588
+ x_prev = x_low
589
+ x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
590
+ s = t
591
+ info['n_accept'] += 1
592
+ else:
593
+ info['n_reject'] += 1
594
+ info['nfe'] += order
595
+ info['steps'] += 1
596
+
597
+ if self.info_callback is not None:
598
+ self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
599
+
600
+ return x, info
601
+
602
+
603
+ @torch.no_grad()
604
+ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
605
+ """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
606
+ if sigma_min <= 0 or sigma_max <= 0:
607
+ raise ValueError('sigma_min and sigma_max must not be 0')
608
+ with tqdm(total=n, disable=disable) as pbar:
609
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
610
+ if callback is not None:
611
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
612
+ return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
613
+
614
+
615
+ @torch.no_grad()
616
+ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
617
+ """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
618
+ if sigma_min <= 0 or sigma_max <= 0:
619
+ raise ValueError('sigma_min and sigma_max must not be 0')
620
+ with tqdm(disable=disable) as pbar:
621
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
622
+ if callback is not None:
623
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
624
+ x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
625
+ if return_info:
626
+ return x, info
627
+ return x
628
+
629
+
630
+ @torch.no_grad()
631
+ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
632
+ if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
633
+ return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
634
+
635
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
636
+ extra_args = {} if extra_args is None else extra_args
637
+ seed = extra_args.get("seed", None)
638
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
639
+ s_in = x.new_ones([x.shape[0]])
640
+ sigma_fn = lambda t: t.neg().exp()
641
+ t_fn = lambda sigma: sigma.log().neg()
642
+
643
+ for i in trange(len(sigmas) - 1, disable=disable):
644
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
645
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
646
+ if callback is not None:
647
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
648
+ if sigma_down == 0:
649
+ # Euler method
650
+ d = to_d(x, sigmas[i], denoised)
651
+ dt = sigma_down - sigmas[i]
652
+ x = x + d * dt
653
+ else:
654
+ # DPM-Solver++(2S)
655
+ t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
656
+ r = 1 / 2
657
+ h = t_next - t
658
+ s = t + r * h
659
+ x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
660
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
661
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
662
+ # Noise addition
663
+ if sigmas[i + 1] > 0:
664
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
665
+ return x
666
+
667
+
668
+ @torch.no_grad()
669
+ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
670
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
671
+ extra_args = {} if extra_args is None else extra_args
672
+ seed = extra_args.get("seed", None)
673
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
674
+ s_in = x.new_ones([x.shape[0]])
675
+ sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
676
+ lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
677
+
678
+ # logged_x = x.unsqueeze(0)
679
+
680
+ for i in trange(len(sigmas) - 1, disable=disable):
681
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
682
+ downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
683
+ sigma_down = sigmas[i+1] * downstep_ratio
684
+ alpha_ip1 = 1 - sigmas[i+1]
685
+ alpha_down = 1 - sigma_down
686
+ renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
687
+ # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
688
+ if callback is not None:
689
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
690
+ if sigmas[i + 1] == 0:
691
+ # Euler method
692
+ d = to_d(x, sigmas[i], denoised)
693
+ dt = sigma_down - sigmas[i]
694
+ x = x + d * dt
695
+ else:
696
+ # DPM-Solver++(2S)
697
+ if sigmas[i] == 1.0:
698
+ sigma_s = 0.9999
699
+ else:
700
+ t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
701
+ r = 1 / 2
702
+ h = t_down - t_i
703
+ s = t_i + r * h
704
+ sigma_s = sigma_fn(s)
705
+ # sigma_s = sigmas[i+1]
706
+ sigma_s_i_ratio = sigma_s / sigmas[i]
707
+ u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
708
+ D_i = model(u, sigma_s * s_in, **extra_args)
709
+ sigma_down_i_ratio = sigma_down / sigmas[i]
710
+ x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
711
+ # print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
712
+ # Noise addition
713
+ if sigmas[i + 1] > 0 and eta > 0:
714
+ x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
715
+ # logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
716
+ return x
717
+
718
+
719
+ @torch.no_grad()
720
+ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
721
+ """DPM-Solver++ (stochastic)."""
722
+ if len(sigmas) <= 1:
723
+ return x
724
+
725
+ extra_args = {} if extra_args is None else extra_args
726
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
727
+ seed = extra_args.get("seed", None)
728
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
729
+ s_in = x.new_ones([x.shape[0]])
730
+
731
+ model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
732
+ sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
733
+ lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
734
+ sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
735
+
736
+ for i in trange(len(sigmas) - 1, disable=disable):
737
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
738
+ if callback is not None:
739
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
740
+ if sigmas[i + 1] == 0:
741
+ # Denoising step
742
+ x = denoised
743
+ else:
744
+ # DPM-Solver++
745
+ lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
746
+ h = lambda_t - lambda_s
747
+ lambda_s_1 = lambda_s + r * h
748
+ fac = 1 / (2 * r)
749
+
750
+ sigma_s_1 = sigma_fn(lambda_s_1)
751
+
752
+ alpha_s = sigmas[i] * lambda_s.exp()
753
+ alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
754
+ alpha_t = sigmas[i + 1] * lambda_t.exp()
755
+
756
+ # Step 1
757
+ sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
758
+ lambda_s_1_ = sd.log().neg()
759
+ h_ = lambda_s_1_ - lambda_s
760
+ x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
761
+ if eta > 0 and s_noise > 0:
762
+ x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
763
+ denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
764
+
765
+ # Step 2
766
+ sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
767
+ lambda_t_ = sd.log().neg()
768
+ h_ = lambda_t_ - lambda_s
769
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
770
+ x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
771
+ if eta > 0 and s_noise > 0:
772
+ x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
773
+ return x
774
+
775
+
776
+ @torch.no_grad()
777
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
778
+ """DPM-Solver++(2M)."""
779
+ extra_args = {} if extra_args is None else extra_args
780
+ s_in = x.new_ones([x.shape[0]])
781
+ sigma_fn = lambda t: t.neg().exp()
782
+ t_fn = lambda sigma: sigma.log().neg()
783
+ old_denoised = None
784
+
785
+ for i in trange(len(sigmas) - 1, disable=disable):
786
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
787
+ if callback is not None:
788
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
789
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
790
+ h = t_next - t
791
+ if old_denoised is None or sigmas[i + 1] == 0:
792
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
793
+ else:
794
+ h_last = t - t_fn(sigmas[i - 1])
795
+ r = h_last / h
796
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
797
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
798
+ old_denoised = denoised
799
+ return x
800
+
801
+
802
+ @torch.no_grad()
803
+ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
804
+ """DPM-Solver++(2M) SDE."""
805
+ if len(sigmas) <= 1:
806
+ return x
807
+
808
+ if solver_type not in {'heun', 'midpoint'}:
809
+ raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
810
+
811
+ extra_args = {} if extra_args is None else extra_args
812
+ seed = extra_args.get("seed", None)
813
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
814
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
815
+ s_in = x.new_ones([x.shape[0]])
816
+
817
+ model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
818
+ lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
819
+ sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
820
+
821
+ old_denoised = None
822
+ h, h_last = None, None
823
+
824
+ for i in trange(len(sigmas) - 1, disable=disable):
825
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
826
+ if callback is not None:
827
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
828
+ if sigmas[i + 1] == 0:
829
+ # Denoising step
830
+ x = denoised
831
+ else:
832
+ # DPM-Solver++(2M) SDE
833
+ lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
834
+ h = lambda_t - lambda_s
835
+ h_eta = h * (eta + 1)
836
+
837
+ alpha_t = sigmas[i + 1] * lambda_t.exp()
838
+
839
+ x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
840
+
841
+ if old_denoised is not None:
842
+ r = h_last / h
843
+ if solver_type == 'heun':
844
+ x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
845
+ elif solver_type == 'midpoint':
846
+ x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
847
+
848
+ if eta > 0 and s_noise > 0:
849
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
850
+
851
+ old_denoised = denoised
852
+ h_last = h
853
+ return x
854
+
855
+
856
+ @torch.no_grad()
857
+ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
858
+ """DPM-Solver++(3M) SDE."""
859
+
860
+ if len(sigmas) <= 1:
861
+ return x
862
+
863
+ extra_args = {} if extra_args is None else extra_args
864
+ seed = extra_args.get("seed", None)
865
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
866
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
867
+ s_in = x.new_ones([x.shape[0]])
868
+
869
+ model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
870
+ lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
871
+ sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
872
+
873
+ denoised_1, denoised_2 = None, None
874
+ h, h_1, h_2 = None, None, None
875
+
876
+ for i in trange(len(sigmas) - 1, disable=disable):
877
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
878
+ if callback is not None:
879
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
880
+ if sigmas[i + 1] == 0:
881
+ # Denoising step
882
+ x = denoised
883
+ else:
884
+ lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
885
+ h = lambda_t - lambda_s
886
+ h_eta = h * (eta + 1)
887
+
888
+ alpha_t = sigmas[i + 1] * lambda_t.exp()
889
+
890
+ x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
891
+
892
+ if h_2 is not None:
893
+ # DPM-Solver++(3M) SDE
894
+ r0 = h_1 / h
895
+ r1 = h_2 / h
896
+ d1_0 = (denoised - denoised_1) / r0
897
+ d1_1 = (denoised_1 - denoised_2) / r1
898
+ d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
899
+ d2 = (d1_0 - d1_1) / (r0 + r1)
900
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
901
+ phi_3 = phi_2 / h_eta - 0.5
902
+ x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
903
+ elif h_1 is not None:
904
+ # DPM-Solver++(2M) SDE
905
+ r = h_1 / h
906
+ d = (denoised - denoised_1) / r
907
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
908
+ x = x + (alpha_t * phi_2) * d
909
+
910
+ if eta > 0 and s_noise > 0:
911
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
912
+
913
+ denoised_1, denoised_2 = denoised, denoised_1
914
+ h_1, h_2 = h, h_1
915
+ return x
916
+
917
+
918
+ @torch.no_grad()
919
+ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
920
+ if len(sigmas) <= 1:
921
+ return x
922
+ extra_args = {} if extra_args is None else extra_args
923
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
924
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
925
+ return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
926
+
927
+
928
+ @torch.no_grad()
929
+ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
930
+ if len(sigmas) <= 1:
931
+ return x
932
+ extra_args = {} if extra_args is None else extra_args
933
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
934
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
935
+ return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
936
+
937
+
938
+ @torch.no_grad()
939
+ def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
940
+ if len(sigmas) <= 1:
941
+ return x
942
+ extra_args = {} if extra_args is None else extra_args
943
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
944
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
945
+ return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
946
+
947
+
948
+ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
949
+ alpha_cumprod = 1 / ((sigma * sigma) + 1)
950
+ alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
951
+ alpha = (alpha_cumprod / alpha_cumprod_prev)
952
+
953
+ mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt())
954
+ if sigma_prev > 0:
955
+ mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
956
+ return mu
957
+
958
+ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
959
+ extra_args = {} if extra_args is None else extra_args
960
+ seed = extra_args.get("seed", None)
961
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
962
+ s_in = x.new_ones([x.shape[0]])
963
+
964
+ for i in trange(len(sigmas) - 1, disable=disable):
965
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
966
+ if callback is not None:
967
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
968
+ x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler)
969
+ if sigmas[i + 1] != 0:
970
+ x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0)
971
+ return x
972
+
973
+
974
+ @torch.no_grad()
975
+ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
976
+ return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
977
+
978
+ @torch.no_grad()
979
+ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
980
+ extra_args = {} if extra_args is None else extra_args
981
+ seed = extra_args.get("seed", None)
982
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
983
+ s_in = x.new_ones([x.shape[0]])
984
+ for i in trange(len(sigmas) - 1, disable=disable):
985
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
986
+ if callback is not None:
987
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
988
+
989
+ x = denoised
990
+ if sigmas[i + 1] > 0:
991
+ x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
992
+ return x
993
+
994
+
995
+
996
+ @torch.no_grad()
997
+ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
998
+ # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
999
+ extra_args = {} if extra_args is None else extra_args
1000
+ s_in = x.new_ones([x.shape[0]])
1001
+ s_end = sigmas[-1]
1002
+ for i in trange(len(sigmas) - 1, disable=disable):
1003
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
1004
+ eps = torch.randn_like(x) * s_noise
1005
+ sigma_hat = sigmas[i] * (gamma + 1)
1006
+ if gamma > 0:
1007
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
1008
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1009
+ d = to_d(x, sigma_hat, denoised)
1010
+ if callback is not None:
1011
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1012
+ dt = sigmas[i + 1] - sigma_hat
1013
+ if sigmas[i + 1] == s_end:
1014
+ # Euler method
1015
+ x = x + d * dt
1016
+ elif sigmas[i + 2] == s_end:
1017
+
1018
+ # Heun's method
1019
+ x_2 = x + d * dt
1020
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
1021
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
1022
+
1023
+ w = 2 * sigmas[0]
1024
+ w2 = sigmas[i+1]/w
1025
+ w1 = 1 - w2
1026
+
1027
+ d_prime = d * w1 + d_2 * w2
1028
+
1029
+
1030
+ x = x + d_prime * dt
1031
+
1032
+ else:
1033
+ # Heun++
1034
+ x_2 = x + d * dt
1035
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
1036
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
1037
+ dt_2 = sigmas[i + 2] - sigmas[i + 1]
1038
+
1039
+ x_3 = x_2 + d_2 * dt_2
1040
+ denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
1041
+ d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
1042
+
1043
+ w = 3 * sigmas[0]
1044
+ w2 = sigmas[i + 1] / w
1045
+ w3 = sigmas[i + 2] / w
1046
+ w1 = 1 - w2 - w3
1047
+
1048
+ d_prime = w1 * d + w2 * d_2 + w3 * d_3
1049
+ x = x + d_prime * dt
1050
+ return x
1051
+
1052
+
1053
+ #From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
1054
+ #under Apache 2 license
1055
+ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
1056
+ extra_args = {} if extra_args is None else extra_args
1057
+ s_in = x.new_ones([x.shape[0]])
1058
+
1059
+ x_next = x
1060
+
1061
+ buffer_model = []
1062
+ for i in trange(len(sigmas) - 1, disable=disable):
1063
+ t_cur = sigmas[i]
1064
+ t_next = sigmas[i + 1]
1065
+
1066
+ x_cur = x_next
1067
+
1068
+ denoised = model(x_cur, t_cur * s_in, **extra_args)
1069
+ if callback is not None:
1070
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1071
+
1072
+ d_cur = (x_cur - denoised) / t_cur
1073
+
1074
+ order = min(max_order, i+1)
1075
+ if t_next == 0: # Denoising step
1076
+ x_next = denoised
1077
+ elif order == 1: # First Euler step.
1078
+ x_next = x_cur + (t_next - t_cur) * d_cur
1079
+ elif order == 2: # Use one history point.
1080
+ x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
1081
+ elif order == 3: # Use two history points.
1082
+ x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
1083
+ elif order == 4: # Use three history points.
1084
+ x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
1085
+
1086
+ if len(buffer_model) == max_order - 1:
1087
+ for k in range(max_order - 2):
1088
+ buffer_model[k] = buffer_model[k+1]
1089
+ buffer_model[-1] = d_cur
1090
+ else:
1091
+ buffer_model.append(d_cur)
1092
+
1093
+ return x_next
1094
+
1095
+
1096
+ #From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
1097
+ #under Apache 2 license
1098
+ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
1099
+ extra_args = {} if extra_args is None else extra_args
1100
+ s_in = x.new_ones([x.shape[0]])
1101
+
1102
+ x_next = x
1103
+ t_steps = sigmas
1104
+
1105
+ buffer_model = []
1106
+ for i in trange(len(sigmas) - 1, disable=disable):
1107
+ t_cur = sigmas[i]
1108
+ t_next = sigmas[i + 1]
1109
+
1110
+ x_cur = x_next
1111
+
1112
+ denoised = model(x_cur, t_cur * s_in, **extra_args)
1113
+ if callback is not None:
1114
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1115
+
1116
+ d_cur = (x_cur - denoised) / t_cur
1117
+
1118
+ order = min(max_order, i+1)
1119
+ if t_next == 0: # Denoising step
1120
+ x_next = denoised
1121
+ elif order == 1: # First Euler step.
1122
+ x_next = x_cur + (t_next - t_cur) * d_cur
1123
+ elif order == 2: # Use one history point.
1124
+ h_n = (t_next - t_cur)
1125
+ h_n_1 = (t_cur - t_steps[i-1])
1126
+ coeff1 = (2 + (h_n / h_n_1)) / 2
1127
+ coeff2 = -(h_n / h_n_1) / 2
1128
+ x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
1129
+ elif order == 3: # Use two history points.
1130
+ h_n = (t_next - t_cur)
1131
+ h_n_1 = (t_cur - t_steps[i-1])
1132
+ h_n_2 = (t_steps[i-1] - t_steps[i-2])
1133
+ temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
1134
+ coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
1135
+ coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
1136
+ coeff3 = temp * h_n_1 / h_n_2
1137
+ x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
1138
+ elif order == 4: # Use three history points.
1139
+ h_n = (t_next - t_cur)
1140
+ h_n_1 = (t_cur - t_steps[i-1])
1141
+ h_n_2 = (t_steps[i-1] - t_steps[i-2])
1142
+ h_n_3 = (t_steps[i-2] - t_steps[i-3])
1143
+ temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
1144
+ temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
1145
+ * (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
1146
+ coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
1147
+ coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
1148
+ coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
1149
+ coeff4 = -temp2 * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * h_n_1 / h_n_2
1150
+ x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + coeff4 * buffer_model[-3])
1151
+
1152
+ if len(buffer_model) == max_order - 1:
1153
+ for k in range(max_order - 2):
1154
+ buffer_model[k] = buffer_model[k+1]
1155
+ buffer_model[-1] = d_cur.detach()
1156
+ else:
1157
+ buffer_model.append(d_cur.detach())
1158
+
1159
+ return x_next
1160
+
1161
+
1162
+ #From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
1163
+ #under Apache 2 license
1164
+ @torch.no_grad()
1165
+ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
1166
+ extra_args = {} if extra_args is None else extra_args
1167
+ s_in = x.new_ones([x.shape[0]])
1168
+
1169
+ x_next = x
1170
+ t_steps = sigmas
1171
+
1172
+ coeff_list = deis.get_deis_coeff_list(t_steps, max_order, deis_mode=deis_mode)
1173
+
1174
+ buffer_model = []
1175
+ for i in trange(len(sigmas) - 1, disable=disable):
1176
+ t_cur = sigmas[i]
1177
+ t_next = sigmas[i + 1]
1178
+
1179
+ x_cur = x_next
1180
+
1181
+ denoised = model(x_cur, t_cur * s_in, **extra_args)
1182
+ if callback is not None:
1183
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1184
+
1185
+ d_cur = (x_cur - denoised) / t_cur
1186
+
1187
+ order = min(max_order, i+1)
1188
+ if t_next <= 0:
1189
+ order = 1
1190
+
1191
+ if order == 1: # First Euler step.
1192
+ x_next = x_cur + (t_next - t_cur) * d_cur
1193
+ elif order == 2: # Use one history point.
1194
+ coeff_cur, coeff_prev1 = coeff_list[i]
1195
+ x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
1196
+ elif order == 3: # Use two history points.
1197
+ coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
1198
+ x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
1199
+ elif order == 4: # Use three history points.
1200
+ coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
1201
+ x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
1202
+
1203
+ if len(buffer_model) == max_order - 1:
1204
+ for k in range(max_order - 2):
1205
+ buffer_model[k] = buffer_model[k+1]
1206
+ buffer_model[-1] = d_cur.detach()
1207
+ else:
1208
+ buffer_model.append(d_cur.detach())
1209
+
1210
+ return x_next
1211
+
1212
+
1213
+ @torch.no_grad()
1214
+ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
1215
+ """Ancestral sampling with Euler method steps (CFG++)."""
1216
+ extra_args = {} if extra_args is None else extra_args
1217
+ seed = extra_args.get("seed", None)
1218
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
1219
+
1220
+ model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
1221
+ lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
1222
+
1223
+ uncond_denoised = None
1224
+
1225
+ def post_cfg_function(args):
1226
+ nonlocal uncond_denoised
1227
+ uncond_denoised = args["uncond_denoised"]
1228
+ return args["denoised"]
1229
+
1230
+ model_options = extra_args.get("model_options", {}).copy()
1231
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1232
+
1233
+ s_in = x.new_ones([x.shape[0]])
1234
+ for i in trange(len(sigmas) - 1, disable=disable):
1235
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1236
+ if callback is not None:
1237
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1238
+ if sigmas[i + 1] == 0:
1239
+ # Denoising step
1240
+ x = denoised
1241
+ else:
1242
+ alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
1243
+ alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
1244
+ d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
1245
+
1246
+ # DDIM stochastic sampling
1247
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
1248
+ sigma_down = alpha_t * sigma_down
1249
+
1250
+ # Euler method
1251
+ x = alpha_t * denoised + sigma_down * d
1252
+ if eta > 0 and s_noise > 0:
1253
+ x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
1254
+ return x
1255
+
1256
+
1257
+ @torch.no_grad()
1258
+ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
1259
+ """Euler method steps (CFG++)."""
1260
+ return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
1261
+
1262
+
1263
+ @torch.no_grad()
1264
+ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
1265
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
1266
+ extra_args = {} if extra_args is None else extra_args
1267
+ seed = extra_args.get("seed", None)
1268
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
1269
+
1270
+ temp = [0]
1271
+ def post_cfg_function(args):
1272
+ temp[0] = args["uncond_denoised"]
1273
+ return args["denoised"]
1274
+
1275
+ model_options = extra_args.get("model_options", {}).copy()
1276
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1277
+
1278
+ s_in = x.new_ones([x.shape[0]])
1279
+ sigma_fn = lambda t: t.neg().exp()
1280
+ t_fn = lambda sigma: sigma.log().neg()
1281
+
1282
+ for i in trange(len(sigmas) - 1, disable=disable):
1283
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1284
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
1285
+ if callback is not None:
1286
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1287
+ if sigma_down == 0:
1288
+ # Euler method
1289
+ d = to_d(x, sigmas[i], temp[0])
1290
+ x = denoised + d * sigma_down
1291
+ else:
1292
+ # DPM-Solver++(2S)
1293
+ t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
1294
+ # r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
1295
+ r = 1 / 2
1296
+ h = t_next - t
1297
+ s = t + r * h
1298
+ x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
1299
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
1300
+ x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
1301
+ # Noise addition
1302
+ if sigmas[i + 1] > 0:
1303
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
1304
+ return x
1305
+
1306
+ @torch.no_grad()
1307
+ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
1308
+ """DPM-Solver++(2M)."""
1309
+ extra_args = {} if extra_args is None else extra_args
1310
+ s_in = x.new_ones([x.shape[0]])
1311
+ t_fn = lambda sigma: sigma.log().neg()
1312
+
1313
+ old_uncond_denoised = None
1314
+ uncond_denoised = None
1315
+ def post_cfg_function(args):
1316
+ nonlocal uncond_denoised
1317
+ uncond_denoised = args["uncond_denoised"]
1318
+ return args["denoised"]
1319
+
1320
+ model_options = extra_args.get("model_options", {}).copy()
1321
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1322
+
1323
+ for i in trange(len(sigmas) - 1, disable=disable):
1324
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1325
+ if callback is not None:
1326
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1327
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
1328
+ h = t_next - t
1329
+ if old_uncond_denoised is None or sigmas[i + 1] == 0:
1330
+ denoised_mix = -torch.exp(-h) * uncond_denoised
1331
+ else:
1332
+ h_last = t - t_fn(sigmas[i - 1])
1333
+ r = h_last / h
1334
+ denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
1335
+ x = denoised + denoised_mix + torch.exp(-h) * x
1336
+ old_uncond_denoised = uncond_denoised
1337
+ return x
1338
+
1339
+ @torch.no_grad()
1340
+ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False):
1341
+ extra_args = {} if extra_args is None else extra_args
1342
+ seed = extra_args.get("seed", None)
1343
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
1344
+ s_in = x.new_ones([x.shape[0]])
1345
+ sigma_fn = lambda t: t.neg().exp()
1346
+ t_fn = lambda sigma: sigma.log().neg()
1347
+ phi1_fn = lambda t: torch.expm1(t) / t
1348
+ phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
1349
+
1350
+ old_sigma_down = None
1351
+ old_denoised = None
1352
+ uncond_denoised = None
1353
+ def post_cfg_function(args):
1354
+ nonlocal uncond_denoised
1355
+ uncond_denoised = args["uncond_denoised"]
1356
+ return args["denoised"]
1357
+
1358
+ if cfg_pp:
1359
+ model_options = extra_args.get("model_options", {}).copy()
1360
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1361
+
1362
+ for i in trange(len(sigmas) - 1, disable=disable):
1363
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1364
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
1365
+ if callback is not None:
1366
+ callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
1367
+ if sigma_down == 0 or old_denoised is None:
1368
+ # Euler method
1369
+ if cfg_pp:
1370
+ d = to_d(x, sigmas[i], uncond_denoised)
1371
+ x = denoised + d * sigma_down
1372
+ else:
1373
+ d = to_d(x, sigmas[i], denoised)
1374
+ dt = sigma_down - sigmas[i]
1375
+ x = x + d * dt
1376
+ else:
1377
+ # Second order multistep method in https://arxiv.org/pdf/2308.02157
1378
+ t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
1379
+ h = t_next - t
1380
+ c2 = (t_prev - t_old) / h
1381
+
1382
+ phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
1383
+ b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
1384
+ b2 = torch.nan_to_num(phi2_val / c2, nan=0.0)
1385
+
1386
+ if cfg_pp:
1387
+ x = x + (denoised - uncond_denoised)
1388
+ x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised)
1389
+ else:
1390
+ x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised)
1391
+
1392
+ # Noise addition
1393
+ if sigmas[i + 1] > 0:
1394
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
1395
+
1396
+ if cfg_pp:
1397
+ old_denoised = uncond_denoised
1398
+ else:
1399
+ old_denoised = denoised
1400
+ old_sigma_down = sigma_down
1401
+ return x
1402
+
1403
+ @torch.no_grad()
1404
+ def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
1405
+ return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False)
1406
+
1407
+ @torch.no_grad()
1408
+ def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
1409
+ return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True)
1410
+
1411
+ @torch.no_grad()
1412
+ def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
1413
+ return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False)
1414
+
1415
+ @torch.no_grad()
1416
+ def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
1417
+ return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
1418
+
1419
+
1420
+ @torch.no_grad()
1421
+ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
1422
+ """Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
1423
+ extra_args = {} if extra_args is None else extra_args
1424
+ s_in = x.new_ones([x.shape[0]])
1425
+ old_d = None
1426
+
1427
+ uncond_denoised = None
1428
+ def post_cfg_function(args):
1429
+ nonlocal uncond_denoised
1430
+ uncond_denoised = args["uncond_denoised"]
1431
+ return args["denoised"]
1432
+
1433
+ if cfg_pp:
1434
+ model_options = extra_args.get("model_options", {}).copy()
1435
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1436
+
1437
+ for i in trange(len(sigmas) - 1, disable=disable):
1438
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1439
+ if cfg_pp:
1440
+ d = to_d(x, sigmas[i], uncond_denoised)
1441
+ else:
1442
+ d = to_d(x, sigmas[i], denoised)
1443
+ if callback is not None:
1444
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1445
+ dt = sigmas[i + 1] - sigmas[i]
1446
+ if sigmas[i + 1] == 0:
1447
+ # Denoising step
1448
+ x = denoised
1449
+ else:
1450
+ # Euler method
1451
+ if cfg_pp:
1452
+ x = denoised + d * sigmas[i + 1]
1453
+ else:
1454
+ x = x + d * dt
1455
+
1456
+ if i >= 1:
1457
+ # Gradient estimation
1458
+ d_bar = (ge_gamma - 1) * (d - old_d)
1459
+ x = x + d_bar * dt
1460
+ old_d = d
1461
+ return x
1462
+
1463
+
1464
+ @torch.no_grad()
1465
+ def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
1466
+ return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
1467
+
1468
+
1469
+ @torch.no_grad()
1470
+ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
1471
+ """Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
1472
+ Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
1473
+ """
1474
+ extra_args = {} if extra_args is None else extra_args
1475
+ seed = extra_args.get("seed", None)
1476
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
1477
+ s_in = x.new_ones([x.shape[0]])
1478
+
1479
+ def default_er_sde_noise_scaler(x):
1480
+ return x * ((x ** 0.3).exp() + 10.0)
1481
+
1482
+ noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
1483
+ num_integration_points = 200.0
1484
+ point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
1485
+
1486
+ model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
1487
+ sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
1488
+ half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
1489
+ er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
1490
+
1491
+ old_denoised = None
1492
+ old_denoised_d = None
1493
+
1494
+ for i in trange(len(sigmas) - 1, disable=disable):
1495
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1496
+ if callback is not None:
1497
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1498
+ stage_used = min(max_stage, i + 1)
1499
+ if sigmas[i + 1] == 0:
1500
+ x = denoised
1501
+ else:
1502
+ er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
1503
+ alpha_s = sigmas[i] / er_lambda_s
1504
+ alpha_t = sigmas[i + 1] / er_lambda_t
1505
+ r_alpha = alpha_t / alpha_s
1506
+ r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
1507
+
1508
+ # Stage 1 Euler
1509
+ x = r_alpha * r * x + alpha_t * (1 - r) * denoised
1510
+
1511
+ if stage_used >= 2:
1512
+ dt = er_lambda_t - er_lambda_s
1513
+ lambda_step_size = -dt / num_integration_points
1514
+ lambda_pos = er_lambda_t + point_indice * lambda_step_size
1515
+ scaled_pos = noise_scaler(lambda_pos)
1516
+
1517
+ # Stage 2
1518
+ s = torch.sum(1 / scaled_pos) * lambda_step_size
1519
+ denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
1520
+ x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
1521
+
1522
+ if stage_used >= 3:
1523
+ # Stage 3
1524
+ s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
1525
+ denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
1526
+ x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
1527
+ old_denoised_d = denoised_d
1528
+
1529
+ if s_noise > 0:
1530
+ x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
1531
+ old_denoised = denoised
1532
+ return x
1533
+
1534
+
1535
+ @torch.no_grad()
1536
+ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
1537
+ """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
1538
+ arXiv: https://arxiv.org/abs/2305.14267
1539
+ """
1540
+ extra_args = {} if extra_args is None else extra_args
1541
+ seed = extra_args.get("seed", None)
1542
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
1543
+ s_in = x.new_ones([x.shape[0]])
1544
+
1545
+ inject_noise = eta > 0 and s_noise > 0
1546
+
1547
+ model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
1548
+ sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
1549
+ lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
1550
+ sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
1551
+
1552
+ for i in trange(len(sigmas) - 1, disable=disable):
1553
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1554
+ if callback is not None:
1555
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1556
+ if sigmas[i + 1] == 0:
1557
+ x = denoised
1558
+ else:
1559
+ lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
1560
+ h = lambda_t - lambda_s
1561
+ h_eta = h * (eta + 1)
1562
+ lambda_s_1 = lambda_s + r * h
1563
+ fac = 1 / (2 * r)
1564
+ sigma_s_1 = sigma_fn(lambda_s_1)
1565
+
1566
+ # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
1567
+ alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
1568
+ alpha_t = sigmas[i + 1] * lambda_t.exp()
1569
+
1570
+ coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
1571
+ if inject_noise:
1572
+ # 0 < r < 1
1573
+ noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
1574
+ noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
1575
+ noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
1576
+
1577
+ # Step 1
1578
+ x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
1579
+ if inject_noise:
1580
+ x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
1581
+ denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
1582
+
1583
+ # Step 2
1584
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
1585
+ x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
1586
+ if inject_noise:
1587
+ x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
1588
+ return x
1589
+
1590
+
1591
+ @torch.no_grad()
1592
+ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
1593
+ """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
1594
+ arXiv: https://arxiv.org/abs/2305.14267
1595
+ """
1596
+ extra_args = {} if extra_args is None else extra_args
1597
+ seed = extra_args.get("seed", None)
1598
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
1599
+ s_in = x.new_ones([x.shape[0]])
1600
+
1601
+ inject_noise = eta > 0 and s_noise > 0
1602
+
1603
+ model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
1604
+ sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
1605
+ lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
1606
+ sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
1607
+
1608
+ for i in trange(len(sigmas) - 1, disable=disable):
1609
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1610
+ if callback is not None:
1611
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1612
+ if sigmas[i + 1] == 0:
1613
+ x = denoised
1614
+ else:
1615
+ lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
1616
+ h = lambda_t - lambda_s
1617
+ h_eta = h * (eta + 1)
1618
+ lambda_s_1 = lambda_s + r_1 * h
1619
+ lambda_s_2 = lambda_s + r_2 * h
1620
+ sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
1621
+
1622
+ # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
1623
+ alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
1624
+ alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
1625
+ alpha_t = sigmas[i + 1] * lambda_t.exp()
1626
+
1627
+ coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
1628
+ if inject_noise:
1629
+ # 0 < r_1 < r_2 < 1
1630
+ noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
1631
+ noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
1632
+ noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
1633
+ noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
1634
+
1635
+ # Step 1
1636
+ x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
1637
+ if inject_noise:
1638
+ x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
1639
+ denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
1640
+
1641
+ # Step 2
1642
+ x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
1643
+ if inject_noise:
1644
+ x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
1645
+ denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
1646
+
1647
+ # Step 3
1648
+ x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
1649
+ if inject_noise:
1650
+ x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
1651
+ return x
1652
+
1653
+
1654
+ @torch.no_grad()
1655
+ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False):
1656
+ """Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023)."""
1657
+ if len(sigmas) <= 1:
1658
+ return x
1659
+ extra_args = {} if extra_args is None else extra_args
1660
+ seed = extra_args.get("seed", None)
1661
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
1662
+ s_in = x.new_ones([x.shape[0]])
1663
+
1664
+ model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
1665
+ sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
1666
+ lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
1667
+
1668
+ if tau_func is None:
1669
+ # Use default interval for stochastic sampling
1670
+ start_sigma = model_sampling.percent_to_sigma(0.2)
1671
+ end_sigma = model_sampling.percent_to_sigma(0.8)
1672
+ tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0)
1673
+
1674
+ max_used_order = max(predictor_order, corrector_order)
1675
+ x_pred = x # x: current state, x_pred: predicted next state
1676
+
1677
+ h = 0.0
1678
+ tau_t = 0.0
1679
+ noise = 0.0
1680
+ pred_list = []
1681
+
1682
+ # Lower order near the end to improve stability
1683
+ lower_order_to_end = sigmas[-1].item() == 0
1684
+
1685
+ for i in trange(len(sigmas) - 1, disable=disable):
1686
+ # Evaluation
1687
+ denoised = model(x_pred, sigmas[i] * s_in, **extra_args)
1688
+ if callback is not None:
1689
+ callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
1690
+ pred_list.append(denoised)
1691
+ pred_list = pred_list[-max_used_order:]
1692
+
1693
+ predictor_order_used = min(predictor_order, len(pred_list))
1694
+ if i == 0 or (sigmas[i + 1] == 0 and not use_pece):
1695
+ corrector_order_used = 0
1696
+ else:
1697
+ corrector_order_used = min(corrector_order, len(pred_list))
1698
+
1699
+ if lower_order_to_end:
1700
+ predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i)
1701
+ corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i)
1702
+
1703
+ # Corrector
1704
+ if corrector_order_used == 0:
1705
+ # Update by the predicted state
1706
+ x = x_pred
1707
+ else:
1708
+ curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1]
1709
+ b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
1710
+ sigmas[i],
1711
+ curr_lambdas,
1712
+ lambdas[i - 1],
1713
+ lambdas[i],
1714
+ tau_t,
1715
+ simple_order_2,
1716
+ is_corrector_step=True,
1717
+ )
1718
+ pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...)
1719
+ corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
1720
+ x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res
1721
+
1722
+ if tau_t > 0 and s_noise > 0:
1723
+ # The noise from the previous predictor step
1724
+ x = x + noise
1725
+
1726
+ if use_pece:
1727
+ # Evaluate the corrected state
1728
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1729
+ pred_list[-1] = denoised
1730
+
1731
+ # Predictor
1732
+ if sigmas[i + 1] == 0:
1733
+ # Denoising step
1734
+ x = denoised
1735
+ else:
1736
+ tau_t = tau_func(sigmas[i + 1])
1737
+ curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
1738
+ b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
1739
+ sigmas[i + 1],
1740
+ curr_lambdas,
1741
+ lambdas[i],
1742
+ lambdas[i + 1],
1743
+ tau_t,
1744
+ simple_order_2,
1745
+ is_corrector_step=False,
1746
+ )
1747
+ pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...)
1748
+ pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
1749
+ h = lambdas[i + 1] - lambdas[i]
1750
+ x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res
1751
+
1752
+ if tau_t > 0 and s_noise > 0:
1753
+ noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
1754
+ x_pred = x_pred + noise
1755
+ return x
1756
+
1757
+
1758
+ @torch.no_grad()
1759
+ def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
1760
+ """Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
1761
+ return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
ComfyUI/comfy/ldm/common_dit.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy.rmsnorm
3
+
4
+
5
+ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
6
+ if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
7
+ padding_mode = "reflect"
8
+
9
+ pad = ()
10
+ for i in range(img.ndim - 2):
11
+ pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
12
+
13
+ return torch.nn.functional.pad(img, pad, mode=padding_mode)
14
+
15
+
16
+ rms_norm = comfy.rmsnorm.rms_norm
ComfyUI/comfy/model_detection.py ADDED
@@ -0,0 +1,910 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import comfy.supported_models
3
+ import comfy.supported_models_base
4
+ import comfy.utils
5
+ import math
6
+ import logging
7
+ import torch
8
+
9
+ def count_blocks(state_dict_keys, prefix_string):
10
+ count = 0
11
+ while True:
12
+ c = False
13
+ for k in state_dict_keys:
14
+ if k.startswith(prefix_string.format(count)):
15
+ c = True
16
+ break
17
+ if c == False:
18
+ break
19
+ count += 1
20
+ return count
21
+
22
+ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
23
+ context_dim = None
24
+ use_linear_in_transformer = False
25
+
26
+ transformer_prefix = prefix + "1.transformer_blocks."
27
+ transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
28
+ if len(transformer_keys) > 0:
29
+ last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
30
+ context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
31
+ use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
32
+ time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
33
+ time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
34
+ return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
35
+ return None
36
+
37
+ def detect_unet_config(state_dict, key_prefix, metadata=None):
38
+ state_dict_keys = list(state_dict.keys())
39
+
40
+ if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
41
+ unet_config = {}
42
+ unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
43
+ patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
44
+ unet_config["patch_size"] = patch_size
45
+ final_layer = '{}final_layer.linear.weight'.format(key_prefix)
46
+ if final_layer in state_dict:
47
+ unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size)
48
+
49
+ unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
50
+ unet_config["input_size"] = None
51
+ y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
52
+ if y_key in state_dict_keys:
53
+ unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
54
+
55
+ context_key = '{}context_embedder.weight'.format(key_prefix)
56
+ if context_key in state_dict_keys:
57
+ in_features = state_dict[context_key].shape[1]
58
+ out_features = state_dict[context_key].shape[0]
59
+ unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
60
+ num_patches_key = '{}pos_embed'.format(key_prefix)
61
+ if num_patches_key in state_dict_keys:
62
+ num_patches = state_dict[num_patches_key].shape[1]
63
+ unet_config["num_patches"] = num_patches
64
+ unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
65
+
66
+ rms_qk = '{}joint_blocks.0.context_block.attn.ln_q.weight'.format(key_prefix)
67
+ if rms_qk in state_dict_keys:
68
+ unet_config["qk_norm"] = "rms"
69
+
70
+ unet_config["pos_embed_scaling_factor"] = None #unused for inference
71
+ context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
72
+ if context_processor in state_dict_keys:
73
+ unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
74
+ unet_config["x_block_self_attn_layers"] = []
75
+ for key in state_dict_keys:
76
+ if key.startswith('{}joint_blocks.'.format(key_prefix)) and key.endswith('.x_block.attn2.qkv.weight'):
77
+ layer = key[len('{}joint_blocks.'.format(key_prefix)):-len('.x_block.attn2.qkv.weight')]
78
+ unet_config["x_block_self_attn_layers"].append(int(layer))
79
+ return unet_config
80
+
81
+ if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
82
+ unet_config = {}
83
+ text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
84
+ if text_mapper_name in state_dict_keys:
85
+ unet_config['stable_cascade_stage'] = 'c'
86
+ w = state_dict[text_mapper_name]
87
+ if w.shape[0] == 1536: #stage c lite
88
+ unet_config['c_cond'] = 1536
89
+ unet_config['c_hidden'] = [1536, 1536]
90
+ unet_config['nhead'] = [24, 24]
91
+ unet_config['blocks'] = [[4, 12], [12, 4]]
92
+ elif w.shape[0] == 2048: #stage c full
93
+ unet_config['c_cond'] = 2048
94
+ elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
95
+ unet_config['stable_cascade_stage'] = 'b'
96
+ w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
97
+ if w.shape[-1] == 640:
98
+ unet_config['c_hidden'] = [320, 640, 1280, 1280]
99
+ unet_config['nhead'] = [-1, -1, 20, 20]
100
+ unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
101
+ unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
102
+ elif w.shape[-1] == 576: #stage b lite
103
+ unet_config['c_hidden'] = [320, 576, 1152, 1152]
104
+ unet_config['nhead'] = [-1, 9, 18, 18]
105
+ unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
106
+ unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
107
+ return unet_config
108
+
109
+ if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
110
+ unet_config = {}
111
+ unet_config["audio_model"] = "dit1.0"
112
+ return unet_config
113
+
114
+ if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
115
+ unet_config = {}
116
+ unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
117
+ unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
118
+ double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
119
+ single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
120
+ unet_config["n_double_layers"] = double_layers
121
+ unet_config["n_layers"] = double_layers + single_layers
122
+ return unet_config
123
+
124
+ if '{}mlp_t5.0.weight'.format(key_prefix) in state_dict_keys: #Hunyuan DiT
125
+ unet_config = {}
126
+ unet_config["image_model"] = "hydit"
127
+ unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
128
+ unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
129
+ if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2
130
+ unet_config["mlp_ratio"] = 4.3637
131
+ if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
132
+ unet_config["size_cond"] = True
133
+ unet_config["use_style_cond"] = True
134
+ unet_config["image_model"] = "hydit1"
135
+ return unet_config
136
+
137
+ if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
138
+ dit_config = {}
139
+ dit_config["image_model"] = "hunyuan_video"
140
+ dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
141
+ dit_config["patch_size"] = [1, 2, 2]
142
+ dit_config["out_channels"] = 16
143
+ dit_config["vec_in_dim"] = 768
144
+ dit_config["context_in_dim"] = 4096
145
+ dit_config["hidden_size"] = 3072
146
+ dit_config["mlp_ratio"] = 4.0
147
+ dit_config["num_heads"] = 24
148
+ dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
149
+ dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
150
+ dit_config["axes_dim"] = [16, 56, 56]
151
+ dit_config["theta"] = 256
152
+ dit_config["qkv_bias"] = True
153
+ guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
154
+ dit_config["guidance_embed"] = len(guidance_keys) > 0
155
+ return dit_config
156
+
157
+ if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
158
+ dit_config = {}
159
+ dit_config["image_model"] = "flux"
160
+ dit_config["in_channels"] = 16
161
+ patch_size = 2
162
+ dit_config["patch_size"] = patch_size
163
+ in_key = "{}img_in.weight".format(key_prefix)
164
+ if in_key in state_dict_keys:
165
+ dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
166
+ dit_config["out_channels"] = 16
167
+ vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
168
+ if vec_in_key in state_dict_keys:
169
+ dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
170
+ dit_config["context_in_dim"] = 4096
171
+ dit_config["hidden_size"] = 3072
172
+ dit_config["mlp_ratio"] = 4.0
173
+ dit_config["num_heads"] = 24
174
+ dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
175
+ dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
176
+ dit_config["axes_dim"] = [16, 56, 56]
177
+ dit_config["theta"] = 10000
178
+ dit_config["qkv_bias"] = True
179
+ if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
180
+ dit_config["image_model"] = "chroma"
181
+ dit_config["in_channels"] = 64
182
+ dit_config["out_channels"] = 64
183
+ dit_config["in_dim"] = 64
184
+ dit_config["out_dim"] = 3072
185
+ dit_config["hidden_dim"] = 5120
186
+ dit_config["n_layers"] = 5
187
+ else:
188
+ dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
189
+ return dit_config
190
+
191
+ if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
192
+ dit_config = {}
193
+ dit_config["image_model"] = "mochi_preview"
194
+ dit_config["depth"] = 48
195
+ dit_config["patch_size"] = 2
196
+ dit_config["num_heads"] = 24
197
+ dit_config["hidden_size_x"] = 3072
198
+ dit_config["hidden_size_y"] = 1536
199
+ dit_config["mlp_ratio_x"] = 4.0
200
+ dit_config["mlp_ratio_y"] = 4.0
201
+ dit_config["learn_sigma"] = False
202
+ dit_config["in_channels"] = 12
203
+ dit_config["qk_norm"] = True
204
+ dit_config["qkv_bias"] = False
205
+ dit_config["out_bias"] = True
206
+ dit_config["attn_drop"] = 0.0
207
+ dit_config["patch_embed_bias"] = True
208
+ dit_config["posenc_preserve_area"] = True
209
+ dit_config["timestep_mlp_bias"] = True
210
+ dit_config["attend_to_padding"] = False
211
+ dit_config["timestep_scale"] = 1000.0
212
+ dit_config["use_t5"] = True
213
+ dit_config["t5_feat_dim"] = 4096
214
+ dit_config["t5_token_length"] = 256
215
+ dit_config["rope_theta"] = 10000.0
216
+ return dit_config
217
+
218
+ if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys and '{}pos_embed.proj.bias'.format(key_prefix) in state_dict_keys:
219
+ # PixArt diffusers
220
+ return None
221
+
222
+ if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
223
+ dit_config = {}
224
+ dit_config["image_model"] = "ltxv"
225
+ dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
226
+ shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
227
+ dit_config["attention_head_dim"] = shape[0] // 32
228
+ dit_config["cross_attention_dim"] = shape[1]
229
+ if metadata is not None and "config" in metadata:
230
+ dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
231
+ return dit_config
232
+
233
+ if '{}genre_embedder.weight'.format(key_prefix) in state_dict_keys: #ACE-Step model
234
+ dit_config = {}
235
+ dit_config["audio_model"] = "ace"
236
+ dit_config["attention_head_dim"] = 128
237
+ dit_config["in_channels"] = 8
238
+ dit_config["inner_dim"] = 2560
239
+ dit_config["max_height"] = 16
240
+ dit_config["max_position"] = 32768
241
+ dit_config["max_width"] = 32768
242
+ dit_config["mlp_ratio"] = 2.5
243
+ dit_config["num_attention_heads"] = 20
244
+ dit_config["num_layers"] = 24
245
+ dit_config["out_channels"] = 8
246
+ dit_config["patch_size"] = [16, 1]
247
+ dit_config["rope_theta"] = 1000000.0
248
+ dit_config["speaker_embedding_dim"] = 512
249
+ dit_config["text_embedding_dim"] = 768
250
+
251
+ dit_config["ssl_encoder_depths"] = [8, 8]
252
+ dit_config["ssl_latent_dims"] = [1024, 768]
253
+ dit_config["ssl_names"] = ["mert", "m-hubert"]
254
+ dit_config["lyric_encoder_vocab_size"] = 6693
255
+ dit_config["lyric_hidden_size"] = 1024
256
+ return dit_config
257
+
258
+ if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
259
+ patch_size = 2
260
+ dit_config = {}
261
+ dit_config["num_heads"] = 16
262
+ dit_config["patch_size"] = patch_size
263
+ dit_config["hidden_size"] = 1152
264
+ dit_config["in_channels"] = 4
265
+ dit_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
266
+
267
+ y_key = "{}y_embedder.y_embedding".format(key_prefix)
268
+ if y_key in state_dict_keys:
269
+ dit_config["model_max_length"] = state_dict[y_key].shape[0]
270
+
271
+ pe_key = "{}pos_embed".format(key_prefix)
272
+ if pe_key in state_dict_keys:
273
+ dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
274
+ dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
275
+
276
+ ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
277
+ if ar_key in state_dict_keys:
278
+ dit_config["image_model"] = "pixart_alpha"
279
+ dit_config["micro_condition"] = True
280
+ else:
281
+ dit_config["image_model"] = "pixart_sigma"
282
+ dit_config["micro_condition"] = False
283
+ return dit_config
284
+
285
+ if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys: # Cosmos
286
+ dit_config = {}
287
+ dit_config["image_model"] = "cosmos"
288
+ dit_config["max_img_h"] = 240
289
+ dit_config["max_img_w"] = 240
290
+ dit_config["max_frames"] = 128
291
+ concat_padding_mask = True
292
+ dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
293
+ dit_config["out_channels"] = 16
294
+ dit_config["patch_spatial"] = 2
295
+ dit_config["patch_temporal"] = 1
296
+ dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
297
+ dit_config["block_config"] = "FA-CA-MLP"
298
+ dit_config["concat_padding_mask"] = concat_padding_mask
299
+ dit_config["pos_emb_cls"] = "rope3d"
300
+ dit_config["pos_emb_learnable"] = False
301
+ dit_config["pos_emb_interpolation"] = "crop"
302
+ dit_config["block_x_format"] = "THWBD"
303
+ dit_config["affline_emb_norm"] = True
304
+ dit_config["use_adaln_lora"] = True
305
+ dit_config["adaln_lora_dim"] = 256
306
+
307
+ if dit_config["model_channels"] == 4096:
308
+ # 7B
309
+ dit_config["num_blocks"] = 28
310
+ dit_config["num_heads"] = 32
311
+ dit_config["extra_per_block_abs_pos_emb"] = True
312
+ dit_config["rope_h_extrapolation_ratio"] = 1.0
313
+ dit_config["rope_w_extrapolation_ratio"] = 1.0
314
+ dit_config["rope_t_extrapolation_ratio"] = 2.0
315
+ dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
316
+ else: # 5120
317
+ # 14B
318
+ dit_config["num_blocks"] = 36
319
+ dit_config["num_heads"] = 40
320
+ dit_config["extra_per_block_abs_pos_emb"] = True
321
+ dit_config["rope_h_extrapolation_ratio"] = 2.0
322
+ dit_config["rope_w_extrapolation_ratio"] = 2.0
323
+ dit_config["rope_t_extrapolation_ratio"] = 2.0
324
+ dit_config["extra_h_extrapolation_ratio"] = 2.0
325
+ dit_config["extra_w_extrapolation_ratio"] = 2.0
326
+ dit_config["extra_t_extrapolation_ratio"] = 2.0
327
+ dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
328
+ return dit_config
329
+
330
+ if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
331
+ dit_config = {}
332
+ dit_config["image_model"] = "lumina2"
333
+ dit_config["patch_size"] = 2
334
+ dit_config["in_channels"] = 16
335
+ dit_config["dim"] = 2304
336
+ dit_config["cap_feat_dim"] = 2304
337
+ dit_config["n_layers"] = 26
338
+ dit_config["n_heads"] = 24
339
+ dit_config["n_kv_heads"] = 8
340
+ dit_config["qk_norm"] = True
341
+ dit_config["axes_dims"] = [32, 32, 32]
342
+ dit_config["axes_lens"] = [300, 512, 512]
343
+ return dit_config
344
+
345
+ if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
346
+ dit_config = {}
347
+ dit_config["image_model"] = "wan2.1"
348
+ dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
349
+ out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
350
+ dit_config["dim"] = dim
351
+ dit_config["out_dim"] = out_dim
352
+ dit_config["num_heads"] = dim // 128
353
+ dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
354
+ dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
355
+ dit_config["patch_size"] = (1, 2, 2)
356
+ dit_config["freq_dim"] = 256
357
+ dit_config["window_size"] = (-1, -1)
358
+ dit_config["qk_norm"] = True
359
+ dit_config["cross_attn_norm"] = True
360
+ dit_config["eps"] = 1e-6
361
+ dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
362
+ if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
363
+ dit_config["model_type"] = "vace"
364
+ dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
365
+ dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
366
+ elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
367
+ dit_config["model_type"] = "camera"
368
+ else:
369
+ if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
370
+ dit_config["model_type"] = "i2v"
371
+ else:
372
+ dit_config["model_type"] = "t2v"
373
+ flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
374
+ if flf_weight is not None:
375
+ dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
376
+ return dit_config
377
+
378
+ if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
379
+ in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
380
+ dit_config = {}
381
+ dit_config["image_model"] = "hunyuan3d2"
382
+ dit_config["in_channels"] = in_shape[1]
383
+ dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
384
+ dit_config["hidden_size"] = in_shape[0]
385
+ dit_config["mlp_ratio"] = 4.0
386
+ dit_config["num_heads"] = 16
387
+ dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
388
+ dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
389
+ dit_config["qkv_bias"] = True
390
+ dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
391
+ return dit_config
392
+
393
+ if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
394
+ dit_config = {}
395
+ dit_config["image_model"] = "hidream"
396
+ dit_config["attention_head_dim"] = 128
397
+ dit_config["axes_dims_rope"] = [64, 32, 32]
398
+ dit_config["caption_channels"] = [4096, 4096]
399
+ dit_config["max_resolution"] = [128, 128]
400
+ dit_config["in_channels"] = 16
401
+ dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]
402
+ dit_config["num_attention_heads"] = 20
403
+ dit_config["num_routed_experts"] = 4
404
+ dit_config["num_activated_experts"] = 2
405
+ dit_config["num_layers"] = 16
406
+ dit_config["num_single_layers"] = 32
407
+ dit_config["out_channels"] = 16
408
+ dit_config["patch_size"] = 2
409
+ dit_config["text_emb_dim"] = 2048
410
+ return dit_config
411
+
412
+ if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
413
+ dit_config = {}
414
+ dit_config["image_model"] = "cosmos_predict2"
415
+ dit_config["max_img_h"] = 240
416
+ dit_config["max_img_w"] = 240
417
+ dit_config["max_frames"] = 128
418
+ concat_padding_mask = True
419
+ dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
420
+ dit_config["out_channels"] = 16
421
+ dit_config["patch_spatial"] = 2
422
+ dit_config["patch_temporal"] = 1
423
+ dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
424
+ dit_config["concat_padding_mask"] = concat_padding_mask
425
+ dit_config["crossattn_emb_channels"] = 1024
426
+ dit_config["pos_emb_cls"] = "rope3d"
427
+ dit_config["pos_emb_learnable"] = True
428
+ dit_config["pos_emb_interpolation"] = "crop"
429
+ dit_config["min_fps"] = 1
430
+ dit_config["max_fps"] = 30
431
+
432
+ dit_config["use_adaln_lora"] = True
433
+ dit_config["adaln_lora_dim"] = 256
434
+ if dit_config["model_channels"] == 2048:
435
+ dit_config["num_blocks"] = 28
436
+ dit_config["num_heads"] = 16
437
+ elif dit_config["model_channels"] == 5120:
438
+ dit_config["num_blocks"] = 36
439
+ dit_config["num_heads"] = 40
440
+
441
+ if dit_config["in_channels"] == 16:
442
+ dit_config["extra_per_block_abs_pos_emb"] = False
443
+ dit_config["rope_h_extrapolation_ratio"] = 4.0
444
+ dit_config["rope_w_extrapolation_ratio"] = 4.0
445
+ dit_config["rope_t_extrapolation_ratio"] = 1.0
446
+ elif dit_config["in_channels"] == 17: # img to video
447
+ if dit_config["model_channels"] == 2048:
448
+ dit_config["extra_per_block_abs_pos_emb"] = False
449
+ dit_config["rope_h_extrapolation_ratio"] = 3.0
450
+ dit_config["rope_w_extrapolation_ratio"] = 3.0
451
+ dit_config["rope_t_extrapolation_ratio"] = 1.0
452
+ elif dit_config["model_channels"] == 5120:
453
+ dit_config["rope_h_extrapolation_ratio"] = 2.0
454
+ dit_config["rope_w_extrapolation_ratio"] = 2.0
455
+ dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334
456
+
457
+ dit_config["extra_h_extrapolation_ratio"] = 1.0
458
+ dit_config["extra_w_extrapolation_ratio"] = 1.0
459
+ dit_config["extra_t_extrapolation_ratio"] = 1.0
460
+ dit_config["rope_enable_fps_modulation"] = False
461
+
462
+ return dit_config
463
+
464
+ if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
465
+ dit_config = {}
466
+ dit_config["image_model"] = "omnigen2"
467
+ dit_config["axes_dim_rope"] = [40, 40, 40]
468
+ dit_config["axes_lens"] = [1024, 1664, 1664]
469
+ dit_config["ffn_dim_multiplier"] = None
470
+ dit_config["hidden_size"] = 2520
471
+ dit_config["in_channels"] = 16
472
+ dit_config["multiple_of"] = 256
473
+ dit_config["norm_eps"] = 1e-05
474
+ dit_config["num_attention_heads"] = 21
475
+ dit_config["num_kv_heads"] = 7
476
+ dit_config["num_layers"] = 32
477
+ dit_config["num_refiner_layers"] = 2
478
+ dit_config["out_channels"] = None
479
+ dit_config["patch_size"] = 2
480
+ dit_config["text_feat_dim"] = 2048
481
+ dit_config["timestep_scale"] = 1000.0
482
+ return dit_config
483
+
484
+ if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
485
+ return None
486
+
487
+ unet_config = {
488
+ "use_checkpoint": False,
489
+ "image_size": 32,
490
+ "use_spatial_transformer": True,
491
+ "legacy": False
492
+ }
493
+
494
+ y_input = '{}label_emb.0.0.weight'.format(key_prefix)
495
+ if y_input in state_dict_keys:
496
+ unet_config["num_classes"] = "sequential"
497
+ unet_config["adm_in_channels"] = state_dict[y_input].shape[1]
498
+ else:
499
+ unet_config["adm_in_channels"] = None
500
+
501
+ model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
502
+ in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
503
+
504
+ out_key = '{}out.2.weight'.format(key_prefix)
505
+ if out_key in state_dict:
506
+ out_channels = state_dict[out_key].shape[0]
507
+ else:
508
+ out_channels = 4
509
+
510
+ num_res_blocks = []
511
+ channel_mult = []
512
+ transformer_depth = []
513
+ transformer_depth_output = []
514
+ context_dim = None
515
+ use_linear_in_transformer = False
516
+
517
+ video_model = False
518
+ video_model_cross = False
519
+
520
+ current_res = 1
521
+ count = 0
522
+
523
+ last_res_blocks = 0
524
+ last_channel_mult = 0
525
+
526
+ input_block_count = count_blocks(state_dict_keys, '{}input_blocks'.format(key_prefix) + '.{}.')
527
+ for count in range(input_block_count):
528
+ prefix = '{}input_blocks.{}.'.format(key_prefix, count)
529
+ prefix_output = '{}output_blocks.{}.'.format(key_prefix, input_block_count - count - 1)
530
+
531
+ block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys)))
532
+ if len(block_keys) == 0:
533
+ break
534
+
535
+ block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys)))
536
+
537
+ if "{}0.op.weight".format(prefix) in block_keys: #new layer
538
+ num_res_blocks.append(last_res_blocks)
539
+ channel_mult.append(last_channel_mult)
540
+
541
+ current_res *= 2
542
+ last_res_blocks = 0
543
+ last_channel_mult = 0
544
+ out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
545
+ if out is not None:
546
+ transformer_depth_output.append(out[0])
547
+ else:
548
+ transformer_depth_output.append(0)
549
+ else:
550
+ res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
551
+ if res_block_prefix in block_keys:
552
+ last_res_blocks += 1
553
+ last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
554
+
555
+ out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
556
+ if out is not None:
557
+ transformer_depth.append(out[0])
558
+ if context_dim is None:
559
+ context_dim = out[1]
560
+ use_linear_in_transformer = out[2]
561
+ video_model = out[3]
562
+ video_model_cross = out[4]
563
+ else:
564
+ transformer_depth.append(0)
565
+
566
+ res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output)
567
+ if res_block_prefix in block_keys_output:
568
+ out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
569
+ if out is not None:
570
+ transformer_depth_output.append(out[0])
571
+ else:
572
+ transformer_depth_output.append(0)
573
+
574
+
575
+ num_res_blocks.append(last_res_blocks)
576
+ channel_mult.append(last_channel_mult)
577
+ if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
578
+ transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
579
+ elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys:
580
+ transformer_depth_middle = -1
581
+ else:
582
+ transformer_depth_middle = -2
583
+
584
+ unet_config["in_channels"] = in_channels
585
+ unet_config["out_channels"] = out_channels
586
+ unet_config["model_channels"] = model_channels
587
+ unet_config["num_res_blocks"] = num_res_blocks
588
+ unet_config["transformer_depth"] = transformer_depth
589
+ unet_config["transformer_depth_output"] = transformer_depth_output
590
+ unet_config["channel_mult"] = channel_mult
591
+ unet_config["transformer_depth_middle"] = transformer_depth_middle
592
+ unet_config['use_linear_in_transformer'] = use_linear_in_transformer
593
+ unet_config["context_dim"] = context_dim
594
+
595
+ if video_model:
596
+ unet_config["extra_ff_mix_layer"] = True
597
+ unet_config["use_spatial_context"] = True
598
+ unet_config["merge_strategy"] = "learned_with_images"
599
+ unet_config["merge_factor"] = 0.0
600
+ unet_config["video_kernel_size"] = [3, 1, 1]
601
+ unet_config["use_temporal_resblock"] = True
602
+ unet_config["use_temporal_attention"] = True
603
+ unet_config["disable_temporal_crossattention"] = not video_model_cross
604
+ else:
605
+ unet_config["use_temporal_resblock"] = False
606
+ unet_config["use_temporal_attention"] = False
607
+
608
+ return unet_config
609
+
610
+ def model_config_from_unet_config(unet_config, state_dict=None):
611
+ for model_config in comfy.supported_models.models:
612
+ if model_config.matches(unet_config, state_dict):
613
+ return model_config(unet_config)
614
+
615
+ logging.error("no match {}".format(unet_config))
616
+ return None
617
+
618
+ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None):
619
+ unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
620
+ if unet_config is None:
621
+ return None
622
+ model_config = model_config_from_unet_config(unet_config, state_dict)
623
+ if model_config is None and use_base_if_no_match:
624
+ model_config = comfy.supported_models_base.BASE(unet_config)
625
+
626
+ scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
627
+ if scaled_fp8_key in state_dict:
628
+ scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
629
+ model_config.scaled_fp8 = scaled_fp8_weight.dtype
630
+ if model_config.scaled_fp8 == torch.float32:
631
+ model_config.scaled_fp8 = torch.float8_e4m3fn
632
+ if scaled_fp8_weight.nelement() == 2:
633
+ model_config.optimizations["fp8"] = False
634
+ else:
635
+ model_config.optimizations["fp8"] = True
636
+
637
+ return model_config
638
+
639
+ def unet_prefix_from_state_dict(state_dict):
640
+ candidates = ["model.diffusion_model.", #ldm/sgm models
641
+ "model.model.", #audio models
642
+ "net.", #cosmos
643
+ ]
644
+ counts = {k: 0 for k in candidates}
645
+ for k in state_dict:
646
+ for c in candidates:
647
+ if k.startswith(c):
648
+ counts[c] += 1
649
+ break
650
+
651
+ top = max(counts, key=counts.get)
652
+ if counts[top] > 5:
653
+ return top
654
+ else:
655
+ return "model." #aura flow and others
656
+
657
+
658
+ def convert_config(unet_config):
659
+ new_config = unet_config.copy()
660
+ num_res_blocks = new_config.get("num_res_blocks", None)
661
+ channel_mult = new_config.get("channel_mult", None)
662
+
663
+ if isinstance(num_res_blocks, int):
664
+ num_res_blocks = len(channel_mult) * [num_res_blocks]
665
+
666
+ if "attention_resolutions" in new_config:
667
+ attention_resolutions = new_config.pop("attention_resolutions")
668
+ transformer_depth = new_config.get("transformer_depth", None)
669
+ transformer_depth_middle = new_config.get("transformer_depth_middle", None)
670
+
671
+ if isinstance(transformer_depth, int):
672
+ transformer_depth = len(channel_mult) * [transformer_depth]
673
+ if transformer_depth_middle is None:
674
+ transformer_depth_middle = transformer_depth[-1]
675
+ t_in = []
676
+ t_out = []
677
+ s = 1
678
+ for i in range(len(num_res_blocks)):
679
+ res = num_res_blocks[i]
680
+ d = 0
681
+ if s in attention_resolutions:
682
+ d = transformer_depth[i]
683
+
684
+ t_in += [d] * res
685
+ t_out += [d] * (res + 1)
686
+ s *= 2
687
+ transformer_depth = t_in
688
+ new_config["transformer_depth"] = t_in
689
+ new_config["transformer_depth_output"] = t_out
690
+ new_config["transformer_depth_middle"] = transformer_depth_middle
691
+
692
+ new_config["num_res_blocks"] = num_res_blocks
693
+ return new_config
694
+
695
+
696
+ def unet_config_from_diffusers_unet(state_dict, dtype=None):
697
+ if "conv_in.weight" not in state_dict:
698
+ return None
699
+
700
+ match = {}
701
+ transformer_depth = []
702
+
703
+ attn_res = 1
704
+ down_blocks = count_blocks(state_dict, "down_blocks.{}")
705
+ for i in range(down_blocks):
706
+ attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
707
+ res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
708
+ for ab in range(attn_blocks):
709
+ transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
710
+ transformer_depth.append(transformer_count)
711
+ if transformer_count > 0:
712
+ match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
713
+
714
+ attn_res *= 2
715
+ if attn_blocks == 0:
716
+ for i in range(res_blocks):
717
+ transformer_depth.append(0)
718
+
719
+ match["transformer_depth"] = transformer_depth
720
+
721
+ match["model_channels"] = state_dict["conv_in.weight"].shape[0]
722
+ match["in_channels"] = state_dict["conv_in.weight"].shape[1]
723
+ match["adm_in_channels"] = None
724
+ if "class_embedding.linear_1.weight" in state_dict:
725
+ match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
726
+ elif "add_embedding.linear_1.weight" in state_dict:
727
+ match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
728
+
729
+ SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
730
+ 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
731
+ 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
732
+ 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
733
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
734
+
735
+ SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
736
+ 'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
737
+ 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4,
738
+ 'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0],
739
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
740
+
741
+ SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
742
+ 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2],
743
+ 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True,
744
+ 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
745
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
746
+
747
+ SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
748
+ 'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
749
+ 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
750
+ 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
751
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
752
+
753
+ SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
754
+ 'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
755
+ 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
756
+ 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
757
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
758
+
759
+ SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
760
+ 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
761
+ 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
762
+ 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
763
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
764
+
765
+ SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
766
+ 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
767
+ 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1,
768
+ 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
769
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
770
+
771
+ SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
772
+ 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
773
+ 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0,
774
+ 'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
775
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
776
+
777
+ SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
778
+ 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
779
+ 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
780
+ 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
781
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
782
+
783
+ SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
784
+ 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
785
+ 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
786
+ 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
787
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
788
+
789
+ SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
790
+ 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
791
+ 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
792
+ 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
793
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
794
+
795
+ Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
796
+ 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
797
+ 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2],
798
+ 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
799
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
800
+
801
+ KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
802
+ 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
803
+ 'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5],
804
+ 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
805
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
806
+
807
+ KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
808
+ 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
809
+ 'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6],
810
+ 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
811
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
812
+
813
+ SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
814
+ 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
815
+ 'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True,
816
+ 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
817
+ 'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
818
+
819
+ SD_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
820
+ 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
821
+ 'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
822
+ 'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
823
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
824
+
825
+ SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
826
+ 'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
827
+ 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
828
+ 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
829
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
830
+
831
+ LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
832
+ 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
833
+ 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
834
+ 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
835
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
836
+
837
+ supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
838
+
839
+ for unet_config in supported_models:
840
+ matches = True
841
+ for k in match:
842
+ if match[k] != unet_config[k]:
843
+ matches = False
844
+ break
845
+ if matches:
846
+ return convert_config(unet_config)
847
+ return None
848
+
849
+ def model_config_from_diffusers_unet(state_dict):
850
+ unet_config = unet_config_from_diffusers_unet(state_dict)
851
+ if unet_config is not None:
852
+ return model_config_from_unet_config(unet_config)
853
+ return None
854
+
855
+ def convert_diffusers_mmdit(state_dict, output_prefix=""):
856
+ out_sd = {}
857
+
858
+ if 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
859
+ num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
860
+ num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
861
+ sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
862
+ elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
863
+ num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
864
+ sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
865
+ elif 'x_embedder.weight' in state_dict: #Flux
866
+ depth = count_blocks(state_dict, 'transformer_blocks.{}.')
867
+ depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
868
+ hidden_size = state_dict["x_embedder.bias"].shape[0]
869
+ sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
870
+ elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
871
+ num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
872
+ depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
873
+ sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
874
+ else:
875
+ return None
876
+
877
+ for k in sd_map:
878
+ weight = state_dict.get(k, None)
879
+ if weight is not None:
880
+ t = sd_map[k]
881
+
882
+ if not isinstance(t, str):
883
+ if len(t) > 2:
884
+ fun = t[2]
885
+ else:
886
+ fun = lambda a: a
887
+ offset = t[1]
888
+ if offset is not None:
889
+ old_weight = out_sd.get(t[0], None)
890
+ if old_weight is None:
891
+ old_weight = torch.empty_like(weight)
892
+ if old_weight.shape[offset[0]] < offset[1] + offset[2]:
893
+ exp = list(weight.shape)
894
+ exp[offset[0]] = offset[1] + offset[2]
895
+ new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
896
+ new[:old_weight.shape[0]] = old_weight
897
+ old_weight = new
898
+
899
+ w = old_weight.narrow(offset[0], offset[1], offset[2])
900
+ else:
901
+ old_weight = weight
902
+ w = weight
903
+ w[:] = fun(weight)
904
+ t = t[0]
905
+ out_sd[t] = old_weight
906
+ else:
907
+ out_sd[t] = weight
908
+ state_dict.pop(k)
909
+
910
+ return out_sd
ComfyUI/comfy/model_patcher.py ADDED
@@ -0,0 +1,1215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Comfy
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import collections
22
+ import copy
23
+ import inspect
24
+ import logging
25
+ import math
26
+ import uuid
27
+ from typing import Callable, Optional
28
+
29
+ import torch
30
+
31
+ import comfy.float
32
+ import comfy.hooks
33
+ import comfy.lora
34
+ import comfy.model_management
35
+ import comfy.patcher_extension
36
+ import comfy.utils
37
+ from comfy.comfy_types import UnetWrapperFunction
38
+ from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
39
+
40
+
41
+ def string_to_seed(data):
42
+ crc = 0xFFFFFFFF
43
+ for byte in data:
44
+ if isinstance(byte, str):
45
+ byte = ord(byte)
46
+ crc ^= byte
47
+ for _ in range(8):
48
+ if crc & 1:
49
+ crc = (crc >> 1) ^ 0xEDB88320
50
+ else:
51
+ crc >>= 1
52
+ return crc ^ 0xFFFFFFFF
53
+
54
+ def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
55
+ to = model_options["transformer_options"].copy()
56
+
57
+ if "patches_replace" not in to:
58
+ to["patches_replace"] = {}
59
+ else:
60
+ to["patches_replace"] = to["patches_replace"].copy()
61
+
62
+ if name not in to["patches_replace"]:
63
+ to["patches_replace"][name] = {}
64
+ else:
65
+ to["patches_replace"][name] = to["patches_replace"][name].copy()
66
+
67
+ if transformer_index is not None:
68
+ block = (block_name, number, transformer_index)
69
+ else:
70
+ block = (block_name, number)
71
+ to["patches_replace"][name][block] = patch
72
+ model_options["transformer_options"] = to
73
+ return model_options
74
+
75
+ def set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=False):
76
+ model_options["sampler_post_cfg_function"] = model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
77
+ if disable_cfg1_optimization:
78
+ model_options["disable_cfg1_optimization"] = True
79
+ return model_options
80
+
81
+ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
82
+ model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
83
+ if disable_cfg1_optimization:
84
+ model_options["disable_cfg1_optimization"] = True
85
+ return model_options
86
+
87
+ def create_model_options_clone(orig_model_options: dict):
88
+ return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
89
+
90
+ def create_hook_patches_clone(orig_hook_patches):
91
+ new_hook_patches = {}
92
+ for hook_ref in orig_hook_patches:
93
+ new_hook_patches[hook_ref] = {}
94
+ for k in orig_hook_patches[hook_ref]:
95
+ new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
96
+ return new_hook_patches
97
+
98
+ def wipe_lowvram_weight(m):
99
+ if hasattr(m, "prev_comfy_cast_weights"):
100
+ m.comfy_cast_weights = m.prev_comfy_cast_weights
101
+ del m.prev_comfy_cast_weights
102
+
103
+ if hasattr(m, "weight_function"):
104
+ m.weight_function = []
105
+
106
+ if hasattr(m, "bias_function"):
107
+ m.bias_function = []
108
+
109
+ def move_weight_functions(m, device):
110
+ if device is None:
111
+ return 0
112
+
113
+ memory = 0
114
+ if hasattr(m, "weight_function"):
115
+ for f in m.weight_function:
116
+ if hasattr(f, "move_to"):
117
+ memory += f.move_to(device=device)
118
+
119
+ if hasattr(m, "bias_function"):
120
+ for f in m.bias_function:
121
+ if hasattr(f, "move_to"):
122
+ memory += f.move_to(device=device)
123
+ return memory
124
+
125
+ class LowVramPatch:
126
+ def __init__(self, key, patches):
127
+ self.key = key
128
+ self.patches = patches
129
+ def __call__(self, weight):
130
+ intermediate_dtype = weight.dtype
131
+ if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
132
+ intermediate_dtype = torch.float32
133
+ return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
134
+
135
+ return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
136
+
137
+ def get_key_weight(model, key):
138
+ set_func = None
139
+ convert_func = None
140
+ op_keys = key.rsplit('.', 1)
141
+ if len(op_keys) < 2:
142
+ weight = comfy.utils.get_attr(model, key)
143
+ else:
144
+ op = comfy.utils.get_attr(model, op_keys[0])
145
+ try:
146
+ set_func = getattr(op, "set_{}".format(op_keys[1]))
147
+ except AttributeError:
148
+ pass
149
+
150
+ try:
151
+ convert_func = getattr(op, "convert_{}".format(op_keys[1]))
152
+ except AttributeError:
153
+ pass
154
+
155
+ weight = getattr(op, op_keys[1])
156
+ if convert_func is not None:
157
+ weight = comfy.utils.get_attr(model, key)
158
+
159
+ return weight, set_func, convert_func
160
+
161
+ class AutoPatcherEjector:
162
+ def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
163
+ self.model = model
164
+ self.was_injected = False
165
+ self.prev_skip_injection = False
166
+ self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
167
+
168
+ def __enter__(self):
169
+ self.was_injected = False
170
+ self.prev_skip_injection = self.model.skip_injection
171
+ if self.skip_and_inject_on_exit_only:
172
+ self.model.skip_injection = True
173
+ if self.model.is_injected:
174
+ self.model.eject_model()
175
+ self.was_injected = True
176
+
177
+ def __exit__(self, *args):
178
+ if self.skip_and_inject_on_exit_only:
179
+ self.model.skip_injection = self.prev_skip_injection
180
+ self.model.inject_model()
181
+ if self.was_injected and not self.model.skip_injection:
182
+ self.model.inject_model()
183
+ self.model.skip_injection = self.prev_skip_injection
184
+
185
+ class MemoryCounter:
186
+ def __init__(self, initial: int, minimum=0):
187
+ self.value = initial
188
+ self.minimum = minimum
189
+ # TODO: add a safe limit besides 0
190
+
191
+ def use(self, weight: torch.Tensor):
192
+ weight_size = weight.nelement() * weight.element_size()
193
+ if self.is_useable(weight_size):
194
+ self.decrement(weight_size)
195
+ return True
196
+ return False
197
+
198
+ def is_useable(self, used: int):
199
+ return self.value - used > self.minimum
200
+
201
+ def decrement(self, used: int):
202
+ self.value -= used
203
+
204
+ class ModelPatcher:
205
+ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
206
+ self.size = size
207
+ self.model = model
208
+ if not hasattr(self.model, 'device'):
209
+ logging.debug("Model doesn't have a device attribute.")
210
+ self.model.device = offload_device
211
+ elif self.model.device is None:
212
+ self.model.device = offload_device
213
+
214
+ self.patches = {}
215
+ self.backup = {}
216
+ self.object_patches = {}
217
+ self.object_patches_backup = {}
218
+ self.weight_wrapper_patches = {}
219
+ self.model_options = {"transformer_options":{}}
220
+ self.model_size()
221
+ self.load_device = load_device
222
+ self.offload_device = offload_device
223
+ self.weight_inplace_update = weight_inplace_update
224
+ self.force_cast_weights = False
225
+ self.patches_uuid = uuid.uuid4()
226
+ self.parent = None
227
+
228
+ self.attachments: dict[str] = {}
229
+ self.additional_models: dict[str, list[ModelPatcher]] = {}
230
+ self.callbacks: dict[str, dict[str, list[Callable]]] = CallbacksMP.init_callbacks()
231
+ self.wrappers: dict[str, dict[str, list[Callable]]] = WrappersMP.init_wrappers()
232
+
233
+ self.is_injected = False
234
+ self.skip_injection = False
235
+ self.injections: dict[str, list[PatcherInjection]] = {}
236
+
237
+ self.hook_patches: dict[comfy.hooks._HookRef] = {}
238
+ self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
239
+ self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
240
+ self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
241
+ self.current_hooks: Optional[comfy.hooks.HookGroup] = None
242
+ self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP at this time
243
+ self.is_clip = False
244
+ self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
245
+
246
+ if not hasattr(self.model, 'model_loaded_weight_memory'):
247
+ self.model.model_loaded_weight_memory = 0
248
+
249
+ if not hasattr(self.model, 'lowvram_patch_counter'):
250
+ self.model.lowvram_patch_counter = 0
251
+
252
+ if not hasattr(self.model, 'model_lowvram'):
253
+ self.model.model_lowvram = False
254
+
255
+ if not hasattr(self.model, 'current_weight_patches_uuid'):
256
+ self.model.current_weight_patches_uuid = None
257
+
258
+ def model_size(self):
259
+ if self.size > 0:
260
+ return self.size
261
+ self.size = comfy.model_management.module_size(self.model)
262
+ return self.size
263
+
264
+ def loaded_size(self):
265
+ return self.model.model_loaded_weight_memory
266
+
267
+ def lowvram_patch_counter(self):
268
+ return self.model.lowvram_patch_counter
269
+
270
+ def clone(self):
271
+ n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
272
+ n.patches = {}
273
+ for k in self.patches:
274
+ n.patches[k] = self.patches[k][:]
275
+ n.patches_uuid = self.patches_uuid
276
+
277
+ n.object_patches = self.object_patches.copy()
278
+ n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
279
+ n.model_options = copy.deepcopy(self.model_options)
280
+ n.backup = self.backup
281
+ n.object_patches_backup = self.object_patches_backup
282
+ n.parent = self
283
+
284
+ n.force_cast_weights = self.force_cast_weights
285
+
286
+ # attachments
287
+ n.attachments = {}
288
+ for k in self.attachments:
289
+ if hasattr(self.attachments[k], "on_model_patcher_clone"):
290
+ n.attachments[k] = self.attachments[k].on_model_patcher_clone()
291
+ else:
292
+ n.attachments[k] = self.attachments[k]
293
+ # additional models
294
+ for k, c in self.additional_models.items():
295
+ n.additional_models[k] = [x.clone() for x in c]
296
+ # callbacks
297
+ for k, c in self.callbacks.items():
298
+ n.callbacks[k] = {}
299
+ for k1, c1 in c.items():
300
+ n.callbacks[k][k1] = c1.copy()
301
+ # sample wrappers
302
+ for k, w in self.wrappers.items():
303
+ n.wrappers[k] = {}
304
+ for k1, w1 in w.items():
305
+ n.wrappers[k][k1] = w1.copy()
306
+ # injection
307
+ n.is_injected = self.is_injected
308
+ n.skip_injection = self.skip_injection
309
+ for k, i in self.injections.items():
310
+ n.injections[k] = i.copy()
311
+ # hooks
312
+ n.hook_patches = create_hook_patches_clone(self.hook_patches)
313
+ n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
314
+ for group in self.cached_hook_patches:
315
+ n.cached_hook_patches[group] = {}
316
+ for k in self.cached_hook_patches[group]:
317
+ n.cached_hook_patches[group][k] = self.cached_hook_patches[group][k]
318
+ n.hook_backup = self.hook_backup
319
+ n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks
320
+ n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
321
+ n.is_clip = self.is_clip
322
+ n.hook_mode = self.hook_mode
323
+
324
+ for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
325
+ callback(self, n)
326
+ return n
327
+
328
+ def is_clone(self, other):
329
+ if hasattr(other, 'model') and self.model is other.model:
330
+ return True
331
+ return False
332
+
333
+ def clone_has_same_weights(self, clone: 'ModelPatcher'):
334
+ if not self.is_clone(clone):
335
+ return False
336
+
337
+ if self.current_hooks != clone.current_hooks:
338
+ return False
339
+ if self.forced_hooks != clone.forced_hooks:
340
+ return False
341
+ if self.hook_patches.keys() != clone.hook_patches.keys():
342
+ return False
343
+ if self.attachments.keys() != clone.attachments.keys():
344
+ return False
345
+ if self.additional_models.keys() != clone.additional_models.keys():
346
+ return False
347
+ for key in self.callbacks:
348
+ if len(self.callbacks[key]) != len(clone.callbacks[key]):
349
+ return False
350
+ for key in self.wrappers:
351
+ if len(self.wrappers[key]) != len(clone.wrappers[key]):
352
+ return False
353
+ if self.injections.keys() != clone.injections.keys():
354
+ return False
355
+
356
+ if len(self.patches) == 0 and len(clone.patches) == 0:
357
+ return True
358
+
359
+ if self.patches_uuid == clone.patches_uuid:
360
+ if len(self.patches) != len(clone.patches):
361
+ logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
362
+ else:
363
+ return True
364
+
365
+ def memory_required(self, input_shape):
366
+ return self.model.memory_required(input_shape=input_shape)
367
+
368
+ def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
369
+ if len(inspect.signature(sampler_cfg_function).parameters) == 3:
370
+ self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
371
+ else:
372
+ self.model_options["sampler_cfg_function"] = sampler_cfg_function
373
+ if disable_cfg1_optimization:
374
+ self.model_options["disable_cfg1_optimization"] = True
375
+
376
+ def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
377
+ self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
378
+
379
+ def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
380
+ self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
381
+
382
+ def set_model_sampler_calc_cond_batch_function(self, sampler_calc_cond_batch_function):
383
+ self.model_options["sampler_calc_cond_batch_function"] = sampler_calc_cond_batch_function
384
+
385
+ def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
386
+ self.model_options["model_function_wrapper"] = unet_wrapper_function
387
+
388
+ def set_model_denoise_mask_function(self, denoise_mask_function):
389
+ self.model_options["denoise_mask_function"] = denoise_mask_function
390
+
391
+ def set_model_patch(self, patch, name):
392
+ to = self.model_options["transformer_options"]
393
+ if "patches" not in to:
394
+ to["patches"] = {}
395
+ to["patches"][name] = to["patches"].get(name, []) + [patch]
396
+
397
+ def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
398
+ self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index)
399
+
400
+ def set_model_attn1_patch(self, patch):
401
+ self.set_model_patch(patch, "attn1_patch")
402
+
403
+ def set_model_attn2_patch(self, patch):
404
+ self.set_model_patch(patch, "attn2_patch")
405
+
406
+ def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
407
+ self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
408
+
409
+ def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
410
+ self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
411
+
412
+ def set_model_attn1_output_patch(self, patch):
413
+ self.set_model_patch(patch, "attn1_output_patch")
414
+
415
+ def set_model_attn2_output_patch(self, patch):
416
+ self.set_model_patch(patch, "attn2_output_patch")
417
+
418
+ def set_model_input_block_patch(self, patch):
419
+ self.set_model_patch(patch, "input_block_patch")
420
+
421
+ def set_model_input_block_patch_after_skip(self, patch):
422
+ self.set_model_patch(patch, "input_block_patch_after_skip")
423
+
424
+ def set_model_output_block_patch(self, patch):
425
+ self.set_model_patch(patch, "output_block_patch")
426
+
427
+ def set_model_emb_patch(self, patch):
428
+ self.set_model_patch(patch, "emb_patch")
429
+
430
+ def set_model_forward_timestep_embed_patch(self, patch):
431
+ self.set_model_patch(patch, "forward_timestep_embed_patch")
432
+
433
+ def add_object_patch(self, name, obj):
434
+ self.object_patches[name] = obj
435
+
436
+ def set_model_compute_dtype(self, dtype):
437
+ self.add_object_patch("manual_cast_dtype", dtype)
438
+ if dtype is not None:
439
+ self.force_cast_weights = True
440
+ self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
441
+
442
+ def add_weight_wrapper(self, name, function):
443
+ self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
444
+ self.patches_uuid = uuid.uuid4()
445
+
446
+ def get_model_object(self, name: str) -> torch.nn.Module:
447
+ """Retrieves a nested attribute from an object using dot notation considering
448
+ object patches.
449
+
450
+ Args:
451
+ name (str): The attribute path using dot notation (e.g. "model.layer.weight")
452
+
453
+ Returns:
454
+ The value of the requested attribute
455
+
456
+ Example:
457
+ patcher = ModelPatcher()
458
+ weight = patcher.get_model_object("layer1.conv.weight")
459
+ """
460
+ if name in self.object_patches:
461
+ return self.object_patches[name]
462
+ else:
463
+ if name in self.object_patches_backup:
464
+ return self.object_patches_backup[name]
465
+ else:
466
+ return comfy.utils.get_attr(self.model, name)
467
+
468
+ def model_patches_to(self, device):
469
+ to = self.model_options["transformer_options"]
470
+ if "patches" in to:
471
+ patches = to["patches"]
472
+ for name in patches:
473
+ patch_list = patches[name]
474
+ for i in range(len(patch_list)):
475
+ if hasattr(patch_list[i], "to"):
476
+ patch_list[i] = patch_list[i].to(device)
477
+ if "patches_replace" in to:
478
+ patches = to["patches_replace"]
479
+ for name in patches:
480
+ patch_list = patches[name]
481
+ for k in patch_list:
482
+ if hasattr(patch_list[k], "to"):
483
+ patch_list[k] = patch_list[k].to(device)
484
+ if "model_function_wrapper" in self.model_options:
485
+ wrap_func = self.model_options["model_function_wrapper"]
486
+ if hasattr(wrap_func, "to"):
487
+ self.model_options["model_function_wrapper"] = wrap_func.to(device)
488
+
489
+ def model_dtype(self):
490
+ if hasattr(self.model, "get_dtype"):
491
+ return self.model.get_dtype()
492
+
493
+ def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
494
+ with self.use_ejected():
495
+ p = set()
496
+ model_sd = self.model.state_dict()
497
+ for k in patches:
498
+ offset = None
499
+ function = None
500
+ if isinstance(k, str):
501
+ key = k
502
+ else:
503
+ offset = k[1]
504
+ key = k[0]
505
+ if len(k) > 2:
506
+ function = k[2]
507
+
508
+ if key in model_sd:
509
+ p.add(k)
510
+ current_patches = self.patches.get(key, [])
511
+ current_patches.append((strength_patch, patches[k], strength_model, offset, function))
512
+ self.patches[key] = current_patches
513
+
514
+ self.patches_uuid = uuid.uuid4()
515
+ return list(p)
516
+
517
+ def get_key_patches(self, filter_prefix=None):
518
+ model_sd = self.model_state_dict()
519
+ p = {}
520
+ for k in model_sd:
521
+ if filter_prefix is not None:
522
+ if not k.startswith(filter_prefix):
523
+ continue
524
+ bk = self.backup.get(k, None)
525
+ hbk = self.hook_backup.get(k, None)
526
+ weight, set_func, convert_func = get_key_weight(self.model, k)
527
+ if bk is not None:
528
+ weight = bk.weight
529
+ if hbk is not None:
530
+ weight = hbk[0]
531
+ if convert_func is None:
532
+ convert_func = lambda a, **kwargs: a
533
+
534
+ if k in self.patches:
535
+ p[k] = [(weight, convert_func)] + self.patches[k]
536
+ else:
537
+ p[k] = [(weight, convert_func)]
538
+ return p
539
+
540
+ def model_state_dict(self, filter_prefix=None):
541
+ with self.use_ejected():
542
+ sd = self.model.state_dict()
543
+ keys = list(sd.keys())
544
+ if filter_prefix is not None:
545
+ for k in keys:
546
+ if not k.startswith(filter_prefix):
547
+ sd.pop(k)
548
+ return sd
549
+
550
+ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
551
+ if key not in self.patches:
552
+ return
553
+
554
+ weight, set_func, convert_func = get_key_weight(self.model, key)
555
+ inplace_update = self.weight_inplace_update or inplace_update
556
+
557
+ if key not in self.backup:
558
+ self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
559
+
560
+ if device_to is not None:
561
+ temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
562
+ else:
563
+ temp_weight = weight.to(torch.float32, copy=True)
564
+ if convert_func is not None:
565
+ temp_weight = convert_func(temp_weight, inplace=True)
566
+
567
+ out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
568
+ if set_func is None:
569
+ out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
570
+ if inplace_update:
571
+ comfy.utils.copy_to_param(self.model, key, out_weight)
572
+ else:
573
+ comfy.utils.set_attr_param(self.model, key, out_weight)
574
+ else:
575
+ set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
576
+
577
+ def _load_list(self):
578
+ loading = []
579
+ for n, m in self.model.named_modules():
580
+ params = []
581
+ skip = False
582
+ for name, param in m.named_parameters(recurse=False):
583
+ params.append(name)
584
+ for name, param in m.named_parameters(recurse=True):
585
+ if name not in params:
586
+ skip = True # skip random weights in non leaf modules
587
+ break
588
+ if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
589
+ loading.append((comfy.model_management.module_size(m), n, m, params))
590
+ return loading
591
+
592
+ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
593
+ with self.use_ejected():
594
+ self.unpatch_hooks()
595
+ mem_counter = 0
596
+ patch_counter = 0
597
+ lowvram_counter = 0
598
+ loading = self._load_list()
599
+
600
+ load_completely = []
601
+ loading.sort(reverse=True)
602
+ for x in loading:
603
+ n = x[1]
604
+ m = x[2]
605
+ params = x[3]
606
+ module_mem = x[0]
607
+
608
+ lowvram_weight = False
609
+
610
+ weight_key = "{}.weight".format(n)
611
+ bias_key = "{}.bias".format(n)
612
+
613
+ if not full_load and hasattr(m, "comfy_cast_weights"):
614
+ if mem_counter + module_mem >= lowvram_model_memory:
615
+ lowvram_weight = True
616
+ lowvram_counter += 1
617
+ if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
618
+ continue
619
+
620
+ cast_weight = self.force_cast_weights
621
+ if lowvram_weight:
622
+ if hasattr(m, "comfy_cast_weights"):
623
+ m.weight_function = []
624
+ m.bias_function = []
625
+
626
+ if weight_key in self.patches:
627
+ if force_patch_weights:
628
+ self.patch_weight_to_device(weight_key)
629
+ else:
630
+ m.weight_function = [LowVramPatch(weight_key, self.patches)]
631
+ patch_counter += 1
632
+ if bias_key in self.patches:
633
+ if force_patch_weights:
634
+ self.patch_weight_to_device(bias_key)
635
+ else:
636
+ m.bias_function = [LowVramPatch(bias_key, self.patches)]
637
+ patch_counter += 1
638
+
639
+ cast_weight = True
640
+ else:
641
+ if hasattr(m, "comfy_cast_weights"):
642
+ wipe_lowvram_weight(m)
643
+
644
+ if full_load or mem_counter + module_mem < lowvram_model_memory:
645
+ mem_counter += module_mem
646
+ load_completely.append((module_mem, n, m, params))
647
+
648
+ if cast_weight and hasattr(m, "comfy_cast_weights"):
649
+ m.prev_comfy_cast_weights = m.comfy_cast_weights
650
+ m.comfy_cast_weights = True
651
+
652
+ if weight_key in self.weight_wrapper_patches:
653
+ m.weight_function.extend(self.weight_wrapper_patches[weight_key])
654
+
655
+ if bias_key in self.weight_wrapper_patches:
656
+ m.bias_function.extend(self.weight_wrapper_patches[bias_key])
657
+
658
+ mem_counter += move_weight_functions(m, device_to)
659
+
660
+ load_completely.sort(reverse=True)
661
+ for x in load_completely:
662
+ n = x[1]
663
+ m = x[2]
664
+ params = x[3]
665
+ if hasattr(m, "comfy_patched_weights"):
666
+ if m.comfy_patched_weights == True:
667
+ continue
668
+
669
+ for param in params:
670
+ self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
671
+
672
+ logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
673
+ m.comfy_patched_weights = True
674
+
675
+ for x in load_completely:
676
+ x[2].to(device_to)
677
+
678
+ if lowvram_counter > 0:
679
+ logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
680
+ self.model.model_lowvram = True
681
+ else:
682
+ logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
683
+ self.model.model_lowvram = False
684
+ if full_load:
685
+ self.model.to(device_to)
686
+ mem_counter = self.model_size()
687
+
688
+ self.model.lowvram_patch_counter += patch_counter
689
+ self.model.device = device_to
690
+ self.model.model_loaded_weight_memory = mem_counter
691
+ self.model.current_weight_patches_uuid = self.patches_uuid
692
+
693
+ for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
694
+ callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
695
+
696
+ self.apply_hooks(self.forced_hooks, force_apply=True)
697
+
698
+ def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
699
+ with self.use_ejected():
700
+ for k in self.object_patches:
701
+ old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
702
+ if k not in self.object_patches_backup:
703
+ self.object_patches_backup[k] = old
704
+
705
+ if lowvram_model_memory == 0:
706
+ full_load = True
707
+ else:
708
+ full_load = False
709
+
710
+ if load_weights:
711
+ self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
712
+ self.inject_model()
713
+ return self.model
714
+
715
+ def unpatch_model(self, device_to=None, unpatch_weights=True):
716
+ self.eject_model()
717
+ if unpatch_weights:
718
+ self.unpatch_hooks()
719
+ if self.model.model_lowvram:
720
+ for m in self.model.modules():
721
+ move_weight_functions(m, device_to)
722
+ wipe_lowvram_weight(m)
723
+
724
+ self.model.model_lowvram = False
725
+ self.model.lowvram_patch_counter = 0
726
+
727
+ keys = list(self.backup.keys())
728
+
729
+ for k in keys:
730
+ bk = self.backup[k]
731
+ if bk.inplace_update:
732
+ comfy.utils.copy_to_param(self.model, k, bk.weight)
733
+ else:
734
+ comfy.utils.set_attr_param(self.model, k, bk.weight)
735
+
736
+ self.model.current_weight_patches_uuid = None
737
+ self.backup.clear()
738
+
739
+ if device_to is not None:
740
+ self.model.to(device_to)
741
+ self.model.device = device_to
742
+ self.model.model_loaded_weight_memory = 0
743
+
744
+ for m in self.model.modules():
745
+ if hasattr(m, "comfy_patched_weights"):
746
+ del m.comfy_patched_weights
747
+
748
+ keys = list(self.object_patches_backup.keys())
749
+ for k in keys:
750
+ comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
751
+
752
+ self.object_patches_backup.clear()
753
+
754
+ def partially_unload(self, device_to, memory_to_free=0):
755
+ with self.use_ejected():
756
+ hooks_unpatched = False
757
+ memory_freed = 0
758
+ patch_counter = 0
759
+ unload_list = self._load_list()
760
+ unload_list.sort()
761
+ for unload in unload_list:
762
+ if memory_to_free < memory_freed:
763
+ break
764
+ module_mem = unload[0]
765
+ n = unload[1]
766
+ m = unload[2]
767
+ params = unload[3]
768
+
769
+ lowvram_possible = hasattr(m, "comfy_cast_weights")
770
+ if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
771
+ move_weight = True
772
+ for param in params:
773
+ key = "{}.{}".format(n, param)
774
+ bk = self.backup.get(key, None)
775
+ if bk is not None:
776
+ if not lowvram_possible:
777
+ move_weight = False
778
+ break
779
+
780
+ if not hooks_unpatched:
781
+ self.unpatch_hooks()
782
+ hooks_unpatched = True
783
+
784
+ if bk.inplace_update:
785
+ comfy.utils.copy_to_param(self.model, key, bk.weight)
786
+ else:
787
+ comfy.utils.set_attr_param(self.model, key, bk.weight)
788
+ self.backup.pop(key)
789
+
790
+ weight_key = "{}.weight".format(n)
791
+ bias_key = "{}.bias".format(n)
792
+ if move_weight:
793
+ cast_weight = self.force_cast_weights
794
+ m.to(device_to)
795
+ module_mem += move_weight_functions(m, device_to)
796
+ if lowvram_possible:
797
+ if weight_key in self.patches:
798
+ m.weight_function.append(LowVramPatch(weight_key, self.patches))
799
+ patch_counter += 1
800
+ if bias_key in self.patches:
801
+ m.bias_function.append(LowVramPatch(bias_key, self.patches))
802
+ patch_counter += 1
803
+ cast_weight = True
804
+
805
+ if cast_weight:
806
+ m.prev_comfy_cast_weights = m.comfy_cast_weights
807
+ m.comfy_cast_weights = True
808
+ m.comfy_patched_weights = False
809
+ memory_freed += module_mem
810
+ logging.debug("freed {}".format(n))
811
+
812
+ self.model.model_lowvram = True
813
+ self.model.lowvram_patch_counter += patch_counter
814
+ self.model.model_loaded_weight_memory -= memory_freed
815
+ return memory_freed
816
+
817
+ def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
818
+ with self.use_ejected(skip_and_inject_on_exit_only=True):
819
+ unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
820
+ # TODO: force_patch_weights should not unload + reload full model
821
+ used = self.model.model_loaded_weight_memory
822
+ self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
823
+ if unpatch_weights:
824
+ extra_memory += (used - self.model.model_loaded_weight_memory)
825
+
826
+ self.patch_model(load_weights=False)
827
+ full_load = False
828
+ if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
829
+ self.apply_hooks(self.forced_hooks, force_apply=True)
830
+ return 0
831
+ if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
832
+ full_load = True
833
+ current_used = self.model.model_loaded_weight_memory
834
+ try:
835
+ self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
836
+ except Exception as e:
837
+ self.detach()
838
+ raise e
839
+
840
+ return self.model.model_loaded_weight_memory - current_used
841
+
842
+ def detach(self, unpatch_all=True):
843
+ self.eject_model()
844
+ self.model_patches_to(self.offload_device)
845
+ if unpatch_all:
846
+ self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
847
+ for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
848
+ callback(self, unpatch_all)
849
+ return self.model
850
+
851
+ def current_loaded_device(self):
852
+ return self.model.device
853
+
854
+ def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
855
+ logging.warning("The ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
856
+ return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
857
+
858
+ def cleanup(self):
859
+ self.clean_hooks()
860
+ if hasattr(self.model, "current_patcher"):
861
+ self.model.current_patcher = None
862
+ for callback in self.get_all_callbacks(CallbacksMP.ON_CLEANUP):
863
+ callback(self)
864
+
865
+ def add_callback(self, call_type: str, callback: Callable):
866
+ self.add_callback_with_key(call_type, None, callback)
867
+
868
+ def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
869
+ c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
870
+ c.append(callback)
871
+
872
+ def remove_callbacks_with_key(self, call_type: str, key: str):
873
+ c = self.callbacks.get(call_type, {})
874
+ if key in c:
875
+ c.pop(key)
876
+
877
+ def get_callbacks(self, call_type: str, key: str):
878
+ return self.callbacks.get(call_type, {}).get(key, [])
879
+
880
+ def get_all_callbacks(self, call_type: str):
881
+ c_list = []
882
+ for c in self.callbacks.get(call_type, {}).values():
883
+ c_list.extend(c)
884
+ return c_list
885
+
886
+ def add_wrapper(self, wrapper_type: str, wrapper: Callable):
887
+ self.add_wrapper_with_key(wrapper_type, None, wrapper)
888
+
889
+ def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
890
+ w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
891
+ w.append(wrapper)
892
+
893
+ def remove_wrappers_with_key(self, wrapper_type: str, key: str):
894
+ w = self.wrappers.get(wrapper_type, {})
895
+ if key in w:
896
+ w.pop(key)
897
+
898
+ def get_wrappers(self, wrapper_type: str, key: str):
899
+ return self.wrappers.get(wrapper_type, {}).get(key, [])
900
+
901
+ def get_all_wrappers(self, wrapper_type: str):
902
+ w_list = []
903
+ for w in self.wrappers.get(wrapper_type, {}).values():
904
+ w_list.extend(w)
905
+ return w_list
906
+
907
+ def set_attachments(self, key: str, attachment):
908
+ self.attachments[key] = attachment
909
+
910
+ def remove_attachments(self, key: str):
911
+ if key in self.attachments:
912
+ self.attachments.pop(key)
913
+
914
+ def get_attachment(self, key: str):
915
+ return self.attachments.get(key, None)
916
+
917
+ def set_injections(self, key: str, injections: list[PatcherInjection]):
918
+ self.injections[key] = injections
919
+
920
+ def remove_injections(self, key: str):
921
+ if key in self.injections:
922
+ self.injections.pop(key)
923
+
924
+ def get_injections(self, key: str):
925
+ return self.injections.get(key, None)
926
+
927
+ def set_additional_models(self, key: str, models: list['ModelPatcher']):
928
+ self.additional_models[key] = models
929
+
930
+ def remove_additional_models(self, key: str):
931
+ if key in self.additional_models:
932
+ self.additional_models.pop(key)
933
+
934
+ def get_additional_models_with_key(self, key: str):
935
+ return self.additional_models.get(key, [])
936
+
937
+ def get_additional_models(self):
938
+ all_models = []
939
+ for models in self.additional_models.values():
940
+ all_models.extend(models)
941
+ return all_models
942
+
943
+ def get_nested_additional_models(self):
944
+ def _evaluate_sub_additional_models(prev_models: list[ModelPatcher], cache_set: set[ModelPatcher]):
945
+ '''Make sure circular references do not cause infinite recursion.'''
946
+ next_models = []
947
+ for model in prev_models:
948
+ candidates = model.get_additional_models()
949
+ for c in candidates:
950
+ if c not in cache_set:
951
+ next_models.append(c)
952
+ cache_set.add(c)
953
+ if len(next_models) == 0:
954
+ return prev_models
955
+ return prev_models + _evaluate_sub_additional_models(next_models, cache_set)
956
+
957
+ all_models = self.get_additional_models()
958
+ models_set = set(all_models)
959
+ real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set)
960
+ return real_all_models
961
+
962
+ def use_ejected(self, skip_and_inject_on_exit_only=False):
963
+ return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only)
964
+
965
+ def inject_model(self):
966
+ if self.is_injected or self.skip_injection:
967
+ return
968
+ for injections in self.injections.values():
969
+ for inj in injections:
970
+ inj.inject(self)
971
+ self.is_injected = True
972
+ if self.is_injected:
973
+ for callback in self.get_all_callbacks(CallbacksMP.ON_INJECT_MODEL):
974
+ callback(self)
975
+
976
+ def eject_model(self):
977
+ if not self.is_injected:
978
+ return
979
+ for injections in self.injections.values():
980
+ for inj in injections:
981
+ inj.eject(self)
982
+ self.is_injected = False
983
+ for callback in self.get_all_callbacks(CallbacksMP.ON_EJECT_MODEL):
984
+ callback(self)
985
+
986
+ def pre_run(self):
987
+ if hasattr(self.model, "current_patcher"):
988
+ self.model.current_patcher = self
989
+ for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
990
+ callback(self)
991
+
992
+ def prepare_state(self, timestep):
993
+ for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
994
+ callback(self, timestep)
995
+
996
+ def restore_hook_patches(self):
997
+ if self.hook_patches_backup is not None:
998
+ self.hook_patches = self.hook_patches_backup
999
+ self.hook_patches_backup = None
1000
+
1001
+ def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
1002
+ self.hook_mode = hook_mode
1003
+
1004
+ def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
1005
+ curr_t = t[0]
1006
+ reset_current_hooks = False
1007
+ transformer_options = model_options.get("transformer_options", {})
1008
+ for hook in hook_group.hooks:
1009
+ changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
1010
+ # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
1011
+ # this will cause the weights to be recalculated when sampling
1012
+ if changed:
1013
+ # reset current_hooks if contains hook that changed
1014
+ if self.current_hooks is not None:
1015
+ for current_hook in self.current_hooks.hooks:
1016
+ if current_hook == hook:
1017
+ reset_current_hooks = True
1018
+ break
1019
+ for cached_group in list(self.cached_hook_patches.keys()):
1020
+ if cached_group.contains(hook):
1021
+ self.cached_hook_patches.pop(cached_group)
1022
+ if reset_current_hooks:
1023
+ self.patch_hooks(None)
1024
+
1025
+ def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
1026
+ registered: comfy.hooks.HookGroup = None):
1027
+ self.restore_hook_patches()
1028
+ if registered is None:
1029
+ registered = comfy.hooks.HookGroup()
1030
+ # handle WeightHooks
1031
+ weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
1032
+ for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
1033
+ if hook.hook_ref not in self.hook_patches:
1034
+ weight_hooks_to_register.append(hook)
1035
+ else:
1036
+ registered.add(hook)
1037
+ if len(weight_hooks_to_register) > 0:
1038
+ # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
1039
+ self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
1040
+ for hook in weight_hooks_to_register:
1041
+ hook.add_hook_patches(self, model_options, target_dict, registered)
1042
+ for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
1043
+ callback(self, hooks, target_dict, model_options, registered)
1044
+ return registered
1045
+
1046
+ def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
1047
+ with self.use_ejected():
1048
+ # NOTE: this mirrors behavior of add_patches func
1049
+ current_hook_patches: dict[str,list] = self.hook_patches.get(hook.hook_ref, {})
1050
+ p = set()
1051
+ model_sd = self.model.state_dict()
1052
+ for k in patches:
1053
+ offset = None
1054
+ function = None
1055
+ if isinstance(k, str):
1056
+ key = k
1057
+ else:
1058
+ offset = k[1]
1059
+ key = k[0]
1060
+ if len(k) > 2:
1061
+ function = k[2]
1062
+
1063
+ if key in model_sd:
1064
+ p.add(k)
1065
+ current_patches: list[tuple] = current_hook_patches.get(key, [])
1066
+ current_patches.append((strength_patch, patches[k], strength_model, offset, function))
1067
+ current_hook_patches[key] = current_patches
1068
+ self.hook_patches[hook.hook_ref] = current_hook_patches
1069
+ # since should care about these patches too to determine if same model, reroll patches_uuid
1070
+ self.patches_uuid = uuid.uuid4()
1071
+ return list(p)
1072
+
1073
+ def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup):
1074
+ # combined_patches will contain weights of all relevant hooks, per key
1075
+ combined_patches = {}
1076
+ if hooks is not None:
1077
+ for hook in hooks.hooks:
1078
+ hook_patches: dict = self.hook_patches.get(hook.hook_ref, {})
1079
+ for key in hook_patches.keys():
1080
+ current_patches: list[tuple] = combined_patches.get(key, [])
1081
+ if math.isclose(hook.strength, 1.0):
1082
+ current_patches.extend(hook_patches[key])
1083
+ else:
1084
+ # patches are stored as tuples: (strength_patch, (tuple_with_weights,), strength_model)
1085
+ for patch in hook_patches[key]:
1086
+ new_patch = list(patch)
1087
+ new_patch[0] *= hook.strength
1088
+ current_patches.append(tuple(new_patch))
1089
+ combined_patches[key] = current_patches
1090
+ return combined_patches
1091
+
1092
+ def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
1093
+ # TODO: return transformer_options dict with any additions from hooks
1094
+ if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
1095
+ return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
1096
+ self.patch_hooks(hooks=hooks)
1097
+ for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
1098
+ callback(self, hooks)
1099
+ return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
1100
+
1101
+ def patch_hooks(self, hooks: comfy.hooks.HookGroup):
1102
+ with self.use_ejected():
1103
+ if hooks is not None:
1104
+ model_sd_keys = list(self.model_state_dict().keys())
1105
+ memory_counter = None
1106
+ if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
1107
+ # TODO: minimum_counter should have a minimum that conforms to loaded model requirements
1108
+ memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device),
1109
+ minimum=comfy.model_management.minimum_inference_memory()*2)
1110
+ # if have cached weights for hooks, use it
1111
+ cached_weights = self.cached_hook_patches.get(hooks, None)
1112
+ if cached_weights is not None:
1113
+ model_sd_keys_set = set(model_sd_keys)
1114
+ for key in cached_weights:
1115
+ if key not in model_sd_keys:
1116
+ logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
1117
+ continue
1118
+ self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
1119
+ model_sd_keys_set.remove(key)
1120
+ self.unpatch_hooks(model_sd_keys_set)
1121
+ else:
1122
+ self.unpatch_hooks()
1123
+ relevant_patches = self.get_combined_hook_patches(hooks=hooks)
1124
+ original_weights = None
1125
+ if len(relevant_patches) > 0:
1126
+ original_weights = self.get_key_patches()
1127
+ for key in relevant_patches:
1128
+ if key not in model_sd_keys:
1129
+ logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
1130
+ continue
1131
+ self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
1132
+ memory_counter=memory_counter)
1133
+ else:
1134
+ self.unpatch_hooks()
1135
+ self.current_hooks = hooks
1136
+
1137
+ def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
1138
+ if key not in self.hook_backup:
1139
+ weight: torch.Tensor = comfy.utils.get_attr(self.model, key)
1140
+ target_device = self.offload_device
1141
+ if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
1142
+ used = memory_counter.use(weight)
1143
+ if used:
1144
+ target_device = weight.device
1145
+ self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
1146
+ comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
1147
+
1148
+ def clear_cached_hook_weights(self):
1149
+ self.cached_hook_patches.clear()
1150
+ self.patch_hooks(None)
1151
+
1152
+ def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
1153
+ if key not in combined_patches:
1154
+ return
1155
+
1156
+ weight, set_func, convert_func = get_key_weight(self.model, key)
1157
+ weight: torch.Tensor
1158
+ if key not in self.hook_backup:
1159
+ target_device = self.offload_device
1160
+ if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
1161
+ used = memory_counter.use(weight)
1162
+ if used:
1163
+ target_device = weight.device
1164
+ self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
1165
+ # TODO: properly handle LowVramPatch, if it ends up an issue
1166
+ temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True)
1167
+ if convert_func is not None:
1168
+ temp_weight = convert_func(temp_weight, inplace=True)
1169
+
1170
+ out_weight = comfy.lora.calculate_weight(combined_patches[key],
1171
+ temp_weight,
1172
+ key, original_weights=original_weights)
1173
+ del original_weights[key]
1174
+ if set_func is None:
1175
+ out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
1176
+ comfy.utils.copy_to_param(self.model, key, out_weight)
1177
+ else:
1178
+ set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
1179
+ if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
1180
+ # TODO: disable caching if not enough system RAM to do so
1181
+ target_device = self.offload_device
1182
+ used = memory_counter.use(weight)
1183
+ if used:
1184
+ target_device = weight.device
1185
+ self.cached_hook_patches.setdefault(hooks, {})
1186
+ self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device)
1187
+ del temp_weight
1188
+ del out_weight
1189
+ del weight
1190
+
1191
+ def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
1192
+ with self.use_ejected():
1193
+ if len(self.hook_backup) == 0:
1194
+ self.current_hooks = None
1195
+ return
1196
+ keys = list(self.hook_backup.keys())
1197
+ if whitelist_keys_set:
1198
+ for k in keys:
1199
+ if k in whitelist_keys_set:
1200
+ comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
1201
+ self.hook_backup.pop(k)
1202
+ else:
1203
+ for k in keys:
1204
+ comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
1205
+
1206
+ self.hook_backup.clear()
1207
+ self.current_hooks = None
1208
+
1209
+ def clean_hooks(self):
1210
+ self.unpatch_hooks()
1211
+ self.clear_cached_hook_weights()
1212
+
1213
+ def __del__(self):
1214
+ self.detach(unpatch_all=False)
1215
+
ComfyUI/comfy/ops.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ import logging
21
+ import comfy.model_management
22
+ from comfy.cli_args import args, PerformanceFeature
23
+ import comfy.float
24
+ import comfy.rmsnorm
25
+ import contextlib
26
+
27
+ cast_to = comfy.model_management.cast_to #TODO: remove once no more references
28
+
29
+ def cast_to_input(weight, input, non_blocking=False, copy=True):
30
+ return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
31
+
32
+ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
33
+ if input is not None:
34
+ if dtype is None:
35
+ dtype = input.dtype
36
+ if bias_dtype is None:
37
+ bias_dtype = dtype
38
+ if device is None:
39
+ device = input.device
40
+
41
+ offload_stream = comfy.model_management.get_offload_stream(device)
42
+ if offload_stream is not None:
43
+ wf_context = offload_stream
44
+ else:
45
+ wf_context = contextlib.nullcontext()
46
+
47
+ bias = None
48
+ non_blocking = comfy.model_management.device_supports_non_blocking(device)
49
+ if s.bias is not None:
50
+ has_function = len(s.bias_function) > 0
51
+ bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
52
+
53
+ if has_function:
54
+ with wf_context:
55
+ for f in s.bias_function:
56
+ bias = f(bias)
57
+
58
+ has_function = len(s.weight_function) > 0
59
+ weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
60
+ if has_function:
61
+ with wf_context:
62
+ for f in s.weight_function:
63
+ weight = f(weight)
64
+
65
+ comfy.model_management.sync_stream(device, offload_stream)
66
+ return weight, bias
67
+
68
+ class CastWeightBiasOp:
69
+ comfy_cast_weights = False
70
+ weight_function = []
71
+ bias_function = []
72
+
73
+ class disable_weight_init:
74
+ class Linear(torch.nn.Linear, CastWeightBiasOp):
75
+ def reset_parameters(self):
76
+ return None
77
+
78
+ def forward_comfy_cast_weights(self, input):
79
+ weight, bias = cast_bias_weight(self, input)
80
+ return torch.nn.functional.linear(input, weight, bias)
81
+
82
+ def forward(self, *args, **kwargs):
83
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
84
+ return self.forward_comfy_cast_weights(*args, **kwargs)
85
+ else:
86
+ return super().forward(*args, **kwargs)
87
+
88
+ class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
89
+ def reset_parameters(self):
90
+ return None
91
+
92
+ def forward_comfy_cast_weights(self, input):
93
+ weight, bias = cast_bias_weight(self, input)
94
+ return self._conv_forward(input, weight, bias)
95
+
96
+ def forward(self, *args, **kwargs):
97
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
98
+ return self.forward_comfy_cast_weights(*args, **kwargs)
99
+ else:
100
+ return super().forward(*args, **kwargs)
101
+
102
+ class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
103
+ def reset_parameters(self):
104
+ return None
105
+
106
+ def forward_comfy_cast_weights(self, input):
107
+ weight, bias = cast_bias_weight(self, input)
108
+ return self._conv_forward(input, weight, bias)
109
+
110
+ def forward(self, *args, **kwargs):
111
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
112
+ return self.forward_comfy_cast_weights(*args, **kwargs)
113
+ else:
114
+ return super().forward(*args, **kwargs)
115
+
116
+ class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
117
+ def reset_parameters(self):
118
+ return None
119
+
120
+ def forward_comfy_cast_weights(self, input):
121
+ weight, bias = cast_bias_weight(self, input)
122
+ return self._conv_forward(input, weight, bias)
123
+
124
+ def forward(self, *args, **kwargs):
125
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
126
+ return self.forward_comfy_cast_weights(*args, **kwargs)
127
+ else:
128
+ return super().forward(*args, **kwargs)
129
+
130
+ class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
131
+ def reset_parameters(self):
132
+ return None
133
+
134
+ def forward_comfy_cast_weights(self, input):
135
+ weight, bias = cast_bias_weight(self, input)
136
+ return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
137
+
138
+ def forward(self, *args, **kwargs):
139
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
140
+ return self.forward_comfy_cast_weights(*args, **kwargs)
141
+ else:
142
+ return super().forward(*args, **kwargs)
143
+
144
+ class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
145
+ def reset_parameters(self):
146
+ return None
147
+
148
+ def forward_comfy_cast_weights(self, input):
149
+ if self.weight is not None:
150
+ weight, bias = cast_bias_weight(self, input)
151
+ else:
152
+ weight = None
153
+ bias = None
154
+ return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
155
+
156
+ def forward(self, *args, **kwargs):
157
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
158
+ return self.forward_comfy_cast_weights(*args, **kwargs)
159
+ else:
160
+ return super().forward(*args, **kwargs)
161
+
162
+ class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
163
+ def reset_parameters(self):
164
+ self.bias = None
165
+ return None
166
+
167
+ def forward_comfy_cast_weights(self, input):
168
+ if self.weight is not None:
169
+ weight, bias = cast_bias_weight(self, input)
170
+ else:
171
+ weight = None
172
+ return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
173
+ # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
174
+
175
+ def forward(self, *args, **kwargs):
176
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
177
+ return self.forward_comfy_cast_weights(*args, **kwargs)
178
+ else:
179
+ return super().forward(*args, **kwargs)
180
+
181
+ class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
182
+ def reset_parameters(self):
183
+ return None
184
+
185
+ def forward_comfy_cast_weights(self, input, output_size=None):
186
+ num_spatial_dims = 2
187
+ output_padding = self._output_padding(
188
+ input, output_size, self.stride, self.padding, self.kernel_size,
189
+ num_spatial_dims, self.dilation)
190
+
191
+ weight, bias = cast_bias_weight(self, input)
192
+ return torch.nn.functional.conv_transpose2d(
193
+ input, weight, bias, self.stride, self.padding,
194
+ output_padding, self.groups, self.dilation)
195
+
196
+ def forward(self, *args, **kwargs):
197
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
198
+ return self.forward_comfy_cast_weights(*args, **kwargs)
199
+ else:
200
+ return super().forward(*args, **kwargs)
201
+
202
+ class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
203
+ def reset_parameters(self):
204
+ return None
205
+
206
+ def forward_comfy_cast_weights(self, input, output_size=None):
207
+ num_spatial_dims = 1
208
+ output_padding = self._output_padding(
209
+ input, output_size, self.stride, self.padding, self.kernel_size,
210
+ num_spatial_dims, self.dilation)
211
+
212
+ weight, bias = cast_bias_weight(self, input)
213
+ return torch.nn.functional.conv_transpose1d(
214
+ input, weight, bias, self.stride, self.padding,
215
+ output_padding, self.groups, self.dilation)
216
+
217
+ def forward(self, *args, **kwargs):
218
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
219
+ return self.forward_comfy_cast_weights(*args, **kwargs)
220
+ else:
221
+ return super().forward(*args, **kwargs)
222
+
223
+ class Embedding(torch.nn.Embedding, CastWeightBiasOp):
224
+ def reset_parameters(self):
225
+ self.bias = None
226
+ return None
227
+
228
+ def forward_comfy_cast_weights(self, input, out_dtype=None):
229
+ output_dtype = out_dtype
230
+ if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
231
+ out_dtype = None
232
+ weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
233
+ return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
234
+
235
+ def forward(self, *args, **kwargs):
236
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
237
+ return self.forward_comfy_cast_weights(*args, **kwargs)
238
+ else:
239
+ if "out_dtype" in kwargs:
240
+ kwargs.pop("out_dtype")
241
+ return super().forward(*args, **kwargs)
242
+
243
+ @classmethod
244
+ def conv_nd(s, dims, *args, **kwargs):
245
+ if dims == 2:
246
+ return s.Conv2d(*args, **kwargs)
247
+ elif dims == 3:
248
+ return s.Conv3d(*args, **kwargs)
249
+ else:
250
+ raise ValueError(f"unsupported dimensions: {dims}")
251
+
252
+
253
+ class manual_cast(disable_weight_init):
254
+ class Linear(disable_weight_init.Linear):
255
+ comfy_cast_weights = True
256
+
257
+ class Conv1d(disable_weight_init.Conv1d):
258
+ comfy_cast_weights = True
259
+
260
+ class Conv2d(disable_weight_init.Conv2d):
261
+ comfy_cast_weights = True
262
+
263
+ class Conv3d(disable_weight_init.Conv3d):
264
+ comfy_cast_weights = True
265
+
266
+ class GroupNorm(disable_weight_init.GroupNorm):
267
+ comfy_cast_weights = True
268
+
269
+ class LayerNorm(disable_weight_init.LayerNorm):
270
+ comfy_cast_weights = True
271
+
272
+ class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
273
+ comfy_cast_weights = True
274
+
275
+ class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
276
+ comfy_cast_weights = True
277
+
278
+ class RMSNorm(disable_weight_init.RMSNorm):
279
+ comfy_cast_weights = True
280
+
281
+ class Embedding(disable_weight_init.Embedding):
282
+ comfy_cast_weights = True
283
+
284
+
285
+ def fp8_linear(self, input):
286
+ dtype = self.weight.dtype
287
+ if dtype not in [torch.float8_e4m3fn]:
288
+ return None
289
+
290
+ tensor_2d = False
291
+ if len(input.shape) == 2:
292
+ tensor_2d = True
293
+ input = input.unsqueeze(1)
294
+
295
+ input_shape = input.shape
296
+ input_dtype = input.dtype
297
+ if len(input.shape) == 3:
298
+ w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
299
+ w = w.t()
300
+
301
+ scale_weight = self.scale_weight
302
+ scale_input = self.scale_input
303
+ if scale_weight is None:
304
+ scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
305
+ else:
306
+ scale_weight = scale_weight.to(input.device)
307
+
308
+ if scale_input is None:
309
+ scale_input = torch.ones((), device=input.device, dtype=torch.float32)
310
+ input = torch.clamp(input, min=-448, max=448, out=input)
311
+ input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
312
+ else:
313
+ scale_input = scale_input.to(input.device)
314
+ input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
315
+
316
+ if bias is not None:
317
+ o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
318
+ else:
319
+ o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
320
+
321
+ if isinstance(o, tuple):
322
+ o = o[0]
323
+
324
+ if tensor_2d:
325
+ return o.reshape(input_shape[0], -1)
326
+
327
+ return o.reshape((-1, input_shape[1], self.weight.shape[0]))
328
+
329
+ return None
330
+
331
+ class fp8_ops(manual_cast):
332
+ class Linear(manual_cast.Linear):
333
+ def reset_parameters(self):
334
+ self.scale_weight = None
335
+ self.scale_input = None
336
+ return None
337
+
338
+ def forward_comfy_cast_weights(self, input):
339
+ try:
340
+ out = fp8_linear(self, input)
341
+ if out is not None:
342
+ return out
343
+ except Exception as e:
344
+ logging.info("Exception during fp8 op: {}".format(e))
345
+
346
+ weight, bias = cast_bias_weight(self, input)
347
+ return torch.nn.functional.linear(input, weight, bias)
348
+
349
+ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
350
+ logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
351
+ class scaled_fp8_op(manual_cast):
352
+ class Linear(manual_cast.Linear):
353
+ def __init__(self, *args, **kwargs):
354
+ if override_dtype is not None:
355
+ kwargs['dtype'] = override_dtype
356
+ super().__init__(*args, **kwargs)
357
+
358
+ def reset_parameters(self):
359
+ if not hasattr(self, 'scale_weight'):
360
+ self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
361
+
362
+ if not scale_input:
363
+ self.scale_input = None
364
+
365
+ if not hasattr(self, 'scale_input'):
366
+ self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
367
+ return None
368
+
369
+ def forward_comfy_cast_weights(self, input):
370
+ if fp8_matrix_mult:
371
+ out = fp8_linear(self, input)
372
+ if out is not None:
373
+ return out
374
+
375
+ weight, bias = cast_bias_weight(self, input)
376
+
377
+ if weight.numel() < input.numel(): #TODO: optimize
378
+ return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
379
+ else:
380
+ return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
381
+
382
+ def convert_weight(self, weight, inplace=False, **kwargs):
383
+ if inplace:
384
+ weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
385
+ return weight
386
+ else:
387
+ return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
388
+
389
+ def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
390
+ weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
391
+ if inplace_update:
392
+ self.weight.data.copy_(weight)
393
+ else:
394
+ self.weight = torch.nn.Parameter(weight, requires_grad=False)
395
+
396
+ return scaled_fp8_op
397
+
398
+ CUBLAS_IS_AVAILABLE = False
399
+ try:
400
+ from cublas_ops import CublasLinear
401
+ CUBLAS_IS_AVAILABLE = True
402
+ except ImportError:
403
+ pass
404
+
405
+ if CUBLAS_IS_AVAILABLE:
406
+ class cublas_ops(disable_weight_init):
407
+ class Linear(CublasLinear, disable_weight_init.Linear):
408
+ def reset_parameters(self):
409
+ return None
410
+
411
+ def forward_comfy_cast_weights(self, input):
412
+ return super().forward(input)
413
+
414
+ def forward(self, *args, **kwargs):
415
+ return super().forward(*args, **kwargs)
416
+
417
+ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
418
+ fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
419
+ if scaled_fp8 is not None:
420
+ return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
421
+
422
+ if (
423
+ fp8_compute and
424
+ (fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
425
+ not disable_fast_fp8
426
+ ):
427
+ return fp8_ops
428
+
429
+ if (
430
+ PerformanceFeature.CublasOps in args.fast and
431
+ CUBLAS_IS_AVAILABLE and
432
+ weight_dtype == torch.float16 and
433
+ (compute_dtype == torch.float16 or compute_dtype is None)
434
+ ):
435
+ logging.info("Using cublas ops")
436
+ return cublas_ops
437
+
438
+ if compute_dtype is None or weight_dtype == compute_dtype:
439
+ return disable_weight_init
440
+
441
+ return manual_cast
ComfyUI/comfy/patcher_extension.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Callable
3
+
4
+ class CallbacksMP:
5
+ ON_CLONE = "on_clone"
6
+ ON_LOAD = "on_load_after"
7
+ ON_DETACH = "on_detach_after"
8
+ ON_CLEANUP = "on_cleanup"
9
+ ON_PRE_RUN = "on_pre_run"
10
+ ON_PREPARE_STATE = "on_prepare_state"
11
+ ON_APPLY_HOOKS = "on_apply_hooks"
12
+ ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
13
+ ON_INJECT_MODEL = "on_inject_model"
14
+ ON_EJECT_MODEL = "on_eject_model"
15
+
16
+ # callbacks dict is in the format:
17
+ # {"call_type": {"key": [Callable1, Callable2, ...]} }
18
+ @classmethod
19
+ def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
20
+ return {}
21
+
22
+ def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
23
+ add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
24
+
25
+ def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False):
26
+ if is_model_options:
27
+ transformer_options = transformer_options.setdefault("transformer_options", {})
28
+ callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {})
29
+ c = callbacks.setdefault(call_type, {}).setdefault(key, [])
30
+ c.append(callback)
31
+
32
+ def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False):
33
+ if is_model_options:
34
+ transformer_options = transformer_options.get("transformer_options", {})
35
+ c_list = []
36
+ callbacks: dict[str, list] = transformer_options.get("callbacks", {})
37
+ c_list.extend(callbacks.get(call_type, {}).get(key, []))
38
+ return c_list
39
+
40
+ def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
41
+ if is_model_options:
42
+ transformer_options = transformer_options.get("transformer_options", {})
43
+ c_list = []
44
+ callbacks: dict[str, list] = transformer_options.get("callbacks", {})
45
+ for c in callbacks.get(call_type, {}).values():
46
+ c_list.extend(c)
47
+ return c_list
48
+
49
+ class WrappersMP:
50
+ OUTER_SAMPLE = "outer_sample"
51
+ PREPARE_SAMPLING = "prepare_sampling"
52
+ SAMPLER_SAMPLE = "sampler_sample"
53
+ CALC_COND_BATCH = "calc_cond_batch"
54
+ APPLY_MODEL = "apply_model"
55
+ DIFFUSION_MODEL = "diffusion_model"
56
+
57
+ # wrappers dict is in the format:
58
+ # {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
59
+ @classmethod
60
+ def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
61
+ return {}
62
+
63
+ def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
64
+ add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
65
+
66
+ def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
67
+ if is_model_options:
68
+ transformer_options = transformer_options.setdefault("transformer_options", {})
69
+ wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {})
70
+ w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
71
+ w.append(wrapper)
72
+
73
+ def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False):
74
+ if is_model_options:
75
+ transformer_options = transformer_options.get("transformer_options", {})
76
+ w_list = []
77
+ wrappers: dict[str, list] = transformer_options.get("wrappers", {})
78
+ w_list.extend(wrappers.get(wrapper_type, {}).get(key, []))
79
+ return w_list
80
+
81
+ def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
82
+ if is_model_options:
83
+ transformer_options = transformer_options.get("transformer_options", {})
84
+ w_list = []
85
+ wrappers: dict[str, list] = transformer_options.get("wrappers", {})
86
+ for w in wrappers.get(wrapper_type, {}).values():
87
+ w_list.extend(w)
88
+ return w_list
89
+
90
+ class WrapperExecutor:
91
+ """Handles call stack of wrappers around a function in an ordered manner."""
92
+ def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
93
+ # NOTE: class_obj exists so that wrappers surrounding a class method can access
94
+ # the class instance at runtime via executor.class_obj
95
+ self.original = original
96
+ self.class_obj = class_obj
97
+ self.wrappers = wrappers.copy()
98
+ self.idx = idx
99
+ self.is_last = idx == len(wrappers)
100
+
101
+ def __call__(self, *args, **kwargs):
102
+ """Calls the next wrapper or original function, whichever is appropriate."""
103
+ new_executor = self._create_next_executor()
104
+ return new_executor.execute(*args, **kwargs)
105
+
106
+ def execute(self, *args, **kwargs):
107
+ """Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
108
+ args = list(args)
109
+ kwargs = dict(kwargs)
110
+ if self.is_last:
111
+ return self.original(*args, **kwargs)
112
+ return self.wrappers[self.idx](self, *args, **kwargs)
113
+
114
+ def _create_next_executor(self) -> 'WrapperExecutor':
115
+ new_idx = self.idx + 1
116
+ if new_idx > len(self.wrappers):
117
+ raise Exception("Wrapper idx exceeded available wrappers; something went very wrong.")
118
+ if self.class_obj is None:
119
+ return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
120
+ return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
121
+
122
+ @classmethod
123
+ def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
124
+ return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
125
+
126
+ @classmethod
127
+ def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
128
+ return cls(original, class_obj, wrappers, idx=idx)
129
+
130
+ class PatcherInjection:
131
+ def __init__(self, inject: Callable, eject: Callable):
132
+ self.inject = inject
133
+ self.eject = eject
134
+
135
+ def copy_nested_dicts(input_dict: dict):
136
+ new_dict = input_dict.copy()
137
+ for key, value in input_dict.items():
138
+ if isinstance(value, dict):
139
+ new_dict[key] = copy_nested_dicts(value)
140
+ elif isinstance(value, list):
141
+ new_dict[key] = value.copy()
142
+ return new_dict
143
+
144
+ def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
145
+ if copy_dict1:
146
+ merged_dict = copy_nested_dicts(dict1)
147
+ else:
148
+ merged_dict = dict1
149
+ for key, value in dict2.items():
150
+ if isinstance(value, dict):
151
+ curr_value = merged_dict.setdefault(key, {})
152
+ merged_dict[key] = merge_nested_dicts(value, curr_value)
153
+ elif isinstance(value, list):
154
+ merged_dict.setdefault(key, []).extend(value)
155
+ else:
156
+ merged_dict[key] = value
157
+ return merged_dict
ComfyUI/comfy/sample.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy.model_management
3
+ import comfy.samplers
4
+ import comfy.utils
5
+ import numpy as np
6
+ import logging
7
+
8
+ def prepare_noise(latent_image, seed, noise_inds=None):
9
+ """
10
+ creates random noise given a latent image and a seed.
11
+ optional arg skip can be used to skip and discard x number of noise generations for a given seed
12
+ """
13
+ generator = torch.manual_seed(seed)
14
+ if noise_inds is None:
15
+ return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
16
+
17
+ unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
18
+ noises = []
19
+ for i in range(unique_inds[-1]+1):
20
+ noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
21
+ if i in unique_inds:
22
+ noises.append(noise)
23
+ noises = [noises[i] for i in inverse]
24
+ noises = torch.cat(noises, axis=0)
25
+ return noises
26
+
27
+ def fix_empty_latent_channels(model, latent_image):
28
+ latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
29
+ if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
30
+ latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
31
+ if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
32
+ latent_image = latent_image.unsqueeze(2)
33
+ return latent_image
34
+
35
+ def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
36
+ logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
37
+ return model, positive, negative, noise_mask, []
38
+
39
+ def cleanup_additional_models(models):
40
+ logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed")
41
+
42
+ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
43
+ sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
44
+
45
+ samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
46
+ samples = samples.to(comfy.model_management.intermediate_device())
47
+ return samples
48
+
49
+ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
50
+ samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
51
+ samples = samples.to(comfy.model_management.intermediate_device())
52
+ return samples
ComfyUI/comfy/samplers.py ADDED
@@ -0,0 +1,1143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from .k_diffusion import sampling as k_diffusion_sampling
3
+ from .extra_samplers import uni_pc
4
+ from typing import TYPE_CHECKING, Callable, NamedTuple
5
+ if TYPE_CHECKING:
6
+ from comfy.model_patcher import ModelPatcher
7
+ from comfy.model_base import BaseModel
8
+ from comfy.controlnet import ControlBase
9
+ import torch
10
+ from functools import partial
11
+ import collections
12
+ from comfy import model_management
13
+ import math
14
+ import logging
15
+ import comfy.sampler_helpers
16
+ import comfy.model_patcher
17
+ import comfy.patcher_extension
18
+ import comfy.hooks
19
+ import scipy.stats
20
+ import numpy
21
+
22
+
23
+ def add_area_dims(area, num_dims):
24
+ while (len(area) // 2) < num_dims:
25
+ area = [2147483648] + area[:len(area) // 2] + [0] + area[len(area) // 2:]
26
+ return area
27
+
28
+ def get_area_and_mult(conds, x_in, timestep_in):
29
+ dims = tuple(x_in.shape[2:])
30
+ area = None
31
+ strength = 1.0
32
+
33
+ if 'timestep_start' in conds:
34
+ timestep_start = conds['timestep_start']
35
+ if timestep_in[0] > timestep_start:
36
+ return None
37
+ if 'timestep_end' in conds:
38
+ timestep_end = conds['timestep_end']
39
+ if timestep_in[0] < timestep_end:
40
+ return None
41
+ if 'area' in conds:
42
+ area = list(conds['area'])
43
+ area = add_area_dims(area, len(dims))
44
+ if (len(area) // 2) > len(dims):
45
+ area = area[:len(dims)] + area[len(area) // 2:(len(area) // 2) + len(dims)]
46
+
47
+ if 'strength' in conds:
48
+ strength = conds['strength']
49
+
50
+ input_x = x_in
51
+ if area is not None:
52
+ for i in range(len(dims)):
53
+ area[i] = min(input_x.shape[i + 2] - area[len(dims) + i], area[i])
54
+ input_x = input_x.narrow(i + 2, area[len(dims) + i], area[i])
55
+
56
+ if 'mask' in conds:
57
+ # Scale the mask to the size of the input
58
+ # The mask should have been resized as we began the sampling process
59
+ mask_strength = 1.0
60
+ if "mask_strength" in conds:
61
+ mask_strength = conds["mask_strength"]
62
+ mask = conds['mask']
63
+ assert (mask.shape[1:] == x_in.shape[2:])
64
+
65
+ mask = mask[:input_x.shape[0]]
66
+ if area is not None:
67
+ for i in range(len(dims)):
68
+ mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
69
+
70
+ mask = mask * mask_strength
71
+ mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
72
+ else:
73
+ mask = torch.ones_like(input_x)
74
+ mult = mask * strength
75
+
76
+ if 'mask' not in conds and area is not None:
77
+ fuzz = 8
78
+ for i in range(len(dims)):
79
+ rr = min(fuzz, mult.shape[2 + i] // 4)
80
+ if area[len(dims) + i] != 0:
81
+ for t in range(rr):
82
+ m = mult.narrow(i + 2, t, 1)
83
+ m *= ((1.0 / rr) * (t + 1))
84
+ if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]:
85
+ for t in range(rr):
86
+ m = mult.narrow(i + 2, area[i] - 1 - t, 1)
87
+ m *= ((1.0 / rr) * (t + 1))
88
+
89
+ conditioning = {}
90
+ model_conds = conds["model_conds"]
91
+ for c in model_conds:
92
+ conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
93
+
94
+ hooks = conds.get('hooks', None)
95
+ control = conds.get('control', None)
96
+
97
+ patches = None
98
+ if 'gligen' in conds:
99
+ gligen = conds['gligen']
100
+ patches = {}
101
+ gligen_type = gligen[0]
102
+ gligen_model = gligen[1]
103
+ if gligen_type == "position":
104
+ gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
105
+ else:
106
+ gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
107
+
108
+ patches['middle_patch'] = [gligen_patch]
109
+
110
+ cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks'])
111
+ return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks)
112
+
113
+ def cond_equal_size(c1, c2):
114
+ if c1 is c2:
115
+ return True
116
+ if c1.keys() != c2.keys():
117
+ return False
118
+ for k in c1:
119
+ if not c1[k].can_concat(c2[k]):
120
+ return False
121
+ return True
122
+
123
+ def can_concat_cond(c1, c2):
124
+ if c1.input_x.shape != c2.input_x.shape:
125
+ return False
126
+
127
+ def objects_concatable(obj1, obj2):
128
+ if (obj1 is None) != (obj2 is None):
129
+ return False
130
+ if obj1 is not None:
131
+ if obj1 is not obj2:
132
+ return False
133
+ return True
134
+
135
+ if not objects_concatable(c1.control, c2.control):
136
+ return False
137
+
138
+ if not objects_concatable(c1.patches, c2.patches):
139
+ return False
140
+
141
+ return cond_equal_size(c1.conditioning, c2.conditioning)
142
+
143
+ def cond_cat(c_list):
144
+ temp = {}
145
+ for x in c_list:
146
+ for k in x:
147
+ cur = temp.get(k, [])
148
+ cur.append(x[k])
149
+ temp[k] = cur
150
+
151
+ out = {}
152
+ for k in temp:
153
+ conds = temp[k]
154
+ out[k] = conds[0].concat(conds[1:])
155
+
156
+ return out
157
+
158
+ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options):
159
+ # need to figure out remaining unmasked area for conds
160
+ default_mults = []
161
+ for _ in default_conds:
162
+ default_mults.append(torch.ones_like(x_in))
163
+ # look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond
164
+ for lora_hooks, to_run in hooked_to_run.items():
165
+ for cond_obj, i in to_run:
166
+ # if no default_cond for cond_type, do nothing
167
+ if len(default_conds[i]) == 0:
168
+ continue
169
+ area: list[int] = cond_obj.area
170
+ if area is not None:
171
+ curr_default_mult: torch.Tensor = default_mults[i]
172
+ dims = len(area) // 2
173
+ for i in range(dims):
174
+ curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i])
175
+ curr_default_mult -= cond_obj.mult
176
+ else:
177
+ default_mults[i] -= cond_obj.mult
178
+ # for each default_mult, ReLU to make negatives=0, and then check for any nonzeros
179
+ for i, mult in enumerate(default_mults):
180
+ # if no default_cond for cond type, do nothing
181
+ if len(default_conds[i]) == 0:
182
+ continue
183
+ torch.nn.functional.relu(mult, inplace=True)
184
+ # if mult is all zeros, then don't add default_cond
185
+ if torch.max(mult) == 0.0:
186
+ continue
187
+
188
+ cond = default_conds[i]
189
+ for x in cond:
190
+ # do get_area_and_mult to get all the expected values
191
+ p = get_area_and_mult(x, x_in, timestep)
192
+ if p is None:
193
+ continue
194
+ # replace p's mult with calculated mult
195
+ p = p._replace(mult=mult)
196
+ if p.hooks is not None:
197
+ model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
198
+ hooked_to_run.setdefault(p.hooks, list())
199
+ hooked_to_run[p.hooks] += [(p, i)]
200
+
201
+ def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
202
+ executor = comfy.patcher_extension.WrapperExecutor.new_executor(
203
+ _calc_cond_batch,
204
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
205
+ )
206
+ return executor.execute(model, conds, x_in, timestep, model_options)
207
+
208
+ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
209
+ out_conds = []
210
+ out_counts = []
211
+ # separate conds by matching hooks
212
+ hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
213
+ default_conds = []
214
+ has_default_conds = False
215
+
216
+ for i in range(len(conds)):
217
+ out_conds.append(torch.zeros_like(x_in))
218
+ out_counts.append(torch.ones_like(x_in) * 1e-37)
219
+
220
+ cond = conds[i]
221
+ default_c = []
222
+ if cond is not None:
223
+ for x in cond:
224
+ if 'default' in x:
225
+ default_c.append(x)
226
+ has_default_conds = True
227
+ continue
228
+ p = get_area_and_mult(x, x_in, timestep)
229
+ if p is None:
230
+ continue
231
+ if p.hooks is not None:
232
+ model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
233
+ hooked_to_run.setdefault(p.hooks, list())
234
+ hooked_to_run[p.hooks] += [(p, i)]
235
+ default_conds.append(default_c)
236
+
237
+ if has_default_conds:
238
+ finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
239
+
240
+ model.current_patcher.prepare_state(timestep)
241
+
242
+ # run every hooked_to_run separately
243
+ for hooks, to_run in hooked_to_run.items():
244
+ while len(to_run) > 0:
245
+ first = to_run[0]
246
+ first_shape = first[0][0].shape
247
+ to_batch_temp = []
248
+ for x in range(len(to_run)):
249
+ if can_concat_cond(to_run[x][0], first[0]):
250
+ to_batch_temp += [x]
251
+
252
+ to_batch_temp.reverse()
253
+ to_batch = to_batch_temp[:1]
254
+
255
+ free_memory = model_management.get_free_memory(x_in.device)
256
+ for i in range(1, len(to_batch_temp) + 1):
257
+ batch_amount = to_batch_temp[:len(to_batch_temp)//i]
258
+ input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
259
+ cond_shapes = collections.defaultdict(list)
260
+ for tt in batch_amount:
261
+ cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
262
+ for k, v in to_run[tt][0].conditioning.items():
263
+ cond_shapes[k].append(v.size())
264
+
265
+ if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
266
+ to_batch = batch_amount
267
+ break
268
+
269
+ input_x = []
270
+ mult = []
271
+ c = []
272
+ cond_or_uncond = []
273
+ uuids = []
274
+ area = []
275
+ control = None
276
+ patches = None
277
+ for x in to_batch:
278
+ o = to_run.pop(x)
279
+ p = o[0]
280
+ input_x.append(p.input_x)
281
+ mult.append(p.mult)
282
+ c.append(p.conditioning)
283
+ area.append(p.area)
284
+ cond_or_uncond.append(o[1])
285
+ uuids.append(p.uuid)
286
+ control = p.control
287
+ patches = p.patches
288
+
289
+ batch_chunks = len(cond_or_uncond)
290
+ input_x = torch.cat(input_x)
291
+ c = cond_cat(c)
292
+ timestep_ = torch.cat([timestep] * batch_chunks)
293
+
294
+ transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
295
+ if 'transformer_options' in model_options:
296
+ transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
297
+ model_options['transformer_options'],
298
+ copy_dict1=False)
299
+
300
+ if patches is not None:
301
+ # TODO: replace with merge_nested_dicts function
302
+ if "patches" in transformer_options:
303
+ cur_patches = transformer_options["patches"].copy()
304
+ for p in patches:
305
+ if p in cur_patches:
306
+ cur_patches[p] = cur_patches[p] + patches[p]
307
+ else:
308
+ cur_patches[p] = patches[p]
309
+ transformer_options["patches"] = cur_patches
310
+ else:
311
+ transformer_options["patches"] = patches
312
+
313
+ transformer_options["cond_or_uncond"] = cond_or_uncond[:]
314
+ transformer_options["uuids"] = uuids[:]
315
+ transformer_options["sigmas"] = timestep
316
+
317
+ c['transformer_options'] = transformer_options
318
+
319
+ if control is not None:
320
+ c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
321
+
322
+ if 'model_function_wrapper' in model_options:
323
+ output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
324
+ else:
325
+ output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
326
+
327
+ for o in range(batch_chunks):
328
+ cond_index = cond_or_uncond[o]
329
+ a = area[o]
330
+ if a is None:
331
+ out_conds[cond_index] += output[o] * mult[o]
332
+ out_counts[cond_index] += mult[o]
333
+ else:
334
+ out_c = out_conds[cond_index]
335
+ out_cts = out_counts[cond_index]
336
+ dims = len(a) // 2
337
+ for i in range(dims):
338
+ out_c = out_c.narrow(i + 2, a[i + dims], a[i])
339
+ out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
340
+ out_c += output[o] * mult[o]
341
+ out_cts += mult[o]
342
+
343
+ for i in range(len(out_conds)):
344
+ out_conds[i] /= out_counts[i]
345
+
346
+ return out_conds
347
+
348
+ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
349
+ logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
350
+ return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
351
+
352
+ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
353
+ if "sampler_cfg_function" in model_options:
354
+ args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
355
+ "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
356
+ cfg_result = x - model_options["sampler_cfg_function"](args)
357
+ else:
358
+ cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
359
+
360
+ for fn in model_options.get("sampler_post_cfg_function", []):
361
+ args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "cond_scale": cond_scale, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
362
+ "sigma": timestep, "model_options": model_options, "input": x}
363
+ cfg_result = fn(args)
364
+
365
+ return cfg_result
366
+
367
+ #The main sampling function shared by all the samplers
368
+ #Returns denoised
369
+ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
370
+ if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
371
+ uncond_ = None
372
+ else:
373
+ uncond_ = uncond
374
+
375
+ conds = [cond, uncond_]
376
+ if "sampler_calc_cond_batch_function" in model_options:
377
+ args = {"conds": conds, "input": x, "sigma": timestep, "model": model, "model_options": model_options}
378
+ out = model_options["sampler_calc_cond_batch_function"](args)
379
+ else:
380
+ out = calc_cond_batch(model, conds, x, timestep, model_options)
381
+
382
+ for fn in model_options.get("sampler_pre_cfg_function", []):
383
+ args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
384
+ "input": x, "sigma": timestep, "model": model, "model_options": model_options}
385
+ out = fn(args)
386
+
387
+ return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
388
+
389
+
390
+ class KSamplerX0Inpaint:
391
+ def __init__(self, model, sigmas):
392
+ self.inner_model = model
393
+ self.sigmas = sigmas
394
+ def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
395
+ if denoise_mask is not None:
396
+ if "denoise_mask_function" in model_options:
397
+ denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
398
+ latent_mask = 1. - denoise_mask
399
+ x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
400
+ out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
401
+ if denoise_mask is not None:
402
+ out = out * denoise_mask + self.latent_image * latent_mask
403
+ return out
404
+
405
+ def simple_scheduler(model_sampling, steps):
406
+ s = model_sampling
407
+ sigs = []
408
+ ss = len(s.sigmas) / steps
409
+ for x in range(steps):
410
+ sigs += [float(s.sigmas[-(1 + int(x * ss))])]
411
+ sigs += [0.0]
412
+ return torch.FloatTensor(sigs)
413
+
414
+ def ddim_scheduler(model_sampling, steps):
415
+ s = model_sampling
416
+ sigs = []
417
+ x = 1
418
+ if math.isclose(float(s.sigmas[x]), 0, abs_tol=0.00001):
419
+ steps += 1
420
+ sigs = []
421
+ else:
422
+ sigs = [0.0]
423
+
424
+ ss = max(len(s.sigmas) // steps, 1)
425
+ while x < len(s.sigmas):
426
+ sigs += [float(s.sigmas[x])]
427
+ x += ss
428
+ sigs = sigs[::-1]
429
+ return torch.FloatTensor(sigs)
430
+
431
+ def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
432
+ s = model_sampling
433
+ start = s.timestep(s.sigma_max)
434
+ end = s.timestep(s.sigma_min)
435
+
436
+ append_zero = True
437
+ if sgm:
438
+ timesteps = torch.linspace(start, end, steps + 1)[:-1]
439
+ else:
440
+ if math.isclose(float(s.sigma(end)), 0, abs_tol=0.00001):
441
+ steps += 1
442
+ append_zero = False
443
+ timesteps = torch.linspace(start, end, steps)
444
+
445
+ sigs = []
446
+ for x in range(len(timesteps)):
447
+ ts = timesteps[x]
448
+ sigs.append(float(s.sigma(ts)))
449
+
450
+ if append_zero:
451
+ sigs += [0.0]
452
+
453
+ return torch.FloatTensor(sigs)
454
+
455
+ # Implemented based on: https://arxiv.org/abs/2407.12173
456
+ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
457
+ total_timesteps = (len(model_sampling.sigmas) - 1)
458
+ ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
459
+ ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
460
+
461
+ sigs = []
462
+ last_t = -1
463
+ for t in ts:
464
+ if t != last_t:
465
+ sigs += [float(model_sampling.sigmas[int(t)])]
466
+ last_t = t
467
+ sigs += [0.0]
468
+ return torch.FloatTensor(sigs)
469
+
470
+ # from: https://github.com/genmoai/models/blob/main/src/mochi_preview/infer.py#L41
471
+ def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, linear_steps=None):
472
+ if steps == 1:
473
+ sigma_schedule = [1.0, 0.0]
474
+ else:
475
+ if linear_steps is None:
476
+ linear_steps = steps // 2
477
+ linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
478
+ threshold_noise_step_diff = linear_steps - threshold_noise * steps
479
+ quadratic_steps = steps - linear_steps
480
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
481
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
482
+ const = quadratic_coef * (linear_steps ** 2)
483
+ quadratic_sigma_schedule = [
484
+ quadratic_coef * (i ** 2) + linear_coef * i + const
485
+ for i in range(linear_steps, steps)
486
+ ]
487
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
488
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
489
+ return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
490
+
491
+ # Referenced from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
492
+ def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Tensor:
493
+ adj_idxs = torch.arange(n, dtype=torch.float).div_(n - 1)
494
+ sigmas = adj_idxs.new_zeros(n + 1)
495
+ sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_()
496
+ return sigmas
497
+
498
+ def get_mask_aabb(masks):
499
+ if masks.numel() == 0:
500
+ return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
501
+
502
+ b = masks.shape[0]
503
+
504
+ bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int)
505
+ is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool)
506
+ for i in range(b):
507
+ mask = masks[i]
508
+ if mask.numel() == 0:
509
+ continue
510
+ if torch.max(mask != 0) == False:
511
+ is_empty[i] = True
512
+ continue
513
+ y, x = torch.where(mask)
514
+ bounding_boxes[i, 0] = torch.min(x)
515
+ bounding_boxes[i, 1] = torch.min(y)
516
+ bounding_boxes[i, 2] = torch.max(x)
517
+ bounding_boxes[i, 3] = torch.max(y)
518
+
519
+ return bounding_boxes, is_empty
520
+
521
+ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
522
+ # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
523
+ # While we're doing this, we can also resolve the mask device and scaling for performance reasons
524
+ for i in range(len(conditions)):
525
+ c = conditions[i]
526
+ if 'area' in c:
527
+ area = c['area']
528
+ if area[0] == "percentage":
529
+ modified = c.copy()
530
+ a = area[1:]
531
+ a_len = len(a) // 2
532
+ area = ()
533
+ for d in range(len(dims)):
534
+ area += (max(1, round(a[d] * dims[d])),)
535
+ for d in range(len(dims)):
536
+ area += (round(a[d + a_len] * dims[d]),)
537
+
538
+ modified['area'] = area
539
+ c = modified
540
+ conditions[i] = c
541
+
542
+ if 'mask' in c:
543
+ mask = c['mask']
544
+ mask = mask.to(device=device)
545
+ modified = c.copy()
546
+ if len(mask.shape) == len(dims):
547
+ mask = mask.unsqueeze(0)
548
+ if mask.shape[1:] != dims:
549
+ mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
550
+
551
+ if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
552
+ bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
553
+ boxes, is_empty = get_mask_aabb(bounds)
554
+ if is_empty[0]:
555
+ # Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway)
556
+ modified['area'] = (8, 8, 0, 0)
557
+ else:
558
+ box = boxes[0]
559
+ H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
560
+ H = max(8, H)
561
+ W = max(8, W)
562
+ area = (int(H), int(W), int(Y), int(X))
563
+ modified['area'] = area
564
+
565
+ modified['mask'] = mask
566
+ conditions[i] = modified
567
+
568
+ def resolve_areas_and_cond_masks(conditions, h, w, device):
569
+ logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
570
+ return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
571
+
572
+ def create_cond_with_same_area_if_none(conds, c):
573
+ if 'area' not in c:
574
+ return
575
+
576
+ def area_inside(a, area_cmp):
577
+ a = add_area_dims(a, len(area_cmp) // 2)
578
+ area_cmp = add_area_dims(area_cmp, len(a) // 2)
579
+
580
+ a_l = len(a) // 2
581
+ area_cmp_l = len(area_cmp) // 2
582
+ for i in range(min(a_l, area_cmp_l)):
583
+ if a[a_l + i] < area_cmp[area_cmp_l + i]:
584
+ return False
585
+ for i in range(min(a_l, area_cmp_l)):
586
+ if (a[i] + a[a_l + i]) > (area_cmp[i] + area_cmp[area_cmp_l + i]):
587
+ return False
588
+ return True
589
+
590
+ c_area = c['area']
591
+ smallest = None
592
+ for x in conds:
593
+ if 'area' in x:
594
+ a = x['area']
595
+ if area_inside(c_area, a):
596
+ if smallest is None:
597
+ smallest = x
598
+ elif 'area' not in smallest:
599
+ smallest = x
600
+ else:
601
+ if math.prod(smallest['area'][:len(smallest['area']) // 2]) > math.prod(a[:len(a) // 2]):
602
+ smallest = x
603
+ else:
604
+ if smallest is None:
605
+ smallest = x
606
+ if smallest is None:
607
+ return
608
+ if 'area' in smallest:
609
+ if smallest['area'] == c_area:
610
+ return
611
+
612
+ out = c.copy()
613
+ out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied?
614
+ conds += [out]
615
+
616
+ def calculate_start_end_timesteps(model, conds):
617
+ s = model.model_sampling
618
+ for t in range(len(conds)):
619
+ x = conds[t]
620
+
621
+ timestep_start = None
622
+ timestep_end = None
623
+ # handle clip hook schedule, if needed
624
+ if 'clip_start_percent' in x:
625
+ timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0)))
626
+ timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0)))
627
+ else:
628
+ if 'start_percent' in x:
629
+ timestep_start = s.percent_to_sigma(x['start_percent'])
630
+ if 'end_percent' in x:
631
+ timestep_end = s.percent_to_sigma(x['end_percent'])
632
+
633
+ if (timestep_start is not None) or (timestep_end is not None):
634
+ n = x.copy()
635
+ if (timestep_start is not None):
636
+ n['timestep_start'] = timestep_start
637
+ if (timestep_end is not None):
638
+ n['timestep_end'] = timestep_end
639
+ conds[t] = n
640
+
641
+ def pre_run_control(model, conds):
642
+ s = model.model_sampling
643
+ for t in range(len(conds)):
644
+ x = conds[t]
645
+
646
+ percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
647
+ if 'control' in x:
648
+ x['control'].pre_run(model, percent_to_timestep_function)
649
+
650
+ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
651
+ cond_cnets = []
652
+ cond_other = []
653
+ uncond_cnets = []
654
+ uncond_other = []
655
+ for t in range(len(conds)):
656
+ x = conds[t]
657
+ if 'area' not in x:
658
+ if name in x and x[name] is not None:
659
+ cond_cnets.append(x[name])
660
+ else:
661
+ cond_other.append((x, t))
662
+ for t in range(len(uncond)):
663
+ x = uncond[t]
664
+ if 'area' not in x:
665
+ if name in x and x[name] is not None:
666
+ uncond_cnets.append(x[name])
667
+ else:
668
+ uncond_other.append((x, t))
669
+
670
+ if len(uncond_cnets) > 0:
671
+ return
672
+
673
+ for x in range(len(cond_cnets)):
674
+ temp = uncond_other[x % len(uncond_other)]
675
+ o = temp[0]
676
+ if name in o and o[name] is not None:
677
+ n = o.copy()
678
+ n[name] = uncond_fill_func(cond_cnets, x)
679
+ uncond += [n]
680
+ else:
681
+ n = o.copy()
682
+ n[name] = uncond_fill_func(cond_cnets, x)
683
+ uncond[temp[1]] = n
684
+
685
+ def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs):
686
+ for t in range(len(conds)):
687
+ x = conds[t]
688
+ params = x.copy()
689
+ params["device"] = device
690
+ params["noise"] = noise
691
+ default_width = None
692
+ if len(noise.shape) >= 4: #TODO: 8 multiple should be set by the model
693
+ default_width = noise.shape[3] * 8
694
+ params["width"] = params.get("width", default_width)
695
+ params["height"] = params.get("height", noise.shape[2] * 8)
696
+ params["prompt_type"] = params.get("prompt_type", prompt_type)
697
+ for k in kwargs:
698
+ if k not in params:
699
+ params[k] = kwargs[k]
700
+
701
+ out = model_function(**params)
702
+ x = x.copy()
703
+ model_conds = x['model_conds'].copy()
704
+ for k in out:
705
+ model_conds[k] = out[k]
706
+ x['model_conds'] = model_conds
707
+ conds[t] = x
708
+ return conds
709
+
710
+ class Sampler:
711
+ def sample(self):
712
+ pass
713
+
714
+ def max_denoise(self, model_wrap, sigmas):
715
+ max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
716
+ sigma = float(sigmas[0])
717
+ return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
718
+
719
+ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
720
+ "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
721
+ "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
722
+ "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
723
+ "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
724
+
725
+ class KSAMPLER(Sampler):
726
+ def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
727
+ self.sampler_function = sampler_function
728
+ self.extra_options = extra_options
729
+ self.inpaint_options = inpaint_options
730
+
731
+ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
732
+ extra_args["denoise_mask"] = denoise_mask
733
+ model_k = KSamplerX0Inpaint(model_wrap, sigmas)
734
+ model_k.latent_image = latent_image
735
+ if self.inpaint_options.get("random", False): #TODO: Should this be the default?
736
+ generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
737
+ model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
738
+ else:
739
+ model_k.noise = noise
740
+
741
+ noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
742
+
743
+ k_callback = None
744
+ total_steps = len(sigmas) - 1
745
+ if callback is not None:
746
+ k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
747
+
748
+ samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
749
+ samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
750
+ return samples
751
+
752
+
753
+ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
754
+ if sampler_name == "dpm_fast":
755
+ def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
756
+ if len(sigmas) <= 1:
757
+ return noise
758
+
759
+ sigma_min = sigmas[-1]
760
+ if sigma_min == 0:
761
+ sigma_min = sigmas[-2]
762
+ total_steps = len(sigmas) - 1
763
+ return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
764
+ sampler_function = dpm_fast_function
765
+ elif sampler_name == "dpm_adaptive":
766
+ def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
767
+ if len(sigmas) <= 1:
768
+ return noise
769
+
770
+ sigma_min = sigmas[-1]
771
+ if sigma_min == 0:
772
+ sigma_min = sigmas[-2]
773
+ return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options)
774
+ sampler_function = dpm_adaptive_function
775
+ else:
776
+ sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
777
+
778
+ return KSAMPLER(sampler_function, extra_options, inpaint_options)
779
+
780
+
781
+ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
782
+ for k in conds:
783
+ conds[k] = conds[k][:]
784
+ resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
785
+
786
+ for k in conds:
787
+ calculate_start_end_timesteps(model, conds[k])
788
+
789
+ if hasattr(model, 'extra_conds'):
790
+ for k in conds:
791
+ conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
792
+
793
+ #make sure each cond area has an opposite one with the same area
794
+ for k in conds:
795
+ for c in conds[k]:
796
+ for kk in conds:
797
+ if k != kk:
798
+ create_cond_with_same_area_if_none(conds[kk], c)
799
+
800
+ for k in conds:
801
+ for c in conds[k]:
802
+ if 'hooks' in c:
803
+ for hook in c['hooks'].hooks:
804
+ hook.initialize_timesteps(model)
805
+
806
+ for k in conds:
807
+ pre_run_control(model, conds[k])
808
+
809
+ if "positive" in conds:
810
+ positive = conds["positive"]
811
+ for k in conds:
812
+ if k != "positive":
813
+ apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x])
814
+ apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x])
815
+
816
+ return conds
817
+
818
+
819
+ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
820
+ # determine which ControlNets have extra_hooks that should be combined with normal hooks
821
+ hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {}
822
+ for k in conds:
823
+ for kk in conds[k]:
824
+ if 'control' in kk:
825
+ control: 'ControlBase' = kk['control']
826
+ extra_hooks = control.get_extra_hooks()
827
+ if len(extra_hooks) > 0:
828
+ hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
829
+ to_replace = hook_replacement.setdefault((control, hooks), [])
830
+ to_replace.append(kk)
831
+ # if nothing to replace, do nothing
832
+ if len(hook_replacement) == 0:
833
+ return
834
+
835
+ # for optimal sampling performance, common ControlNets + hook combos should have identical hooks
836
+ # on the cond dicts
837
+ for key, conds_to_modify in hook_replacement.items():
838
+ control = key[0]
839
+ hooks = key[1]
840
+ hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks])
841
+ # if combined hooks are not None, set as new hooks for all relevant conds
842
+ if hooks is not None:
843
+ for cond in conds_to_modify:
844
+ cond['hooks'] = hooks
845
+
846
+ def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]):
847
+ '''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for
848
+ HookGroups that have the same reference.'''
849
+ registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None)
850
+ # if None were registered, make sure all hooks are cleaned from conds
851
+ if registered is None:
852
+ for k in conds:
853
+ for kk in conds[k]:
854
+ kk.pop('hooks', None)
855
+ return
856
+ # find conds that contain hooks to be replaced - group by common HookGroup refs
857
+ hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {}
858
+ for k in conds:
859
+ for kk in conds[k]:
860
+ hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
861
+ if hooks is not None:
862
+ if not hooks.is_subset_of(registered):
863
+ to_replace = hook_replacement.setdefault(hooks, [])
864
+ to_replace.append(kk)
865
+ # for each hook to replace, create a new proper HookGroup and assign to all common conds
866
+ for hooks, conds_to_modify in hook_replacement.items():
867
+ new_hooks = hooks.new_with_common_hooks(registered)
868
+ if len(new_hooks) == 0:
869
+ new_hooks = None
870
+ for kk in conds_to_modify:
871
+ kk['hooks'] = new_hooks
872
+
873
+
874
+ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
875
+ hooks_set = set()
876
+ for k in conds:
877
+ for kk in conds[k]:
878
+ hooks_set.add(kk.get('hooks', None))
879
+ return len(hooks_set)
880
+
881
+
882
+ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
883
+ '''
884
+ If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
885
+ '''
886
+ if model_options is None:
887
+ return
888
+ to_load_options = model_options.get("to_load_options", None)
889
+ if to_load_options is None:
890
+ return
891
+
892
+ casts = []
893
+ if device is not None:
894
+ casts.append(device)
895
+ if dtype is not None:
896
+ casts.append(dtype)
897
+ # if nothing to apply, do nothing
898
+ if len(casts) == 0:
899
+ return
900
+
901
+ # try to call .to on patches
902
+ if "patches" in to_load_options:
903
+ patches = to_load_options["patches"]
904
+ for name in patches:
905
+ patch_list = patches[name]
906
+ for i in range(len(patch_list)):
907
+ if hasattr(patch_list[i], "to"):
908
+ for cast in casts:
909
+ patch_list[i] = patch_list[i].to(cast)
910
+ if "patches_replace" in to_load_options:
911
+ patches = to_load_options["patches_replace"]
912
+ for name in patches:
913
+ patch_list = patches[name]
914
+ for k in patch_list:
915
+ if hasattr(patch_list[k], "to"):
916
+ for cast in casts:
917
+ patch_list[k] = patch_list[k].to(cast)
918
+ # try to call .to on any wrappers/callbacks
919
+ wrappers_and_callbacks = ["wrappers", "callbacks"]
920
+ for wc_name in wrappers_and_callbacks:
921
+ if wc_name in to_load_options:
922
+ wc: dict[str, list] = to_load_options[wc_name]
923
+ for wc_dict in wc.values():
924
+ for wc_list in wc_dict.values():
925
+ for i in range(len(wc_list)):
926
+ if hasattr(wc_list[i], "to"):
927
+ for cast in casts:
928
+ wc_list[i] = wc_list[i].to(cast)
929
+
930
+
931
+ class CFGGuider:
932
+ def __init__(self, model_patcher: ModelPatcher):
933
+ self.model_patcher = model_patcher
934
+ self.model_options = model_patcher.model_options
935
+ self.original_conds = {}
936
+ self.cfg = 1.0
937
+
938
+ def set_conds(self, positive, negative):
939
+ self.inner_set_conds({"positive": positive, "negative": negative})
940
+
941
+ def set_cfg(self, cfg):
942
+ self.cfg = cfg
943
+
944
+ def inner_set_conds(self, conds):
945
+ for k in conds:
946
+ self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
947
+
948
+ def __call__(self, *args, **kwargs):
949
+ return self.predict_noise(*args, **kwargs)
950
+
951
+ def predict_noise(self, x, timestep, model_options={}, seed=None):
952
+ return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
953
+
954
+ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
955
+ if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
956
+ latent_image = self.inner_model.process_latent_in(latent_image)
957
+
958
+ self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
959
+
960
+ extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
961
+ extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
962
+ extra_args = {"model_options": extra_model_options, "seed": seed}
963
+
964
+ executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
965
+ sampler.sample,
966
+ sampler,
967
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
968
+ )
969
+ samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
970
+ return self.inner_model.process_latent_out(samples.to(torch.float32))
971
+
972
+ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
973
+ self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
974
+ device = self.model_patcher.load_device
975
+
976
+ if denoise_mask is not None:
977
+ denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
978
+
979
+ noise = noise.to(device)
980
+ latent_image = latent_image.to(device)
981
+ sigmas = sigmas.to(device)
982
+ cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
983
+
984
+ try:
985
+ self.model_patcher.pre_run()
986
+ output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
987
+ finally:
988
+ self.model_patcher.cleanup()
989
+
990
+ comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
991
+ del self.inner_model
992
+ del self.loaded_models
993
+ return output
994
+
995
+ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
996
+ if sigmas.shape[-1] == 0:
997
+ return latent_image
998
+
999
+ self.conds = {}
1000
+ for k in self.original_conds:
1001
+ self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
1002
+ preprocess_conds_hooks(self.conds)
1003
+
1004
+ try:
1005
+ orig_model_options = self.model_options
1006
+ self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
1007
+ # if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step)
1008
+ orig_hook_mode = self.model_patcher.hook_mode
1009
+ if get_total_hook_groups_in_conds(self.conds) <= 1:
1010
+ self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
1011
+ comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
1012
+ filter_registered_hooks_on_conds(self.conds, self.model_options)
1013
+ executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
1014
+ self.outer_sample,
1015
+ self,
1016
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
1017
+ )
1018
+ output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
1019
+ finally:
1020
+ cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
1021
+ self.model_options = orig_model_options
1022
+ self.model_patcher.hook_mode = orig_hook_mode
1023
+ self.model_patcher.restore_hook_patches()
1024
+
1025
+ del self.conds
1026
+ return output
1027
+
1028
+
1029
+ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
1030
+ cfg_guider = CFGGuider(model)
1031
+ cfg_guider.set_conds(positive, negative)
1032
+ cfg_guider.set_cfg(cfg)
1033
+ return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
1034
+
1035
+
1036
+ SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
1037
+
1038
+ class SchedulerHandler(NamedTuple):
1039
+ handler: Callable[..., torch.Tensor]
1040
+ # Boolean indicates whether to call the handler like:
1041
+ # scheduler_function(model_sampling, steps) or
1042
+ # scheduler_function(n, sigma_min: float, sigma_max: float)
1043
+ use_ms: bool = True
1044
+
1045
+ SCHEDULER_HANDLERS = {
1046
+ "simple": SchedulerHandler(simple_scheduler),
1047
+ "sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
1048
+ "karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
1049
+ "exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
1050
+ "ddim_uniform": SchedulerHandler(ddim_scheduler),
1051
+ "beta": SchedulerHandler(beta_scheduler),
1052
+ "normal": SchedulerHandler(normal_scheduler),
1053
+ "linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
1054
+ "kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
1055
+ }
1056
+ SCHEDULER_NAMES = list(SCHEDULER_HANDLERS)
1057
+
1058
+ def calculate_sigmas(model_sampling: object, scheduler_name: str, steps: int) -> torch.Tensor:
1059
+ handler = SCHEDULER_HANDLERS.get(scheduler_name)
1060
+ if handler is None:
1061
+ err = f"error invalid scheduler {scheduler_name}"
1062
+ logging.error(err)
1063
+ raise ValueError(err)
1064
+ if handler.use_ms:
1065
+ return handler.handler(model_sampling, steps)
1066
+ return handler.handler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
1067
+
1068
+ def sampler_object(name):
1069
+ if name == "uni_pc":
1070
+ sampler = KSAMPLER(uni_pc.sample_unipc)
1071
+ elif name == "uni_pc_bh2":
1072
+ sampler = KSAMPLER(uni_pc.sample_unipc_bh2)
1073
+ elif name == "ddim":
1074
+ sampler = ksampler("euler", inpaint_options={"random": True})
1075
+ else:
1076
+ sampler = ksampler(name)
1077
+ return sampler
1078
+
1079
+ class KSampler:
1080
+ SCHEDULERS = SCHEDULER_NAMES
1081
+ SAMPLERS = SAMPLER_NAMES
1082
+ DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set(('dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2'))
1083
+
1084
+ def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
1085
+ self.model = model
1086
+ self.device = device
1087
+ if scheduler not in self.SCHEDULERS:
1088
+ scheduler = self.SCHEDULERS[0]
1089
+ if sampler not in self.SAMPLERS:
1090
+ sampler = self.SAMPLERS[0]
1091
+ self.scheduler = scheduler
1092
+ self.sampler = sampler
1093
+ self.set_steps(steps, denoise)
1094
+ self.denoise = denoise
1095
+ self.model_options = model_options
1096
+
1097
+ def calculate_sigmas(self, steps):
1098
+ sigmas = None
1099
+
1100
+ discard_penultimate_sigma = False
1101
+ if self.sampler in self.DISCARD_PENULTIMATE_SIGMA_SAMPLERS:
1102
+ steps += 1
1103
+ discard_penultimate_sigma = True
1104
+
1105
+ sigmas = calculate_sigmas(self.model.get_model_object("model_sampling"), self.scheduler, steps)
1106
+
1107
+ if discard_penultimate_sigma:
1108
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
1109
+ return sigmas
1110
+
1111
+ def set_steps(self, steps, denoise=None):
1112
+ self.steps = steps
1113
+ if denoise is None or denoise > 0.9999:
1114
+ self.sigmas = self.calculate_sigmas(steps).to(self.device)
1115
+ else:
1116
+ if denoise <= 0.0:
1117
+ self.sigmas = torch.FloatTensor([])
1118
+ else:
1119
+ new_steps = int(steps/denoise)
1120
+ sigmas = self.calculate_sigmas(new_steps).to(self.device)
1121
+ self.sigmas = sigmas[-(steps + 1):]
1122
+
1123
+ def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
1124
+ if sigmas is None:
1125
+ sigmas = self.sigmas
1126
+
1127
+ if last_step is not None and last_step < (len(sigmas) - 1):
1128
+ sigmas = sigmas[:last_step + 1]
1129
+ if force_full_denoise:
1130
+ sigmas[-1] = 0
1131
+
1132
+ if start_step is not None:
1133
+ if start_step < (len(sigmas) - 1):
1134
+ sigmas = sigmas[start_step:]
1135
+ else:
1136
+ if latent_image is not None:
1137
+ return latent_image
1138
+ else:
1139
+ return torch.zeros_like(noise)
1140
+
1141
+ sampler = sampler_object(self.sampler)
1142
+
1143
+ return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
ComfyUI/comfy/sd1_clip.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from transformers import CLIPTokenizer
4
+ import comfy.ops
5
+ import torch
6
+ import traceback
7
+ import zipfile
8
+ from . import model_management
9
+ import comfy.clip_model
10
+ import json
11
+ import logging
12
+ import numbers
13
+ import re
14
+
15
+ def gen_empty_tokens(special_tokens, length):
16
+ start_token = special_tokens.get("start", None)
17
+ end_token = special_tokens.get("end", None)
18
+ pad_token = special_tokens.get("pad")
19
+ output = []
20
+ if start_token is not None:
21
+ output.append(start_token)
22
+ if end_token is not None:
23
+ output.append(end_token)
24
+ output += [pad_token] * (length - len(output))
25
+ return output
26
+
27
+ class ClipTokenWeightEncoder:
28
+ def encode_token_weights(self, token_weight_pairs):
29
+ to_encode = list()
30
+ max_token_len = 0
31
+ has_weights = False
32
+ for x in token_weight_pairs:
33
+ tokens = list(map(lambda a: a[0], x))
34
+ max_token_len = max(len(tokens), max_token_len)
35
+ has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
36
+ to_encode.append(tokens)
37
+
38
+ sections = len(to_encode)
39
+ if has_weights or sections == 0:
40
+ if hasattr(self, "gen_empty_tokens"):
41
+ to_encode.append(self.gen_empty_tokens(self.special_tokens, max_token_len))
42
+ else:
43
+ to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
44
+
45
+ o = self.encode(to_encode)
46
+ out, pooled = o[:2]
47
+
48
+ if pooled is not None:
49
+ first_pooled = pooled[0:1].to(model_management.intermediate_device())
50
+ else:
51
+ first_pooled = pooled
52
+
53
+ output = []
54
+ for k in range(0, sections):
55
+ z = out[k:k+1]
56
+ if has_weights:
57
+ z_empty = out[-1]
58
+ for i in range(len(z)):
59
+ for j in range(len(z[i])):
60
+ weight = token_weight_pairs[k][j][1]
61
+ if weight != 1.0:
62
+ z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
63
+ output.append(z)
64
+
65
+ if (len(output) == 0):
66
+ r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
67
+ else:
68
+ r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
69
+
70
+ if len(o) > 2:
71
+ extra = {}
72
+ for k in o[2]:
73
+ v = o[2][k]
74
+ if k == "attention_mask":
75
+ v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
76
+ extra[k] = v
77
+
78
+ r = r + (extra,)
79
+ return r
80
+
81
+ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
82
+ LAYERS = [
83
+ "last",
84
+ "pooled",
85
+ "hidden",
86
+ "all"
87
+ ]
88
+ def __init__(self, device="cpu", max_length=77,
89
+ freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
90
+ special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
91
+ return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
92
+ super().__init__()
93
+ assert layer in self.LAYERS
94
+
95
+ if textmodel_json_config is None:
96
+ textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
97
+ if "model_name" not in model_options:
98
+ model_options = {**model_options, "model_name": "clip_l"}
99
+
100
+ if isinstance(textmodel_json_config, dict):
101
+ config = textmodel_json_config
102
+ else:
103
+ with open(textmodel_json_config) as f:
104
+ config = json.load(f)
105
+
106
+ te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {})
107
+ for k, v in te_model_options.items():
108
+ config[k] = v
109
+
110
+ operations = model_options.get("custom_operations", None)
111
+ scaled_fp8 = None
112
+
113
+ if operations is None:
114
+ scaled_fp8 = model_options.get("scaled_fp8", None)
115
+ if scaled_fp8 is not None:
116
+ operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
117
+ else:
118
+ operations = comfy.ops.manual_cast
119
+
120
+ self.operations = operations
121
+ self.transformer = model_class(config, dtype, device, self.operations)
122
+ if scaled_fp8 is not None:
123
+ self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
124
+
125
+ self.num_layers = self.transformer.num_layers
126
+
127
+ self.max_length = max_length
128
+ if freeze:
129
+ self.freeze()
130
+ self.layer = layer
131
+ self.layer_idx = None
132
+ self.special_tokens = special_tokens
133
+
134
+ self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
135
+ self.enable_attention_masks = enable_attention_masks
136
+ self.zero_out_masked = zero_out_masked
137
+
138
+ self.layer_norm_hidden_state = layer_norm_hidden_state
139
+ self.return_projected_pooled = return_projected_pooled
140
+ self.return_attention_masks = return_attention_masks
141
+
142
+ if layer == "hidden":
143
+ assert layer_idx is not None
144
+ assert abs(layer_idx) < self.num_layers
145
+ self.set_clip_options({"layer": layer_idx})
146
+ self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
147
+
148
+ def freeze(self):
149
+ self.transformer = self.transformer.eval()
150
+ #self.train = disabled_train
151
+ for param in self.parameters():
152
+ param.requires_grad = False
153
+
154
+ def set_clip_options(self, options):
155
+ layer_idx = options.get("layer", self.layer_idx)
156
+ self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
157
+ if self.layer == "all":
158
+ pass
159
+ elif layer_idx is None or abs(layer_idx) > self.num_layers:
160
+ self.layer = "last"
161
+ else:
162
+ self.layer = "hidden"
163
+ self.layer_idx = layer_idx
164
+
165
+ def reset_clip_options(self):
166
+ self.layer = self.options_default[0]
167
+ self.layer_idx = self.options_default[1]
168
+ self.return_projected_pooled = self.options_default[2]
169
+
170
+ def process_tokens(self, tokens, device):
171
+ end_token = self.special_tokens.get("end", None)
172
+ if end_token is None:
173
+ cmp_token = self.special_tokens.get("pad", -1)
174
+ else:
175
+ cmp_token = end_token
176
+
177
+ embeds_out = []
178
+ attention_masks = []
179
+ num_tokens = []
180
+
181
+ for x in tokens:
182
+ attention_mask = []
183
+ tokens_temp = []
184
+ other_embeds = []
185
+ eos = False
186
+ index = 0
187
+ for y in x:
188
+ if isinstance(y, numbers.Integral):
189
+ if eos:
190
+ attention_mask.append(0)
191
+ else:
192
+ attention_mask.append(1)
193
+ token = int(y)
194
+ tokens_temp += [token]
195
+ if not eos and token == cmp_token:
196
+ if end_token is None:
197
+ attention_mask[-1] = 0
198
+ eos = True
199
+ else:
200
+ other_embeds.append((index, y))
201
+ index += 1
202
+
203
+ tokens_embed = torch.tensor([tokens_temp], device=device, dtype=torch.long)
204
+ tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
205
+ index = 0
206
+ pad_extra = 0
207
+ for o in other_embeds:
208
+ emb = o[1]
209
+ if torch.is_tensor(emb):
210
+ emb = {"type": "embedding", "data": emb}
211
+
212
+ emb_type = emb.get("type", None)
213
+ if emb_type == "embedding":
214
+ emb = emb.get("data", None)
215
+ else:
216
+ if hasattr(self.transformer, "preprocess_embed"):
217
+ emb = self.transformer.preprocess_embed(emb, device=device)
218
+ else:
219
+ emb = None
220
+
221
+ if emb is None:
222
+ index += -1
223
+ continue
224
+
225
+ ind = index + o[0]
226
+ emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32)
227
+ emb_shape = emb.shape[1]
228
+ if emb.shape[-1] == tokens_embed.shape[-1]:
229
+ tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
230
+ attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
231
+ index += emb_shape - 1
232
+ else:
233
+ index += -1
234
+ pad_extra += emb_shape
235
+ logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))
236
+
237
+ if pad_extra > 0:
238
+ padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
239
+ tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1)
240
+ attention_mask = attention_mask + [0] * pad_extra
241
+
242
+ embeds_out.append(tokens_embed)
243
+ attention_masks.append(attention_mask)
244
+ num_tokens.append(sum(attention_mask))
245
+
246
+ return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
247
+
248
+ def forward(self, tokens):
249
+ device = self.transformer.get_input_embeddings().weight.device
250
+ embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
251
+
252
+ attention_mask_model = None
253
+ if self.enable_attention_masks:
254
+ attention_mask_model = attention_mask
255
+
256
+ if self.layer == "all":
257
+ intermediate_output = "all"
258
+ else:
259
+ intermediate_output = self.layer_idx
260
+
261
+ outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
262
+
263
+ if self.layer == "last":
264
+ z = outputs[0].float()
265
+ else:
266
+ z = outputs[1].float()
267
+
268
+ if self.zero_out_masked:
269
+ z *= attention_mask.unsqueeze(-1).float()
270
+
271
+ pooled_output = None
272
+ if len(outputs) >= 3:
273
+ if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
274
+ pooled_output = outputs[3].float()
275
+ elif outputs[2] is not None:
276
+ pooled_output = outputs[2].float()
277
+
278
+ extra = {}
279
+ if self.return_attention_masks:
280
+ extra["attention_mask"] = attention_mask
281
+
282
+ if len(extra) > 0:
283
+ return z, pooled_output, extra
284
+
285
+ return z, pooled_output
286
+
287
+ def encode(self, tokens):
288
+ return self(tokens)
289
+
290
+ def load_sd(self, sd):
291
+ return self.transformer.load_state_dict(sd, strict=False)
292
+
293
+ def parse_parentheses(string):
294
+ result = []
295
+ current_item = ""
296
+ nesting_level = 0
297
+ for char in string:
298
+ if char == "(":
299
+ if nesting_level == 0:
300
+ if current_item:
301
+ result.append(current_item)
302
+ current_item = "("
303
+ else:
304
+ current_item = "("
305
+ else:
306
+ current_item += char
307
+ nesting_level += 1
308
+ elif char == ")":
309
+ nesting_level -= 1
310
+ if nesting_level == 0:
311
+ result.append(current_item + ")")
312
+ current_item = ""
313
+ else:
314
+ current_item += char
315
+ else:
316
+ current_item += char
317
+ if current_item:
318
+ result.append(current_item)
319
+ return result
320
+
321
+ def token_weights(string, current_weight):
322
+ a = parse_parentheses(string)
323
+ out = []
324
+ for x in a:
325
+ weight = current_weight
326
+ if len(x) >= 2 and x[-1] == ')' and x[0] == '(':
327
+ x = x[1:-1]
328
+ xx = x.rfind(":")
329
+ weight *= 1.1
330
+ if xx > 0:
331
+ try:
332
+ weight = float(x[xx+1:])
333
+ x = x[:xx]
334
+ except:
335
+ pass
336
+ out += token_weights(x, weight)
337
+ else:
338
+ out += [(x, current_weight)]
339
+ return out
340
+
341
+ def escape_important(text):
342
+ text = text.replace("\\)", "\0\1")
343
+ text = text.replace("\\(", "\0\2")
344
+ return text
345
+
346
+ def unescape_important(text):
347
+ text = text.replace("\0\1", ")")
348
+ text = text.replace("\0\2", "(")
349
+ return text
350
+
351
+ def safe_load_embed_zip(embed_path):
352
+ with zipfile.ZipFile(embed_path) as myzip:
353
+ names = list(filter(lambda a: "data/" in a, myzip.namelist()))
354
+ names.reverse()
355
+ for n in names:
356
+ with myzip.open(n) as myfile:
357
+ data = myfile.read()
358
+ number = len(data) // 4
359
+ length_embed = 1024 #sd2.x
360
+ if number < 768:
361
+ continue
362
+ if number % 768 == 0:
363
+ length_embed = 768 #sd1.x
364
+ num_embeds = number // length_embed
365
+ embed = torch.frombuffer(data, dtype=torch.float)
366
+ out = embed.reshape((num_embeds, length_embed)).clone()
367
+ del embed
368
+ return out
369
+
370
+ def expand_directory_list(directories):
371
+ dirs = set()
372
+ for x in directories:
373
+ dirs.add(x)
374
+ for root, subdir, file in os.walk(x, followlinks=True):
375
+ dirs.add(root)
376
+ return list(dirs)
377
+
378
+ def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format
379
+ out_list = []
380
+ for k in embed:
381
+ if k.startswith(prefix) and k.endswith(suffix):
382
+ out_list.append(embed[k])
383
+ if len(out_list) == 0:
384
+ return None
385
+
386
+ return torch.cat(out_list, dim=0)
387
+
388
+ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
389
+ if isinstance(embedding_directory, str):
390
+ embedding_directory = [embedding_directory]
391
+
392
+ embedding_directory = expand_directory_list(embedding_directory)
393
+
394
+ valid_file = None
395
+ for embed_dir in embedding_directory:
396
+ embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
397
+ embed_dir = os.path.abspath(embed_dir)
398
+ try:
399
+ if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
400
+ continue
401
+ except:
402
+ continue
403
+ if not os.path.isfile(embed_path):
404
+ extensions = ['.safetensors', '.pt', '.bin']
405
+ for x in extensions:
406
+ t = embed_path + x
407
+ if os.path.isfile(t):
408
+ valid_file = t
409
+ break
410
+ else:
411
+ valid_file = embed_path
412
+ if valid_file is not None:
413
+ break
414
+
415
+ if valid_file is None:
416
+ return None
417
+
418
+ embed_path = valid_file
419
+
420
+ embed_out = None
421
+
422
+ try:
423
+ if embed_path.lower().endswith(".safetensors"):
424
+ import safetensors.torch
425
+ embed = safetensors.torch.load_file(embed_path, device="cpu")
426
+ else:
427
+ try:
428
+ embed = torch.load(embed_path, weights_only=True, map_location="cpu")
429
+ except:
430
+ embed_out = safe_load_embed_zip(embed_path)
431
+ except Exception:
432
+ logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
433
+ return None
434
+
435
+ if embed_out is None:
436
+ if 'string_to_param' in embed:
437
+ values = embed['string_to_param'].values()
438
+ embed_out = next(iter(values))
439
+ elif isinstance(embed, list):
440
+ out_list = []
441
+ for x in range(len(embed)):
442
+ for k in embed[x]:
443
+ t = embed[x][k]
444
+ if t.shape[-1] != embedding_size:
445
+ continue
446
+ out_list.append(t.reshape(-1, t.shape[-1]))
447
+ embed_out = torch.cat(out_list, dim=0)
448
+ elif embed_key is not None and embed_key in embed:
449
+ embed_out = embed[embed_key]
450
+ else:
451
+ embed_out = bundled_embed(embed, 'bundle_emb.', '.string_to_param.*')
452
+ if embed_out is None:
453
+ embed_out = bundled_embed(embed, 'bundle_emb.', '.{}'.format(embed_key))
454
+ if embed_out is None:
455
+ values = embed.values()
456
+ embed_out = next(iter(values))
457
+ return embed_out
458
+
459
+ class SDTokenizer:
460
+ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
461
+ if tokenizer_path is None:
462
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
463
+ self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
464
+ self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
465
+ self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
466
+ self.end_token = None
467
+ self.min_padding = min_padding
468
+
469
+ empty = self.tokenizer('')["input_ids"]
470
+ self.tokenizer_adds_end_token = has_end_token
471
+ if has_start_token:
472
+ self.tokens_start = 1
473
+ self.start_token = empty[0]
474
+ if end_token is not None:
475
+ self.end_token = end_token
476
+ else:
477
+ if has_end_token:
478
+ self.end_token = empty[1]
479
+ else:
480
+ self.tokens_start = 0
481
+ self.start_token = None
482
+ if end_token is not None:
483
+ self.end_token = end_token
484
+ else:
485
+ if has_end_token:
486
+ self.end_token = empty[0]
487
+
488
+ if pad_token is not None:
489
+ self.pad_token = pad_token
490
+ elif pad_with_end:
491
+ self.pad_token = self.end_token
492
+ else:
493
+ self.pad_token = 0
494
+
495
+ self.pad_with_end = pad_with_end
496
+ self.pad_to_max_length = pad_to_max_length
497
+
498
+ vocab = self.tokenizer.get_vocab()
499
+ self.inv_vocab = {v: k for k, v in vocab.items()}
500
+ self.embedding_directory = embedding_directory
501
+ self.max_word_length = 8
502
+ self.embedding_identifier = "embedding:"
503
+ self.embedding_size = embedding_size
504
+ self.embedding_key = embedding_key
505
+
506
+ def _try_get_embedding(self, embedding_name:str):
507
+ '''
508
+ Takes a potential embedding name and tries to retrieve it.
509
+ Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
510
+ '''
511
+ split_embed = embedding_name.split()
512
+ embedding_name = split_embed[0]
513
+ leftover = ' '.join(split_embed[1:])
514
+ embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
515
+ if embed is None:
516
+ stripped = embedding_name.strip(',')
517
+ if len(stripped) < len(embedding_name):
518
+ embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
519
+ return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
520
+ return (embed, leftover)
521
+
522
+
523
+ def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
524
+ '''
525
+ Takes a prompt and converts it to a list of (token, weight, word id) elements.
526
+ Tokens can both be integer tokens and pre computed CLIP tensors.
527
+ Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
528
+ Returned list has the dimensions NxM where M is the input size of CLIP
529
+ '''
530
+ min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
531
+ min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
532
+
533
+ text = escape_important(text)
534
+ parsed_weights = token_weights(text, 1.0)
535
+
536
+ # tokenize words
537
+ tokens = []
538
+ for weighted_segment, weight in parsed_weights:
539
+ to_tokenize = unescape_important(weighted_segment)
540
+ split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
541
+ to_tokenize = [split[0]]
542
+ for i in range(1, len(split)):
543
+ to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
544
+
545
+ to_tokenize = [x for x in to_tokenize if x != ""]
546
+ for word in to_tokenize:
547
+ # if we find an embedding, deal with the embedding
548
+ if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
549
+ embedding_name = word[len(self.embedding_identifier):].strip('\n')
550
+ embed, leftover = self._try_get_embedding(embedding_name)
551
+ if embed is None:
552
+ logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
553
+ else:
554
+ if len(embed.shape) == 1:
555
+ tokens.append([(embed, weight)])
556
+ else:
557
+ tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
558
+ #if we accidentally have leftover text, continue parsing using leftover, else move on to next word
559
+ if leftover != "":
560
+ word = leftover
561
+ else:
562
+ continue
563
+ end = 999999999999
564
+ if self.tokenizer_adds_end_token:
565
+ end = -1
566
+ #parse word
567
+ tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]])
568
+
569
+ #reshape token array to CLIP input size
570
+ batched_tokens = []
571
+ batch = []
572
+ if self.start_token is not None:
573
+ batch.append((self.start_token, 1.0, 0))
574
+ batched_tokens.append(batch)
575
+ for i, t_group in enumerate(tokens):
576
+ #determine if we're going to try and keep the tokens in a single batch
577
+ is_large = len(t_group) >= self.max_word_length
578
+ if self.end_token is not None:
579
+ has_end_token = 1
580
+ else:
581
+ has_end_token = 0
582
+
583
+ while len(t_group) > 0:
584
+ if len(t_group) + len(batch) > self.max_length - has_end_token:
585
+ remaining_length = self.max_length - len(batch) - has_end_token
586
+ #break word in two and add end token
587
+ if is_large:
588
+ batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
589
+ if self.end_token is not None:
590
+ batch.append((self.end_token, 1.0, 0))
591
+ t_group = t_group[remaining_length:]
592
+ #add end token and pad
593
+ else:
594
+ if self.end_token is not None:
595
+ batch.append((self.end_token, 1.0, 0))
596
+ if self.pad_to_max_length:
597
+ batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
598
+ #start new batch
599
+ batch = []
600
+ if self.start_token is not None:
601
+ batch.append((self.start_token, 1.0, 0))
602
+ batched_tokens.append(batch)
603
+ else:
604
+ batch.extend([(t,w,i+1) for t,w in t_group])
605
+ t_group = []
606
+
607
+ #fill last batch
608
+ if self.end_token is not None:
609
+ batch.append((self.end_token, 1.0, 0))
610
+ if min_padding is not None:
611
+ batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
612
+ if self.pad_to_max_length and len(batch) < self.max_length:
613
+ batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
614
+ if min_length is not None and len(batch) < min_length:
615
+ batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
616
+
617
+ if not return_word_ids:
618
+ batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
619
+
620
+ return batched_tokens
621
+
622
+
623
+ def untokenize(self, token_weight_pair):
624
+ return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
625
+
626
+ def state_dict(self):
627
+ return {}
628
+
629
+ class SD1Tokenizer:
630
+ def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
631
+ if name is not None:
632
+ self.clip_name = name
633
+ self.clip = "{}".format(self.clip_name)
634
+ else:
635
+ self.clip_name = clip_name
636
+ self.clip = "clip_{}".format(self.clip_name)
637
+
638
+ tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
639
+ setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
640
+
641
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
642
+ out = {}
643
+ out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids, **kwargs)
644
+ return out
645
+
646
+ def untokenize(self, token_weight_pair):
647
+ return getattr(self, self.clip).untokenize(token_weight_pair)
648
+
649
+ def state_dict(self):
650
+ return getattr(self, self.clip).state_dict()
651
+
652
+ class SD1CheckpointClipModel(SDClipModel):
653
+ def __init__(self, device="cpu", dtype=None, model_options={}):
654
+ super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
655
+
656
+ class SD1ClipModel(torch.nn.Module):
657
+ def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, name=None, **kwargs):
658
+ super().__init__()
659
+
660
+ if name is not None:
661
+ self.clip_name = name
662
+ self.clip = "{}".format(self.clip_name)
663
+ else:
664
+ self.clip_name = clip_name
665
+ self.clip = "clip_{}".format(self.clip_name)
666
+
667
+ clip_model = model_options.get("{}_class".format(self.clip), clip_model)
668
+ model_options = {**model_options, "model_name": self.clip}
669
+ setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
670
+
671
+ self.dtypes = set()
672
+ if dtype is not None:
673
+ self.dtypes.add(dtype)
674
+
675
+ def set_clip_options(self, options):
676
+ getattr(self, self.clip).set_clip_options(options)
677
+
678
+ def reset_clip_options(self):
679
+ getattr(self, self.clip).reset_clip_options()
680
+
681
+ def encode_token_weights(self, token_weight_pairs):
682
+ token_weight_pairs = token_weight_pairs[self.clip_name]
683
+ out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
684
+ return out
685
+
686
+ def load_sd(self, sd):
687
+ return getattr(self, self.clip).load_sd(sd)
ComfyUI/comfy/sd1_clip_config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/clip-vit-large-patch14",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 49407,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.24.0",
24
+ "vocab_size": 49408
25
+ }
ComfyUI/comfy/sd1_tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
ComfyUI/comfy/sd1_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "do_lower_case": true,
12
+ "eos_token": {
13
+ "__type": "AddedToken",
14
+ "content": "<|endoftext|>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "errors": "replace",
21
+ "model_max_length": 8192,
22
+ "name_or_path": "openai/clip-vit-large-patch14",
23
+ "pad_token": "<|endoftext|>",
24
+ "special_tokens_map_file": "./special_tokens_map.json",
25
+ "tokenizer_class": "CLIPTokenizer",
26
+ "unk_token": {
27
+ "__type": "AddedToken",
28
+ "content": "<|endoftext|>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
ComfyUI/comfy/sd1_tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
ComfyUI/comfy/supported_models.py ADDED
@@ -0,0 +1,1235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import model_base
3
+ from . import utils
4
+
5
+ from . import sd1_clip
6
+ from . import sdxl_clip
7
+ import comfy.text_encoders.sd2_clip
8
+ import comfy.text_encoders.sd3_clip
9
+ import comfy.text_encoders.sa_t5
10
+ import comfy.text_encoders.aura_t5
11
+ import comfy.text_encoders.pixart_t5
12
+ import comfy.text_encoders.hydit
13
+ import comfy.text_encoders.flux
14
+ import comfy.text_encoders.genmo
15
+ import comfy.text_encoders.lt
16
+ import comfy.text_encoders.hunyuan_video
17
+ import comfy.text_encoders.cosmos
18
+ import comfy.text_encoders.lumina2
19
+ import comfy.text_encoders.wan
20
+ import comfy.text_encoders.ace
21
+ import comfy.text_encoders.omnigen2
22
+
23
+ from . import supported_models_base
24
+ from . import latent_formats
25
+
26
+ from . import diffusers_convert
27
+
28
+ class SD15(supported_models_base.BASE):
29
+ unet_config = {
30
+ "context_dim": 768,
31
+ "model_channels": 320,
32
+ "use_linear_in_transformer": False,
33
+ "adm_in_channels": None,
34
+ "use_temporal_attention": False,
35
+ }
36
+
37
+ unet_extra_config = {
38
+ "num_heads": 8,
39
+ "num_head_channels": -1,
40
+ }
41
+
42
+ latent_format = latent_formats.SD15
43
+ memory_usage_factor = 1.0
44
+
45
+ def process_clip_state_dict(self, state_dict):
46
+ k = list(state_dict.keys())
47
+ for x in k:
48
+ if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
49
+ y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
50
+ state_dict[y] = state_dict.pop(x)
51
+
52
+ if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in state_dict:
53
+ ids = state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids']
54
+ if ids.dtype == torch.float32:
55
+ state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
56
+
57
+ replace_prefix = {}
58
+ replace_prefix["cond_stage_model."] = "clip_l."
59
+ state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
60
+ return state_dict
61
+
62
+ def process_clip_state_dict_for_saving(self, state_dict):
63
+ pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
64
+ for p in pop_keys:
65
+ if p in state_dict:
66
+ state_dict.pop(p)
67
+
68
+ replace_prefix = {"clip_l.": "cond_stage_model."}
69
+ return utils.state_dict_prefix_replace(state_dict, replace_prefix)
70
+
71
+ def clip_target(self, state_dict={}):
72
+ return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
73
+
74
+ class SD20(supported_models_base.BASE):
75
+ unet_config = {
76
+ "context_dim": 1024,
77
+ "model_channels": 320,
78
+ "use_linear_in_transformer": True,
79
+ "adm_in_channels": None,
80
+ "use_temporal_attention": False,
81
+ }
82
+
83
+ unet_extra_config = {
84
+ "num_heads": -1,
85
+ "num_head_channels": 64,
86
+ "attn_precision": torch.float32,
87
+ }
88
+
89
+ latent_format = latent_formats.SD15
90
+ memory_usage_factor = 1.0
91
+
92
+ def model_type(self, state_dict, prefix=""):
93
+ if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
94
+ k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
95
+ out = state_dict.get(k, None)
96
+ if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
97
+ return model_base.ModelType.V_PREDICTION
98
+ return model_base.ModelType.EPS
99
+
100
+ def process_clip_state_dict(self, state_dict):
101
+ replace_prefix = {}
102
+ replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
103
+ replace_prefix["cond_stage_model.model."] = "clip_h."
104
+ state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
105
+ state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
106
+ return state_dict
107
+
108
+ def process_clip_state_dict_for_saving(self, state_dict):
109
+ replace_prefix = {}
110
+ replace_prefix["clip_h"] = "cond_stage_model.model"
111
+ state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
112
+ state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
113
+ return state_dict
114
+
115
+ def clip_target(self, state_dict={}):
116
+ return supported_models_base.ClipTarget(comfy.text_encoders.sd2_clip.SD2Tokenizer, comfy.text_encoders.sd2_clip.SD2ClipModel)
117
+
118
+ class SD21UnclipL(SD20):
119
+ unet_config = {
120
+ "context_dim": 1024,
121
+ "model_channels": 320,
122
+ "use_linear_in_transformer": True,
123
+ "adm_in_channels": 1536,
124
+ "use_temporal_attention": False,
125
+ }
126
+
127
+ clip_vision_prefix = "embedder.model.visual."
128
+ noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 768}
129
+
130
+
131
+ class SD21UnclipH(SD20):
132
+ unet_config = {
133
+ "context_dim": 1024,
134
+ "model_channels": 320,
135
+ "use_linear_in_transformer": True,
136
+ "adm_in_channels": 2048,
137
+ "use_temporal_attention": False,
138
+ }
139
+
140
+ clip_vision_prefix = "embedder.model.visual."
141
+ noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1024}
142
+
143
+ class SDXLRefiner(supported_models_base.BASE):
144
+ unet_config = {
145
+ "model_channels": 384,
146
+ "use_linear_in_transformer": True,
147
+ "context_dim": 1280,
148
+ "adm_in_channels": 2560,
149
+ "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
150
+ "use_temporal_attention": False,
151
+ }
152
+
153
+ latent_format = latent_formats.SDXL
154
+ memory_usage_factor = 1.0
155
+
156
+ def get_model(self, state_dict, prefix="", device=None):
157
+ return model_base.SDXLRefiner(self, device=device)
158
+
159
+ def process_clip_state_dict(self, state_dict):
160
+ keys_to_replace = {}
161
+ replace_prefix = {}
162
+ replace_prefix["conditioner.embedders.0.model."] = "clip_g."
163
+ state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
164
+
165
+ state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
166
+ state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
167
+ return state_dict
168
+
169
+ def process_clip_state_dict_for_saving(self, state_dict):
170
+ replace_prefix = {}
171
+ state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
172
+ if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
173
+ state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
174
+ replace_prefix["clip_g"] = "conditioner.embedders.0.model"
175
+ state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
176
+ return state_dict_g
177
+
178
+ def clip_target(self, state_dict={}):
179
+ return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
180
+
181
+ class SDXL(supported_models_base.BASE):
182
+ unet_config = {
183
+ "model_channels": 320,
184
+ "use_linear_in_transformer": True,
185
+ "transformer_depth": [0, 0, 2, 2, 10, 10],
186
+ "context_dim": 2048,
187
+ "adm_in_channels": 2816,
188
+ "use_temporal_attention": False,
189
+ }
190
+
191
+ latent_format = latent_formats.SDXL
192
+
193
+ memory_usage_factor = 0.8
194
+
195
+ def model_type(self, state_dict, prefix=""):
196
+ if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
197
+ self.latent_format = latent_formats.SDXL_Playground_2_5()
198
+ self.sampling_settings["sigma_data"] = 0.5
199
+ self.sampling_settings["sigma_max"] = 80.0
200
+ self.sampling_settings["sigma_min"] = 0.002
201
+ return model_base.ModelType.EDM
202
+ elif "edm_vpred.sigma_max" in state_dict:
203
+ self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item())
204
+ if "edm_vpred.sigma_min" in state_dict:
205
+ self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
206
+ return model_base.ModelType.V_PREDICTION_EDM
207
+ elif "v_pred" in state_dict:
208
+ if "ztsnr" in state_dict: #Some zsnr anime checkpoints
209
+ self.sampling_settings["zsnr"] = True
210
+ return model_base.ModelType.V_PREDICTION
211
+ else:
212
+ return model_base.ModelType.EPS
213
+
214
+ def get_model(self, state_dict, prefix="", device=None):
215
+ out = model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
216
+ if self.inpaint_model():
217
+ out.set_inpaint()
218
+ return out
219
+
220
+ def process_clip_state_dict(self, state_dict):
221
+ keys_to_replace = {}
222
+ replace_prefix = {}
223
+
224
+ replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model"
225
+ replace_prefix["conditioner.embedders.1.model."] = "clip_g."
226
+ state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
227
+
228
+ state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
229
+ state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
230
+ return state_dict
231
+
232
+ def process_clip_state_dict_for_saving(self, state_dict):
233
+ replace_prefix = {}
234
+ state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
235
+ for k in state_dict:
236
+ if k.startswith("clip_l"):
237
+ state_dict_g[k] = state_dict[k]
238
+
239
+ state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1))
240
+ pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
241
+ for p in pop_keys:
242
+ if p in state_dict_g:
243
+ state_dict_g.pop(p)
244
+
245
+ replace_prefix["clip_g"] = "conditioner.embedders.1.model"
246
+ replace_prefix["clip_l"] = "conditioner.embedders.0"
247
+ state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
248
+ return state_dict_g
249
+
250
+ def clip_target(self, state_dict={}):
251
+ return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
252
+
253
+ class SSD1B(SDXL):
254
+ unet_config = {
255
+ "model_channels": 320,
256
+ "use_linear_in_transformer": True,
257
+ "transformer_depth": [0, 0, 2, 2, 4, 4],
258
+ "context_dim": 2048,
259
+ "adm_in_channels": 2816,
260
+ "use_temporal_attention": False,
261
+ }
262
+
263
+ class Segmind_Vega(SDXL):
264
+ unet_config = {
265
+ "model_channels": 320,
266
+ "use_linear_in_transformer": True,
267
+ "transformer_depth": [0, 0, 1, 1, 2, 2],
268
+ "context_dim": 2048,
269
+ "adm_in_channels": 2816,
270
+ "use_temporal_attention": False,
271
+ }
272
+
273
+ class KOALA_700M(SDXL):
274
+ unet_config = {
275
+ "model_channels": 320,
276
+ "use_linear_in_transformer": True,
277
+ "transformer_depth": [0, 2, 5],
278
+ "context_dim": 2048,
279
+ "adm_in_channels": 2816,
280
+ "use_temporal_attention": False,
281
+ }
282
+
283
+ class KOALA_1B(SDXL):
284
+ unet_config = {
285
+ "model_channels": 320,
286
+ "use_linear_in_transformer": True,
287
+ "transformer_depth": [0, 2, 6],
288
+ "context_dim": 2048,
289
+ "adm_in_channels": 2816,
290
+ "use_temporal_attention": False,
291
+ }
292
+
293
+ class SVD_img2vid(supported_models_base.BASE):
294
+ unet_config = {
295
+ "model_channels": 320,
296
+ "in_channels": 8,
297
+ "use_linear_in_transformer": True,
298
+ "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
299
+ "context_dim": 1024,
300
+ "adm_in_channels": 768,
301
+ "use_temporal_attention": True,
302
+ "use_temporal_resblock": True
303
+ }
304
+
305
+ unet_extra_config = {
306
+ "num_heads": -1,
307
+ "num_head_channels": 64,
308
+ "attn_precision": torch.float32,
309
+ }
310
+
311
+ clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
312
+
313
+ latent_format = latent_formats.SD15
314
+
315
+ sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
316
+
317
+ def get_model(self, state_dict, prefix="", device=None):
318
+ out = model_base.SVD_img2vid(self, device=device)
319
+ return out
320
+
321
+ def clip_target(self, state_dict={}):
322
+ return None
323
+
324
+ class SV3D_u(SVD_img2vid):
325
+ unet_config = {
326
+ "model_channels": 320,
327
+ "in_channels": 8,
328
+ "use_linear_in_transformer": True,
329
+ "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
330
+ "context_dim": 1024,
331
+ "adm_in_channels": 256,
332
+ "use_temporal_attention": True,
333
+ "use_temporal_resblock": True
334
+ }
335
+
336
+ vae_key_prefix = ["conditioner.embedders.1.encoder."]
337
+
338
+ def get_model(self, state_dict, prefix="", device=None):
339
+ out = model_base.SV3D_u(self, device=device)
340
+ return out
341
+
342
+ class SV3D_p(SV3D_u):
343
+ unet_config = {
344
+ "model_channels": 320,
345
+ "in_channels": 8,
346
+ "use_linear_in_transformer": True,
347
+ "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
348
+ "context_dim": 1024,
349
+ "adm_in_channels": 1280,
350
+ "use_temporal_attention": True,
351
+ "use_temporal_resblock": True
352
+ }
353
+
354
+
355
+ def get_model(self, state_dict, prefix="", device=None):
356
+ out = model_base.SV3D_p(self, device=device)
357
+ return out
358
+
359
+ class Stable_Zero123(supported_models_base.BASE):
360
+ unet_config = {
361
+ "context_dim": 768,
362
+ "model_channels": 320,
363
+ "use_linear_in_transformer": False,
364
+ "adm_in_channels": None,
365
+ "use_temporal_attention": False,
366
+ "in_channels": 8,
367
+ }
368
+
369
+ unet_extra_config = {
370
+ "num_heads": 8,
371
+ "num_head_channels": -1,
372
+ }
373
+
374
+ required_keys = {
375
+ "cc_projection.weight": None,
376
+ "cc_projection.bias": None,
377
+ }
378
+
379
+ clip_vision_prefix = "cond_stage_model.model.visual."
380
+
381
+ latent_format = latent_formats.SD15
382
+
383
+ def get_model(self, state_dict, prefix="", device=None):
384
+ out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
385
+ return out
386
+
387
+ def clip_target(self, state_dict={}):
388
+ return None
389
+
390
+ class SD_X4Upscaler(SD20):
391
+ unet_config = {
392
+ "context_dim": 1024,
393
+ "model_channels": 256,
394
+ 'in_channels': 7,
395
+ "use_linear_in_transformer": True,
396
+ "adm_in_channels": None,
397
+ "use_temporal_attention": False,
398
+ }
399
+
400
+ unet_extra_config = {
401
+ "disable_self_attentions": [True, True, True, False],
402
+ "num_classes": 1000,
403
+ "num_heads": 8,
404
+ "num_head_channels": -1,
405
+ }
406
+
407
+ latent_format = latent_formats.SD_X4
408
+
409
+ sampling_settings = {
410
+ "linear_start": 0.0001,
411
+ "linear_end": 0.02,
412
+ }
413
+
414
+ def get_model(self, state_dict, prefix="", device=None):
415
+ out = model_base.SD_X4Upscaler(self, device=device)
416
+ return out
417
+
418
+ class Stable_Cascade_C(supported_models_base.BASE):
419
+ unet_config = {
420
+ "stable_cascade_stage": 'c',
421
+ }
422
+
423
+ unet_extra_config = {}
424
+
425
+ latent_format = latent_formats.SC_Prior
426
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
427
+
428
+ sampling_settings = {
429
+ "shift": 2.0,
430
+ }
431
+
432
+ vae_key_prefix = ["vae."]
433
+ text_encoder_key_prefix = ["text_encoder."]
434
+ clip_vision_prefix = "clip_l_vision."
435
+
436
+ def process_unet_state_dict(self, state_dict):
437
+ key_list = list(state_dict.keys())
438
+ for y in ["weight", "bias"]:
439
+ suffix = "in_proj_{}".format(y)
440
+ keys = filter(lambda a: a.endswith(suffix), key_list)
441
+ for k_from in keys:
442
+ weights = state_dict.pop(k_from)
443
+ prefix = k_from[:-(len(suffix) + 1)]
444
+ shape_from = weights.shape[0] // 3
445
+ for x in range(3):
446
+ p = ["to_q", "to_k", "to_v"]
447
+ k_to = "{}.{}.{}".format(prefix, p[x], y)
448
+ state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
449
+ return state_dict
450
+
451
+ def process_clip_state_dict(self, state_dict):
452
+ state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
453
+ if "clip_g.text_projection" in state_dict:
454
+ state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1)
455
+ return state_dict
456
+
457
+ def get_model(self, state_dict, prefix="", device=None):
458
+ out = model_base.StableCascade_C(self, device=device)
459
+ return out
460
+
461
+ def clip_target(self, state_dict={}):
462
+ return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
463
+
464
+ class Stable_Cascade_B(Stable_Cascade_C):
465
+ unet_config = {
466
+ "stable_cascade_stage": 'b',
467
+ }
468
+
469
+ unet_extra_config = {}
470
+
471
+ latent_format = latent_formats.SC_B
472
+ supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
473
+
474
+ sampling_settings = {
475
+ "shift": 1.0,
476
+ }
477
+
478
+ clip_vision_prefix = None
479
+
480
+ def get_model(self, state_dict, prefix="", device=None):
481
+ out = model_base.StableCascade_B(self, device=device)
482
+ return out
483
+
484
+ class SD15_instructpix2pix(SD15):
485
+ unet_config = {
486
+ "context_dim": 768,
487
+ "model_channels": 320,
488
+ "use_linear_in_transformer": False,
489
+ "adm_in_channels": None,
490
+ "use_temporal_attention": False,
491
+ "in_channels": 8,
492
+ }
493
+
494
+ def get_model(self, state_dict, prefix="", device=None):
495
+ return model_base.SD15_instructpix2pix(self, device=device)
496
+
497
+ class SDXL_instructpix2pix(SDXL):
498
+ unet_config = {
499
+ "model_channels": 320,
500
+ "use_linear_in_transformer": True,
501
+ "transformer_depth": [0, 0, 2, 2, 10, 10],
502
+ "context_dim": 2048,
503
+ "adm_in_channels": 2816,
504
+ "use_temporal_attention": False,
505
+ "in_channels": 8,
506
+ }
507
+
508
+ def get_model(self, state_dict, prefix="", device=None):
509
+ return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
510
+
511
+ class LotusD(SD20):
512
+ unet_config = {
513
+ "model_channels": 320,
514
+ "use_linear_in_transformer": True,
515
+ "use_temporal_attention": False,
516
+ "adm_in_channels": 4,
517
+ "in_channels": 4,
518
+ }
519
+
520
+ unet_extra_config = {
521
+ "num_classes": 'sequential'
522
+ }
523
+
524
+ def get_model(self, state_dict, prefix="", device=None):
525
+ return model_base.Lotus(self, device=device)
526
+
527
+ class SD3(supported_models_base.BASE):
528
+ unet_config = {
529
+ "in_channels": 16,
530
+ "pos_embed_scaling_factor": None,
531
+ }
532
+
533
+ sampling_settings = {
534
+ "shift": 3.0,
535
+ }
536
+
537
+ unet_extra_config = {}
538
+ latent_format = latent_formats.SD3
539
+
540
+ memory_usage_factor = 1.2
541
+
542
+ text_encoder_key_prefix = ["text_encoders."]
543
+
544
+ def get_model(self, state_dict, prefix="", device=None):
545
+ out = model_base.SD3(self, device=device)
546
+ return out
547
+
548
+ def clip_target(self, state_dict={}):
549
+ clip_l = False
550
+ clip_g = False
551
+ t5 = False
552
+ pref = self.text_encoder_key_prefix[0]
553
+ if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
554
+ clip_l = True
555
+ if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
556
+ clip_g = True
557
+ t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
558
+ if "dtype_t5" in t5_detect:
559
+ t5 = True
560
+
561
+ return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, **t5_detect))
562
+
563
+ class StableAudio(supported_models_base.BASE):
564
+ unet_config = {
565
+ "audio_model": "dit1.0",
566
+ }
567
+
568
+ sampling_settings = {"sigma_max": 500.0, "sigma_min": 0.03}
569
+
570
+ unet_extra_config = {}
571
+ latent_format = latent_formats.StableAudio1
572
+
573
+ text_encoder_key_prefix = ["text_encoders."]
574
+ vae_key_prefix = ["pretransform.model."]
575
+
576
+ def get_model(self, state_dict, prefix="", device=None):
577
+ seconds_start_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_start.": ""}, filter_keys=True)
578
+ seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True)
579
+ return model_base.StableAudio1(self, seconds_start_embedder_weights=seconds_start_sd, seconds_total_embedder_weights=seconds_total_sd, device=device)
580
+
581
+ def process_unet_state_dict(self, state_dict):
582
+ for k in list(state_dict.keys()):
583
+ if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): #These weights are all zero
584
+ state_dict.pop(k)
585
+ return state_dict
586
+
587
+ def process_unet_state_dict_for_saving(self, state_dict):
588
+ replace_prefix = {"": "model.model."}
589
+ return utils.state_dict_prefix_replace(state_dict, replace_prefix)
590
+
591
+ def clip_target(self, state_dict={}):
592
+ return supported_models_base.ClipTarget(comfy.text_encoders.sa_t5.SAT5Tokenizer, comfy.text_encoders.sa_t5.SAT5Model)
593
+
594
+ class AuraFlow(supported_models_base.BASE):
595
+ unet_config = {
596
+ "cond_seq_dim": 2048,
597
+ }
598
+
599
+ sampling_settings = {
600
+ "multiplier": 1.0,
601
+ "shift": 1.73,
602
+ }
603
+
604
+ unet_extra_config = {}
605
+ latent_format = latent_formats.SDXL
606
+
607
+ vae_key_prefix = ["vae."]
608
+ text_encoder_key_prefix = ["text_encoders."]
609
+
610
+ def get_model(self, state_dict, prefix="", device=None):
611
+ out = model_base.AuraFlow(self, device=device)
612
+ return out
613
+
614
+ def clip_target(self, state_dict={}):
615
+ return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
616
+
617
+ class PixArtAlpha(supported_models_base.BASE):
618
+ unet_config = {
619
+ "image_model": "pixart_alpha",
620
+ }
621
+
622
+ sampling_settings = {
623
+ "beta_schedule" : "sqrt_linear",
624
+ "linear_start" : 0.0001,
625
+ "linear_end" : 0.02,
626
+ "timesteps" : 1000,
627
+ }
628
+
629
+ unet_extra_config = {}
630
+ latent_format = latent_formats.SD15
631
+
632
+ memory_usage_factor = 0.5
633
+
634
+ vae_key_prefix = ["vae."]
635
+ text_encoder_key_prefix = ["text_encoders."]
636
+
637
+ def get_model(self, state_dict, prefix="", device=None):
638
+ out = model_base.PixArt(self, device=device)
639
+ return out.eval()
640
+
641
+ def clip_target(self, state_dict={}):
642
+ return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.PixArtT5XXL)
643
+
644
+ class PixArtSigma(PixArtAlpha):
645
+ unet_config = {
646
+ "image_model": "pixart_sigma",
647
+ }
648
+ latent_format = latent_formats.SDXL
649
+
650
+ class HunyuanDiT(supported_models_base.BASE):
651
+ unet_config = {
652
+ "image_model": "hydit",
653
+ }
654
+
655
+ unet_extra_config = {
656
+ "attn_precision": torch.float32,
657
+ }
658
+
659
+ sampling_settings = {
660
+ "linear_start": 0.00085,
661
+ "linear_end": 0.018,
662
+ }
663
+
664
+ latent_format = latent_formats.SDXL
665
+
666
+ memory_usage_factor = 1.3
667
+
668
+ vae_key_prefix = ["vae."]
669
+ text_encoder_key_prefix = ["text_encoders."]
670
+
671
+ def get_model(self, state_dict, prefix="", device=None):
672
+ out = model_base.HunyuanDiT(self, device=device)
673
+ return out
674
+
675
+ def clip_target(self, state_dict={}):
676
+ return supported_models_base.ClipTarget(comfy.text_encoders.hydit.HyditTokenizer, comfy.text_encoders.hydit.HyditModel)
677
+
678
+ class HunyuanDiT1(HunyuanDiT):
679
+ unet_config = {
680
+ "image_model": "hydit1",
681
+ }
682
+
683
+ unet_extra_config = {}
684
+
685
+ sampling_settings = {
686
+ "linear_start" : 0.00085,
687
+ "linear_end" : 0.03,
688
+ }
689
+
690
+ class Flux(supported_models_base.BASE):
691
+ unet_config = {
692
+ "image_model": "flux",
693
+ "guidance_embed": True,
694
+ }
695
+
696
+ sampling_settings = {
697
+ }
698
+
699
+ unet_extra_config = {}
700
+ latent_format = latent_formats.Flux
701
+
702
+ memory_usage_factor = 2.8
703
+
704
+ supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
705
+
706
+ vae_key_prefix = ["vae."]
707
+ text_encoder_key_prefix = ["text_encoders."]
708
+
709
+ def get_model(self, state_dict, prefix="", device=None):
710
+ out = model_base.Flux(self, device=device)
711
+ return out
712
+
713
+ def clip_target(self, state_dict={}):
714
+ pref = self.text_encoder_key_prefix[0]
715
+ t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
716
+ return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
717
+
718
+ class FluxInpaint(Flux):
719
+ unet_config = {
720
+ "image_model": "flux",
721
+ "guidance_embed": True,
722
+ "in_channels": 96,
723
+ }
724
+
725
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
726
+
727
+ class FluxSchnell(Flux):
728
+ unet_config = {
729
+ "image_model": "flux",
730
+ "guidance_embed": False,
731
+ }
732
+
733
+ sampling_settings = {
734
+ "multiplier": 1.0,
735
+ "shift": 1.0,
736
+ }
737
+
738
+ def get_model(self, state_dict, prefix="", device=None):
739
+ out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
740
+ return out
741
+
742
+ class GenmoMochi(supported_models_base.BASE):
743
+ unet_config = {
744
+ "image_model": "mochi_preview",
745
+ }
746
+
747
+ sampling_settings = {
748
+ "multiplier": 1.0,
749
+ "shift": 6.0,
750
+ }
751
+
752
+ unet_extra_config = {}
753
+ latent_format = latent_formats.Mochi
754
+
755
+ memory_usage_factor = 2.0 #TODO
756
+
757
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
758
+
759
+ vae_key_prefix = ["vae."]
760
+ text_encoder_key_prefix = ["text_encoders."]
761
+
762
+ def get_model(self, state_dict, prefix="", device=None):
763
+ out = model_base.GenmoMochi(self, device=device)
764
+ return out
765
+
766
+ def clip_target(self, state_dict={}):
767
+ pref = self.text_encoder_key_prefix[0]
768
+ t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
769
+ return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
770
+
771
+ class LTXV(supported_models_base.BASE):
772
+ unet_config = {
773
+ "image_model": "ltxv",
774
+ }
775
+
776
+ sampling_settings = {
777
+ "shift": 2.37,
778
+ }
779
+
780
+ unet_extra_config = {}
781
+ latent_format = latent_formats.LTXV
782
+
783
+ memory_usage_factor = 5.5 # TODO: img2vid is about 2x vs txt2vid
784
+
785
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
786
+
787
+ vae_key_prefix = ["vae."]
788
+ text_encoder_key_prefix = ["text_encoders."]
789
+
790
+ def __init__(self, unet_config):
791
+ super().__init__(unet_config)
792
+ self.memory_usage_factor = (unet_config.get("cross_attention_dim", 2048) / 2048) * 5.5
793
+
794
+ def get_model(self, state_dict, prefix="", device=None):
795
+ out = model_base.LTXV(self, device=device)
796
+ return out
797
+
798
+ def clip_target(self, state_dict={}):
799
+ pref = self.text_encoder_key_prefix[0]
800
+ t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
801
+ return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
802
+
803
+ class HunyuanVideo(supported_models_base.BASE):
804
+ unet_config = {
805
+ "image_model": "hunyuan_video",
806
+ }
807
+
808
+ sampling_settings = {
809
+ "shift": 7.0,
810
+ }
811
+
812
+ unet_extra_config = {}
813
+ latent_format = latent_formats.HunyuanVideo
814
+
815
+ memory_usage_factor = 1.8 #TODO
816
+
817
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
818
+
819
+ vae_key_prefix = ["vae."]
820
+ text_encoder_key_prefix = ["text_encoders."]
821
+
822
+ def get_model(self, state_dict, prefix="", device=None):
823
+ out = model_base.HunyuanVideo(self, device=device)
824
+ return out
825
+
826
+ def process_unet_state_dict(self, state_dict):
827
+ out_sd = {}
828
+ for k in list(state_dict.keys()):
829
+ key_out = k
830
+ key_out = key_out.replace("txt_in.t_embedder.mlp.0.", "txt_in.t_embedder.in_layer.").replace("txt_in.t_embedder.mlp.2.", "txt_in.t_embedder.out_layer.")
831
+ key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
832
+ key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
833
+ key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
834
+ key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
835
+ key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
836
+ key_out = key_out.replace("_attn_proj.", "_attn.proj.")
837
+ key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
838
+ key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
839
+ out_sd[key_out] = state_dict[k]
840
+ return out_sd
841
+
842
+ def process_unet_state_dict_for_saving(self, state_dict):
843
+ replace_prefix = {"": "model.model."}
844
+ return utils.state_dict_prefix_replace(state_dict, replace_prefix)
845
+
846
+ def clip_target(self, state_dict={}):
847
+ pref = self.text_encoder_key_prefix[0]
848
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
849
+ return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
850
+
851
+ class HunyuanVideoI2V(HunyuanVideo):
852
+ unet_config = {
853
+ "image_model": "hunyuan_video",
854
+ "in_channels": 33,
855
+ }
856
+
857
+ def get_model(self, state_dict, prefix="", device=None):
858
+ out = model_base.HunyuanVideoI2V(self, device=device)
859
+ return out
860
+
861
+ class HunyuanVideoSkyreelsI2V(HunyuanVideo):
862
+ unet_config = {
863
+ "image_model": "hunyuan_video",
864
+ "in_channels": 32,
865
+ }
866
+
867
+ def get_model(self, state_dict, prefix="", device=None):
868
+ out = model_base.HunyuanVideoSkyreelsI2V(self, device=device)
869
+ return out
870
+
871
+ class CosmosT2V(supported_models_base.BASE):
872
+ unet_config = {
873
+ "image_model": "cosmos",
874
+ "in_channels": 16,
875
+ }
876
+
877
+ sampling_settings = {
878
+ "sigma_data": 0.5,
879
+ "sigma_max": 80.0,
880
+ "sigma_min": 0.002,
881
+ }
882
+
883
+ unet_extra_config = {}
884
+ latent_format = latent_formats.Cosmos1CV8x8x8
885
+
886
+ memory_usage_factor = 1.6 #TODO
887
+
888
+ supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
889
+
890
+ vae_key_prefix = ["vae."]
891
+ text_encoder_key_prefix = ["text_encoders."]
892
+
893
+ def get_model(self, state_dict, prefix="", device=None):
894
+ out = model_base.CosmosVideo(self, device=device)
895
+ return out
896
+
897
+ def clip_target(self, state_dict={}):
898
+ pref = self.text_encoder_key_prefix[0]
899
+ t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
900
+ return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
901
+
902
+ class CosmosI2V(CosmosT2V):
903
+ unet_config = {
904
+ "image_model": "cosmos",
905
+ "in_channels": 17,
906
+ }
907
+
908
+ def get_model(self, state_dict, prefix="", device=None):
909
+ out = model_base.CosmosVideo(self, image_to_video=True, device=device)
910
+ return out
911
+
912
+ class CosmosT2IPredict2(supported_models_base.BASE):
913
+ unet_config = {
914
+ "image_model": "cosmos_predict2",
915
+ "in_channels": 16,
916
+ }
917
+
918
+ sampling_settings = {
919
+ "sigma_data": 1.0,
920
+ "sigma_max": 80.0,
921
+ "sigma_min": 0.002,
922
+ }
923
+
924
+ unet_extra_config = {}
925
+ latent_format = latent_formats.Wan21
926
+
927
+ memory_usage_factor = 1.0
928
+
929
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
930
+
931
+ def __init__(self, unet_config):
932
+ super().__init__(unet_config)
933
+ self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
934
+
935
+ def get_model(self, state_dict, prefix="", device=None):
936
+ out = model_base.CosmosPredict2(self, device=device)
937
+ return out
938
+
939
+ def clip_target(self, state_dict={}):
940
+ pref = self.text_encoder_key_prefix[0]
941
+ t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
942
+ return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
943
+
944
+ class CosmosI2VPredict2(CosmosT2IPredict2):
945
+ unet_config = {
946
+ "image_model": "cosmos_predict2",
947
+ "in_channels": 17,
948
+ }
949
+
950
+ def get_model(self, state_dict, prefix="", device=None):
951
+ out = model_base.CosmosPredict2(self, image_to_video=True, device=device)
952
+ return out
953
+
954
+ class Lumina2(supported_models_base.BASE):
955
+ unet_config = {
956
+ "image_model": "lumina2",
957
+ }
958
+
959
+ sampling_settings = {
960
+ "multiplier": 1.0,
961
+ "shift": 6.0,
962
+ }
963
+
964
+ memory_usage_factor = 1.2
965
+
966
+ unet_extra_config = {}
967
+ latent_format = latent_formats.Flux
968
+
969
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
970
+
971
+ vae_key_prefix = ["vae."]
972
+ text_encoder_key_prefix = ["text_encoders."]
973
+
974
+ def get_model(self, state_dict, prefix="", device=None):
975
+ out = model_base.Lumina2(self, device=device)
976
+ return out
977
+
978
+ def clip_target(self, state_dict={}):
979
+ pref = self.text_encoder_key_prefix[0]
980
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
981
+ return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
982
+
983
+ class WAN21_T2V(supported_models_base.BASE):
984
+ unet_config = {
985
+ "image_model": "wan2.1",
986
+ "model_type": "t2v",
987
+ }
988
+
989
+ sampling_settings = {
990
+ "shift": 8.0,
991
+ }
992
+
993
+ unet_extra_config = {}
994
+ latent_format = latent_formats.Wan21
995
+
996
+ memory_usage_factor = 1.0
997
+
998
+ supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
999
+
1000
+ vae_key_prefix = ["vae."]
1001
+ text_encoder_key_prefix = ["text_encoders."]
1002
+
1003
+ def __init__(self, unet_config):
1004
+ super().__init__(unet_config)
1005
+ self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000
1006
+
1007
+ def get_model(self, state_dict, prefix="", device=None):
1008
+ out = model_base.WAN21(self, device=device)
1009
+ return out
1010
+
1011
+ def clip_target(self, state_dict={}):
1012
+ pref = self.text_encoder_key_prefix[0]
1013
+ t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
1014
+ return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
1015
+
1016
+ class WAN21_I2V(WAN21_T2V):
1017
+ unet_config = {
1018
+ "image_model": "wan2.1",
1019
+ "model_type": "i2v",
1020
+ "in_dim": 36,
1021
+ }
1022
+
1023
+ def get_model(self, state_dict, prefix="", device=None):
1024
+ out = model_base.WAN21(self, image_to_video=True, device=device)
1025
+ return out
1026
+
1027
+ class WAN21_FunControl2V(WAN21_T2V):
1028
+ unet_config = {
1029
+ "image_model": "wan2.1",
1030
+ "model_type": "i2v",
1031
+ "in_dim": 48,
1032
+ }
1033
+
1034
+ def get_model(self, state_dict, prefix="", device=None):
1035
+ out = model_base.WAN21(self, image_to_video=False, device=device)
1036
+ return out
1037
+
1038
+ class WAN21_Camera(WAN21_T2V):
1039
+ unet_config = {
1040
+ "image_model": "wan2.1",
1041
+ "model_type": "camera",
1042
+ "in_dim": 32,
1043
+ }
1044
+
1045
+ def get_model(self, state_dict, prefix="", device=None):
1046
+ out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
1047
+ return out
1048
+ class WAN21_Vace(WAN21_T2V):
1049
+ unet_config = {
1050
+ "image_model": "wan2.1",
1051
+ "model_type": "vace",
1052
+ }
1053
+
1054
+ def __init__(self, unet_config):
1055
+ super().__init__(unet_config)
1056
+ self.memory_usage_factor = 1.2 * self.memory_usage_factor
1057
+
1058
+ def get_model(self, state_dict, prefix="", device=None):
1059
+ out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
1060
+ return out
1061
+
1062
+ class WAN22_T2V(WAN21_T2V):
1063
+ unet_config = {
1064
+ "image_model": "wan2.1",
1065
+ "model_type": "t2v",
1066
+ "out_dim": 48,
1067
+ }
1068
+
1069
+ latent_format = latent_formats.Wan22
1070
+
1071
+ def get_model(self, state_dict, prefix="", device=None):
1072
+ out = model_base.WAN22(self, image_to_video=True, device=device)
1073
+ return out
1074
+
1075
+ class Hunyuan3Dv2(supported_models_base.BASE):
1076
+ unet_config = {
1077
+ "image_model": "hunyuan3d2",
1078
+ }
1079
+
1080
+ unet_extra_config = {}
1081
+
1082
+ sampling_settings = {
1083
+ "multiplier": 1.0,
1084
+ "shift": 1.0,
1085
+ }
1086
+
1087
+ memory_usage_factor = 3.5
1088
+
1089
+ clip_vision_prefix = "conditioner.main_image_encoder.model."
1090
+ vae_key_prefix = ["vae."]
1091
+
1092
+ latent_format = latent_formats.Hunyuan3Dv2
1093
+
1094
+ def process_unet_state_dict_for_saving(self, state_dict):
1095
+ replace_prefix = {"": "model."}
1096
+ return utils.state_dict_prefix_replace(state_dict, replace_prefix)
1097
+
1098
+ def get_model(self, state_dict, prefix="", device=None):
1099
+ out = model_base.Hunyuan3Dv2(self, device=device)
1100
+ return out
1101
+
1102
+ def clip_target(self, state_dict={}):
1103
+ return None
1104
+
1105
+ class Hunyuan3Dv2mini(Hunyuan3Dv2):
1106
+ unet_config = {
1107
+ "image_model": "hunyuan3d2",
1108
+ "depth": 8,
1109
+ }
1110
+
1111
+ latent_format = latent_formats.Hunyuan3Dv2mini
1112
+
1113
+ class HiDream(supported_models_base.BASE):
1114
+ unet_config = {
1115
+ "image_model": "hidream",
1116
+ }
1117
+
1118
+ sampling_settings = {
1119
+ "shift": 3.0,
1120
+ }
1121
+
1122
+ sampling_settings = {
1123
+ }
1124
+
1125
+ # memory_usage_factor = 1.2 # TODO
1126
+
1127
+ unet_extra_config = {}
1128
+ latent_format = latent_formats.Flux
1129
+
1130
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
1131
+
1132
+ vae_key_prefix = ["vae."]
1133
+ text_encoder_key_prefix = ["text_encoders."]
1134
+
1135
+ def get_model(self, state_dict, prefix="", device=None):
1136
+ out = model_base.HiDream(self, device=device)
1137
+ return out
1138
+
1139
+ def clip_target(self, state_dict={}):
1140
+ return None # TODO
1141
+
1142
+ class Chroma(supported_models_base.BASE):
1143
+ unet_config = {
1144
+ "image_model": "chroma",
1145
+ }
1146
+
1147
+ unet_extra_config = {
1148
+ }
1149
+
1150
+ sampling_settings = {
1151
+ "multiplier": 1.0,
1152
+ }
1153
+
1154
+ latent_format = comfy.latent_formats.Flux
1155
+
1156
+ memory_usage_factor = 3.2
1157
+
1158
+ supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
1159
+
1160
+
1161
+ def get_model(self, state_dict, prefix="", device=None):
1162
+ out = model_base.Chroma(self, device=device)
1163
+ return out
1164
+
1165
+ def clip_target(self, state_dict={}):
1166
+ pref = self.text_encoder_key_prefix[0]
1167
+ t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
1168
+ return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
1169
+
1170
+ class ACEStep(supported_models_base.BASE):
1171
+ unet_config = {
1172
+ "audio_model": "ace",
1173
+ }
1174
+
1175
+ unet_extra_config = {
1176
+ }
1177
+
1178
+ sampling_settings = {
1179
+ "shift": 3.0,
1180
+ }
1181
+
1182
+ latent_format = comfy.latent_formats.ACEAudio
1183
+
1184
+ memory_usage_factor = 0.5
1185
+
1186
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
1187
+
1188
+ vae_key_prefix = ["vae."]
1189
+ text_encoder_key_prefix = ["text_encoders."]
1190
+
1191
+ def get_model(self, state_dict, prefix="", device=None):
1192
+ out = model_base.ACEStep(self, device=device)
1193
+ return out
1194
+
1195
+ def clip_target(self, state_dict={}):
1196
+ return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
1197
+
1198
+ class Omnigen2(supported_models_base.BASE):
1199
+ unet_config = {
1200
+ "image_model": "omnigen2",
1201
+ }
1202
+
1203
+ sampling_settings = {
1204
+ "multiplier": 1.0,
1205
+ "shift": 2.6,
1206
+ }
1207
+
1208
+ memory_usage_factor = 1.65 #TODO
1209
+
1210
+ unet_extra_config = {}
1211
+ latent_format = latent_formats.Flux
1212
+
1213
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
1214
+
1215
+ vae_key_prefix = ["vae."]
1216
+ text_encoder_key_prefix = ["text_encoders."]
1217
+
1218
+ def __init__(self, unet_config):
1219
+ super().__init__(unet_config)
1220
+ if comfy.model_management.extended_fp16_support():
1221
+ self.supported_inference_dtypes = [torch.float16] + self.supported_inference_dtypes
1222
+
1223
+ def get_model(self, state_dict, prefix="", device=None):
1224
+ out = model_base.Omnigen2(self, device=device)
1225
+ return out
1226
+
1227
+ def clip_target(self, state_dict={}):
1228
+ pref = self.text_encoder_key_prefix[0]
1229
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
1230
+ return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
1231
+
1232
+
1233
+ models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
1234
+
1235
+ models += [SVD_img2vid]
ComfyUI/comfy/supported_models_base.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Comfy
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ from . import model_base
21
+ from . import utils
22
+ from . import latent_formats
23
+
24
+ class ClipTarget:
25
+ def __init__(self, tokenizer, clip):
26
+ self.clip = clip
27
+ self.tokenizer = tokenizer
28
+ self.params = {}
29
+
30
+ class BASE:
31
+ unet_config = {}
32
+ unet_extra_config = {
33
+ "num_heads": -1,
34
+ "num_head_channels": 64,
35
+ }
36
+
37
+ required_keys = {}
38
+
39
+ clip_prefix = []
40
+ clip_vision_prefix = None
41
+ noise_aug_config = None
42
+ sampling_settings = {}
43
+ latent_format = latent_formats.LatentFormat
44
+ vae_key_prefix = ["first_stage_model."]
45
+ text_encoder_key_prefix = ["cond_stage_model."]
46
+ supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
47
+
48
+ memory_usage_factor = 2.0
49
+
50
+ manual_cast_dtype = None
51
+ custom_operations = None
52
+ scaled_fp8 = None
53
+ optimizations = {"fp8": False}
54
+
55
+ @classmethod
56
+ def matches(s, unet_config, state_dict=None):
57
+ for k in s.unet_config:
58
+ if k not in unet_config or s.unet_config[k] != unet_config[k]:
59
+ return False
60
+ if state_dict is not None:
61
+ for k in s.required_keys:
62
+ if k not in state_dict:
63
+ return False
64
+ return True
65
+
66
+ def model_type(self, state_dict, prefix=""):
67
+ return model_base.ModelType.EPS
68
+
69
+ def inpaint_model(self):
70
+ return self.unet_config["in_channels"] > 4
71
+
72
+ def __init__(self, unet_config):
73
+ self.unet_config = unet_config.copy()
74
+ self.sampling_settings = self.sampling_settings.copy()
75
+ self.latent_format = self.latent_format()
76
+ self.optimizations = self.optimizations.copy()
77
+ for x in self.unet_extra_config:
78
+ self.unet_config[x] = self.unet_extra_config[x]
79
+
80
+ def get_model(self, state_dict, prefix="", device=None):
81
+ if self.noise_aug_config is not None:
82
+ out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
83
+ else:
84
+ out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
85
+ if self.inpaint_model():
86
+ out.set_inpaint()
87
+ return out
88
+
89
+ def process_clip_state_dict(self, state_dict):
90
+ state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
91
+ return state_dict
92
+
93
+ def process_unet_state_dict(self, state_dict):
94
+ return state_dict
95
+
96
+ def process_vae_state_dict(self, state_dict):
97
+ return state_dict
98
+
99
+ def process_clip_state_dict_for_saving(self, state_dict):
100
+ replace_prefix = {"": self.text_encoder_key_prefix[0]}
101
+ return utils.state_dict_prefix_replace(state_dict, replace_prefix)
102
+
103
+ def process_clip_vision_state_dict_for_saving(self, state_dict):
104
+ replace_prefix = {}
105
+ if self.clip_vision_prefix is not None:
106
+ replace_prefix[""] = self.clip_vision_prefix
107
+ return utils.state_dict_prefix_replace(state_dict, replace_prefix)
108
+
109
+ def process_unet_state_dict_for_saving(self, state_dict):
110
+ replace_prefix = {"": "model.diffusion_model."}
111
+ return utils.state_dict_prefix_replace(state_dict, replace_prefix)
112
+
113
+ def process_vae_state_dict_for_saving(self, state_dict):
114
+ replace_prefix = {"": self.vae_key_prefix[0]}
115
+ return utils.state_dict_prefix_replace(state_dict, replace_prefix)
116
+
117
+ def set_inference_dtype(self, dtype, manual_cast_dtype):
118
+ self.unet_config['dtype'] = dtype
119
+ self.manual_cast_dtype = manual_cast_dtype
ComfyUI/comfy/t2i_adapter/adapter.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from https://github.com/TencentARC/T2I-Adapter
2
+ import torch
3
+ import torch.nn as nn
4
+ from collections import OrderedDict
5
+
6
+
7
+ def conv_nd(dims, *args, **kwargs):
8
+ """
9
+ Create a 1D, 2D, or 3D convolution module.
10
+ """
11
+ if dims == 1:
12
+ return nn.Conv1d(*args, **kwargs)
13
+ elif dims == 2:
14
+ return nn.Conv2d(*args, **kwargs)
15
+ elif dims == 3:
16
+ return nn.Conv3d(*args, **kwargs)
17
+ raise ValueError(f"unsupported dimensions: {dims}")
18
+
19
+
20
+ def avg_pool_nd(dims, *args, **kwargs):
21
+ """
22
+ Create a 1D, 2D, or 3D average pooling module.
23
+ """
24
+ if dims == 1:
25
+ return nn.AvgPool1d(*args, **kwargs)
26
+ elif dims == 2:
27
+ return nn.AvgPool2d(*args, **kwargs)
28
+ elif dims == 3:
29
+ return nn.AvgPool3d(*args, **kwargs)
30
+ raise ValueError(f"unsupported dimensions: {dims}")
31
+
32
+
33
+ class Downsample(nn.Module):
34
+ """
35
+ A downsampling layer with an optional convolution.
36
+ :param channels: channels in the inputs and outputs.
37
+ :param use_conv: a bool determining if a convolution is applied.
38
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
39
+ downsampling occurs in the inner-two dimensions.
40
+ """
41
+
42
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
43
+ super().__init__()
44
+ self.channels = channels
45
+ self.out_channels = out_channels or channels
46
+ self.use_conv = use_conv
47
+ self.dims = dims
48
+ stride = 2 if dims != 3 else (1, 2, 2)
49
+ if use_conv:
50
+ self.op = conv_nd(
51
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
52
+ )
53
+ else:
54
+ assert self.channels == self.out_channels
55
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
56
+
57
+ def forward(self, x):
58
+ assert x.shape[1] == self.channels
59
+ if not self.use_conv:
60
+ padding = [x.shape[2] % 2, x.shape[3] % 2]
61
+ self.op.padding = padding
62
+
63
+ x = self.op(x)
64
+ return x
65
+
66
+
67
+ class ResnetBlock(nn.Module):
68
+ def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
69
+ super().__init__()
70
+ ps = ksize // 2
71
+ if in_c != out_c or sk == False:
72
+ self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
73
+ else:
74
+ # print('n_in')
75
+ self.in_conv = None
76
+ self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
77
+ self.act = nn.ReLU()
78
+ self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
79
+ if sk == False:
80
+ self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
81
+ else:
82
+ self.skep = None
83
+
84
+ self.down = down
85
+ if self.down == True:
86
+ self.down_opt = Downsample(in_c, use_conv=use_conv)
87
+
88
+ def forward(self, x):
89
+ if self.down == True:
90
+ x = self.down_opt(x)
91
+ if self.in_conv is not None: # edit
92
+ x = self.in_conv(x)
93
+
94
+ h = self.block1(x)
95
+ h = self.act(h)
96
+ h = self.block2(h)
97
+ if self.skep is not None:
98
+ return h + self.skep(x)
99
+ else:
100
+ return h + x
101
+
102
+
103
+ class Adapter(nn.Module):
104
+ def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True):
105
+ super(Adapter, self).__init__()
106
+ self.unshuffle_amount = 8
107
+ resblock_no_downsample = []
108
+ resblock_downsample = [3, 2, 1]
109
+ self.xl = xl
110
+ if self.xl:
111
+ self.unshuffle_amount = 16
112
+ resblock_no_downsample = [1]
113
+ resblock_downsample = [2]
114
+
115
+ self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount)
116
+ self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount)
117
+ self.channels = channels
118
+ self.nums_rb = nums_rb
119
+ self.body = []
120
+ for i in range(len(channels)):
121
+ for j in range(nums_rb):
122
+ if (i in resblock_downsample) and (j == 0):
123
+ self.body.append(
124
+ ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
125
+ elif (i in resblock_no_downsample) and (j == 0):
126
+ self.body.append(
127
+ ResnetBlock(channels[i - 1], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
128
+ else:
129
+ self.body.append(
130
+ ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
131
+ self.body = nn.ModuleList(self.body)
132
+ self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
133
+
134
+ def forward(self, x):
135
+ # unshuffle
136
+ x = self.unshuffle(x)
137
+ # extract features
138
+ features = []
139
+ x = self.conv_in(x)
140
+ for i in range(len(self.channels)):
141
+ for j in range(self.nums_rb):
142
+ idx = i * self.nums_rb + j
143
+ x = self.body[idx](x)
144
+ if self.xl:
145
+ features.append(None)
146
+ if i == 0:
147
+ features.append(None)
148
+ features.append(None)
149
+ if i == 2:
150
+ features.append(None)
151
+ else:
152
+ features.append(None)
153
+ features.append(None)
154
+ features.append(x)
155
+
156
+ features = features[::-1]
157
+
158
+ if self.xl:
159
+ return {"input": features[1:], "middle": features[:1]}
160
+ else:
161
+ return {"input": features}
162
+
163
+
164
+
165
+ class LayerNorm(nn.LayerNorm):
166
+ """Subclass torch's LayerNorm to handle fp16."""
167
+
168
+ def forward(self, x: torch.Tensor):
169
+ orig_type = x.dtype
170
+ ret = super().forward(x.type(torch.float32))
171
+ return ret.type(orig_type)
172
+
173
+
174
+ class QuickGELU(nn.Module):
175
+
176
+ def forward(self, x: torch.Tensor):
177
+ return x * torch.sigmoid(1.702 * x)
178
+
179
+
180
+ class ResidualAttentionBlock(nn.Module):
181
+
182
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
183
+ super().__init__()
184
+
185
+ self.attn = nn.MultiheadAttention(d_model, n_head)
186
+ self.ln_1 = LayerNorm(d_model)
187
+ self.mlp = nn.Sequential(
188
+ OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
189
+ ("c_proj", nn.Linear(d_model * 4, d_model))]))
190
+ self.ln_2 = LayerNorm(d_model)
191
+ self.attn_mask = attn_mask
192
+
193
+ def attention(self, x: torch.Tensor):
194
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
195
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
196
+
197
+ def forward(self, x: torch.Tensor):
198
+ x = x + self.attention(self.ln_1(x))
199
+ x = x + self.mlp(self.ln_2(x))
200
+ return x
201
+
202
+
203
+ class StyleAdapter(nn.Module):
204
+
205
+ def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4):
206
+ super().__init__()
207
+
208
+ scale = width ** -0.5
209
+ self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)])
210
+ self.num_token = num_token
211
+ self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale)
212
+ self.ln_post = LayerNorm(width)
213
+ self.ln_pre = LayerNorm(width)
214
+ self.proj = nn.Parameter(scale * torch.randn(width, context_dim))
215
+
216
+ def forward(self, x):
217
+ # x shape [N, HW+1, C]
218
+ style_embedding = self.style_embedding + torch.zeros(
219
+ (x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device)
220
+ x = torch.cat([x, style_embedding], dim=1)
221
+ x = self.ln_pre(x)
222
+ x = x.permute(1, 0, 2) # NLD -> LND
223
+ x = self.transformer_layes(x)
224
+ x = x.permute(1, 0, 2) # LND -> NLD
225
+
226
+ x = self.ln_post(x[:, -self.num_token:, :])
227
+ x = x @ self.proj
228
+
229
+ return x
230
+
231
+
232
+ class ResnetBlock_light(nn.Module):
233
+ def __init__(self, in_c):
234
+ super().__init__()
235
+ self.block1 = nn.Conv2d(in_c, in_c, 3, 1, 1)
236
+ self.act = nn.ReLU()
237
+ self.block2 = nn.Conv2d(in_c, in_c, 3, 1, 1)
238
+
239
+ def forward(self, x):
240
+ h = self.block1(x)
241
+ h = self.act(h)
242
+ h = self.block2(h)
243
+
244
+ return h + x
245
+
246
+
247
+ class extractor(nn.Module):
248
+ def __init__(self, in_c, inter_c, out_c, nums_rb, down=False):
249
+ super().__init__()
250
+ self.in_conv = nn.Conv2d(in_c, inter_c, 1, 1, 0)
251
+ self.body = []
252
+ for _ in range(nums_rb):
253
+ self.body.append(ResnetBlock_light(inter_c))
254
+ self.body = nn.Sequential(*self.body)
255
+ self.out_conv = nn.Conv2d(inter_c, out_c, 1, 1, 0)
256
+ self.down = down
257
+ if self.down == True:
258
+ self.down_opt = Downsample(in_c, use_conv=False)
259
+
260
+ def forward(self, x):
261
+ if self.down == True:
262
+ x = self.down_opt(x)
263
+ x = self.in_conv(x)
264
+ x = self.body(x)
265
+ x = self.out_conv(x)
266
+
267
+ return x
268
+
269
+
270
+ class Adapter_light(nn.Module):
271
+ def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
272
+ super(Adapter_light, self).__init__()
273
+ self.unshuffle_amount = 8
274
+ self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount)
275
+ self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount)
276
+ self.channels = channels
277
+ self.nums_rb = nums_rb
278
+ self.body = []
279
+ self.xl = False
280
+
281
+ for i in range(len(channels)):
282
+ if i == 0:
283
+ self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))
284
+ else:
285
+ self.body.append(extractor(in_c=channels[i-1], inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=True))
286
+ self.body = nn.ModuleList(self.body)
287
+
288
+ def forward(self, x):
289
+ # unshuffle
290
+ x = self.unshuffle(x)
291
+ # extract features
292
+ features = []
293
+ for i in range(len(self.channels)):
294
+ x = self.body[i](x)
295
+ features.append(None)
296
+ features.append(None)
297
+ features.append(x)
298
+
299
+ return {"input": features[::-1]}
ComfyUI/comfy/taesd/taesd.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Tiny AutoEncoder for Stable Diffusion
4
+ (DNN for encoding / decoding SD's latent space)
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ import comfy.utils
10
+ import comfy.ops
11
+
12
+ def conv(n_in, n_out, **kwargs):
13
+ return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
14
+
15
+ class Clamp(nn.Module):
16
+ def forward(self, x):
17
+ return torch.tanh(x / 3) * 3
18
+
19
+ class Block(nn.Module):
20
+ def __init__(self, n_in, n_out):
21
+ super().__init__()
22
+ self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
23
+ self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
24
+ self.fuse = nn.ReLU()
25
+ def forward(self, x):
26
+ return self.fuse(self.conv(x) + self.skip(x))
27
+
28
+ def Encoder(latent_channels=4):
29
+ return nn.Sequential(
30
+ conv(3, 64), Block(64, 64),
31
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
32
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
33
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
34
+ conv(64, latent_channels),
35
+ )
36
+
37
+
38
+ def Decoder(latent_channels=4):
39
+ return nn.Sequential(
40
+ Clamp(), conv(latent_channels, 64), nn.ReLU(),
41
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
42
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
43
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
44
+ Block(64, 64), conv(64, 3),
45
+ )
46
+
47
+ class TAESD(nn.Module):
48
+ latent_magnitude = 3
49
+ latent_shift = 0.5
50
+
51
+ def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
52
+ """Initialize pretrained TAESD on the given device from the given checkpoints."""
53
+ super().__init__()
54
+ self.taesd_encoder = Encoder(latent_channels=latent_channels)
55
+ self.taesd_decoder = Decoder(latent_channels=latent_channels)
56
+ self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
57
+ self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
58
+ if encoder_path is not None:
59
+ self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
60
+ if decoder_path is not None:
61
+ self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
62
+
63
+ @staticmethod
64
+ def scale_latents(x):
65
+ """raw latents -> [0, 1]"""
66
+ return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
67
+
68
+ @staticmethod
69
+ def unscale_latents(x):
70
+ """[0, 1] -> raw latents"""
71
+ return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
72
+
73
+ def decode(self, x):
74
+ x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
75
+ x_sample = x_sample.sub(0.5).mul(2)
76
+ return x_sample
77
+
78
+ def encode(self, x):
79
+ return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
ComfyUI/comfy/text_encoders/ace.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy import sd1_clip
2
+ from .spiece_tokenizer import SPieceTokenizer
3
+ import comfy.text_encoders.t5
4
+ import os
5
+ import re
6
+ import torch
7
+ import logging
8
+
9
+ from tokenizers import Tokenizer
10
+ from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji
11
+
12
+ SUPPORT_LANGUAGES = {
13
+ "en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
14
+ "pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
15
+ "nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
16
+ "ko": 6152, "hi": 6680
17
+ }
18
+
19
+ structure_pattern = re.compile(r"\[.*?\]")
20
+
21
+ DEFAULT_VOCAB_FILE = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
22
+
23
+
24
+ class VoiceBpeTokenizer:
25
+ def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
26
+ self.tokenizer = None
27
+ if vocab_file is not None:
28
+ self.tokenizer = Tokenizer.from_file(vocab_file)
29
+
30
+ def preprocess_text(self, txt, lang):
31
+ txt = multilingual_cleaners(txt, lang)
32
+ return txt
33
+
34
+ def encode(self, txt, lang='en'):
35
+ # lang = lang.split("-")[0] # remove the region
36
+ # self.check_input_length(txt, lang)
37
+ txt = self.preprocess_text(txt, lang)
38
+ lang = "zh-cn" if lang == "zh" else lang
39
+ txt = f"[{lang}]{txt}"
40
+ txt = txt.replace(" ", "[SPACE]")
41
+ return self.tokenizer.encode(txt).ids
42
+
43
+ def get_lang(self, line):
44
+ if line.startswith("[") and line[3:4] == ']':
45
+ lang = line[1:3].lower()
46
+ if lang in SUPPORT_LANGUAGES:
47
+ return lang, line[4:]
48
+ return "en", line
49
+
50
+ def __call__(self, string):
51
+ lines = string.split("\n")
52
+ lyric_token_idx = [261]
53
+ for line in lines:
54
+ line = line.strip()
55
+ if not line:
56
+ lyric_token_idx += [2]
57
+ continue
58
+
59
+ lang, line = self.get_lang(line)
60
+
61
+ if lang not in SUPPORT_LANGUAGES:
62
+ lang = "en"
63
+ if "zh" in lang:
64
+ lang = "zh"
65
+ if "spa" in lang:
66
+ lang = "es"
67
+
68
+ try:
69
+ line_out = japanese_to_romaji(line)
70
+ if line_out != line:
71
+ lang = "ja"
72
+ line = line_out
73
+ except:
74
+ pass
75
+
76
+ try:
77
+ if structure_pattern.match(line):
78
+ token_idx = self.encode(line, "en")
79
+ else:
80
+ token_idx = self.encode(line, lang)
81
+ lyric_token_idx = lyric_token_idx + token_idx + [2]
82
+ except Exception as e:
83
+ logging.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
84
+ return {"input_ids": lyric_token_idx}
85
+
86
+ @staticmethod
87
+ def from_pretrained(path, **kwargs):
88
+ return VoiceBpeTokenizer(path, **kwargs)
89
+
90
+ def get_vocab(self):
91
+ return {}
92
+
93
+
94
+ class UMT5BaseModel(sd1_clip.SDClipModel):
95
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
96
+ textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_base.json")
97
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
98
+
99
+ class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
100
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
101
+ tokenizer = tokenizer_data.get("spiece_model", None)
102
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)
103
+
104
+ def state_dict(self):
105
+ return {"spiece_model": self.tokenizer.serialize_model()}
106
+
107
+ class LyricsTokenizer(sd1_clip.SDTokenizer):
108
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
109
+ tokenizer = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
110
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)
111
+
112
+ class AceT5Tokenizer:
113
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
114
+ self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
115
+ self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
116
+
117
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
118
+ out = {}
119
+ out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
120
+ out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
121
+ return out
122
+
123
+ def untokenize(self, token_weight_pair):
124
+ return self.umt5base.untokenize(token_weight_pair)
125
+
126
+ def state_dict(self):
127
+ return self.umt5base.state_dict()
128
+
129
+ class AceT5Model(torch.nn.Module):
130
+ def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
131
+ super().__init__()
132
+ self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
133
+ self.dtypes = set()
134
+ if dtype is not None:
135
+ self.dtypes.add(dtype)
136
+
137
+ def set_clip_options(self, options):
138
+ self.umt5base.set_clip_options(options)
139
+
140
+ def reset_clip_options(self):
141
+ self.umt5base.reset_clip_options()
142
+
143
+ def encode_token_weights(self, token_weight_pairs):
144
+ token_weight_pairs_umt5base = token_weight_pairs["umt5base"]
145
+ token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
146
+
147
+ t5_out, t5_pooled = self.umt5base.encode_token_weights(token_weight_pairs_umt5base)
148
+
149
+ lyrics_embeds = torch.tensor(list(map(lambda a: a[0], token_weight_pairs_lyrics[0]))).unsqueeze(0)
150
+ return t5_out, None, {"conditioning_lyrics": lyrics_embeds}
151
+
152
+ def load_sd(self, sd):
153
+ return self.umt5base.load_sd(sd)
ComfyUI/comfy/text_encoders/ace_text_cleaners.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # basic text cleaners for the ACE step model
2
+ # I didn't copy the ones from the reference code because I didn't want to deal with the dependencies
3
+ # TODO: more languages than english?
4
+
5
+ import re
6
+
7
+ def japanese_to_romaji(japanese_text):
8
+ """
9
+ Convert Japanese hiragana and katakana to romaji (Latin alphabet representation).
10
+
11
+ Args:
12
+ japanese_text (str): Text containing hiragana and/or katakana characters
13
+
14
+ Returns:
15
+ str: The romaji (Latin alphabet) equivalent
16
+ """
17
+ # Dictionary mapping kana characters to their romaji equivalents
18
+ kana_map = {
19
+ # Katakana characters
20
+ 'ア': 'a', 'イ': 'i', 'ウ': 'u', 'エ': 'e', 'オ': 'o',
21
+ 'カ': 'ka', 'キ': 'ki', 'ク': 'ku', 'ケ': 'ke', 'コ': 'ko',
22
+ 'サ': 'sa', 'シ': 'shi', 'ス': 'su', 'セ': 'se', 'ソ': 'so',
23
+ 'タ': 'ta', 'チ': 'chi', 'ツ': 'tsu', 'テ': 'te', 'ト': 'to',
24
+ 'ナ': 'na', 'ニ': 'ni', 'ヌ': 'nu', 'ネ': 'ne', 'ノ': 'no',
25
+ 'ハ': 'ha', 'ヒ': 'hi', 'フ': 'fu', 'ヘ': 'he', 'ホ': 'ho',
26
+ 'マ': 'ma', 'ミ': 'mi', 'ム': 'mu', 'メ': 'me', 'モ': 'mo',
27
+ 'ヤ': 'ya', 'ユ': 'yu', 'ヨ': 'yo',
28
+ 'ラ': 'ra', 'リ': 'ri', 'ル': 'ru', 'レ': 're', 'ロ': 'ro',
29
+ 'ワ': 'wa', 'ヲ': 'wo', 'ン': 'n',
30
+
31
+ # Katakana voiced consonants
32
+ 'ガ': 'ga', 'ギ': 'gi', 'グ': 'gu', 'ゲ': 'ge', 'ゴ': 'go',
33
+ 'ザ': 'za', 'ジ': 'ji', 'ズ': 'zu', 'ゼ': 'ze', 'ゾ': 'zo',
34
+ 'ダ': 'da', 'ヂ': 'ji', 'ヅ': 'zu', 'デ': 'de', 'ド': 'do',
35
+ 'バ': 'ba', 'ビ': 'bi', 'ブ': 'bu', 'ベ': 'be', 'ボ': 'bo',
36
+ 'パ': 'pa', 'ピ': 'pi', 'プ': 'pu', 'ペ': 'pe', 'ポ': 'po',
37
+
38
+ # Katakana combinations
39
+ 'キャ': 'kya', 'キュ': 'kyu', 'キョ': 'kyo',
40
+ 'シャ': 'sha', 'シュ': 'shu', 'ショ': 'sho',
41
+ 'チャ': 'cha', 'チュ': 'chu', 'チョ': 'cho',
42
+ 'ニャ': 'nya', 'ニュ': 'nyu', 'ニョ': 'nyo',
43
+ 'ヒャ': 'hya', 'ヒュ': 'hyu', 'ヒョ': 'hyo',
44
+ 'ミャ': 'mya', 'ミュ': 'myu', 'ミョ': 'myo',
45
+ 'リャ': 'rya', 'リュ': 'ryu', 'リョ': 'ryo',
46
+ 'ギャ': 'gya', 'ギュ': 'gyu', 'ギョ': 'gyo',
47
+ 'ジャ': 'ja', 'ジュ': 'ju', 'ジョ': 'jo',
48
+ 'ビャ': 'bya', 'ビュ': 'byu', 'ビョ': 'byo',
49
+ 'ピャ': 'pya', 'ピュ': 'pyu', 'ピョ': 'pyo',
50
+
51
+ # Katakana small characters and special cases
52
+ 'ッ': '', # Small tsu (doubles the following consonant)
53
+ 'ャ': 'ya', 'ュ': 'yu', 'ョ': 'yo',
54
+
55
+ # Katakana extras
56
+ 'ヴ': 'vu', 'ファ': 'fa', 'フィ': 'fi', 'フェ': 'fe', 'フォ': 'fo',
57
+ 'ウィ': 'wi', 'ウェ': 'we', 'ウォ': 'wo',
58
+
59
+ # Hiragana characters
60
+ 'あ': 'a', 'い': 'i', 'う': 'u', 'え': 'e', 'お': 'o',
61
+ 'か': 'ka', 'き': 'ki', 'く': 'ku', 'け': 'ke', 'こ': 'ko',
62
+ 'さ': 'sa', 'し': 'shi', 'す': 'su', 'せ': 'se', 'そ': 'so',
63
+ 'た': 'ta', 'ち': 'chi', 'つ': 'tsu', 'て': 'te', 'と': 'to',
64
+ 'な': 'na', 'に': 'ni', 'ぬ': 'nu', 'ね': 'ne', 'の': 'no',
65
+ 'は': 'ha', 'ひ': 'hi', 'ふ': 'fu', 'へ': 'he', 'ほ': 'ho',
66
+ 'ま': 'ma', 'み': 'mi', 'む': 'mu', 'め': 'me', 'も': 'mo',
67
+ 'や': 'ya', 'ゆ': 'yu', 'よ': 'yo',
68
+ 'ら': 'ra', 'り': 'ri', 'る': 'ru', 'れ': 're', 'ろ': 'ro',
69
+ 'わ': 'wa', 'を': 'wo', 'ん': 'n',
70
+
71
+ # Hiragana voiced consonants
72
+ 'が': 'ga', 'ぎ': 'gi', 'ぐ': 'gu', 'げ': 'ge', 'ご': 'go',
73
+ 'ざ': 'za', 'じ': 'ji', 'ず': 'zu', 'ぜ': 'ze', 'ぞ': 'zo',
74
+ 'だ': 'da', 'ぢ': 'ji', 'づ': 'zu', 'で': 'de', 'ど': 'do',
75
+ 'ば': 'ba', 'び': 'bi', 'ぶ': 'bu', 'べ': 'be', 'ぼ': 'bo',
76
+ 'ぱ': 'pa', 'ぴ': 'pi', 'ぷ': 'pu', 'ぺ': 'pe', 'ぽ': 'po',
77
+
78
+ # Hiragana combinations
79
+ 'きゃ': 'kya', 'きゅ': 'kyu', 'きょ': 'kyo',
80
+ 'しゃ': 'sha', 'しゅ': 'shu', 'しょ': 'sho',
81
+ 'ちゃ': 'cha', 'ちゅ': 'chu', 'ちょ': 'cho',
82
+ 'にゃ': 'nya', 'にゅ': 'nyu', 'にょ': 'nyo',
83
+ 'ひゃ': 'hya', 'ひゅ': 'hyu', 'ひょ': 'hyo',
84
+ 'みゃ': 'mya', 'みゅ': 'myu', 'みょ': 'myo',
85
+ 'りゃ': 'rya', 'りゅ': 'ryu', 'りょ': 'ryo',
86
+ 'ぎゃ': 'gya', 'ぎゅ': 'gyu', 'ぎょ': 'gyo',
87
+ 'じゃ': 'ja', 'じゅ': 'ju', 'じょ': 'jo',
88
+ 'びゃ': 'bya', 'びゅ': 'byu', 'びょ': 'byo',
89
+ 'ぴゃ': 'pya', 'ぴゅ': 'pyu', 'ぴょ': 'pyo',
90
+
91
+ # Hiragana small characters and special cases
92
+ 'っ': '', # Small tsu (doubles the following consonant)
93
+ 'ゃ': 'ya', 'ゅ': 'yu', 'ょ': 'yo',
94
+
95
+ # Common punctuation and spaces
96
+ ' ': ' ', # Japanese space
97
+ '、': ', ', '。': '. ',
98
+ }
99
+
100
+ result = []
101
+ i = 0
102
+
103
+ while i < len(japanese_text):
104
+ # Check for small tsu (doubling the following consonant)
105
+ if i < len(japanese_text) - 1 and (japanese_text[i] == 'っ' or japanese_text[i] == '��'):
106
+ if i < len(japanese_text) - 1 and japanese_text[i+1] in kana_map:
107
+ next_romaji = kana_map[japanese_text[i+1]]
108
+ if next_romaji and next_romaji[0] not in 'aiueon':
109
+ result.append(next_romaji[0]) # Double the consonant
110
+ i += 1
111
+ continue
112
+
113
+ # Check for combinations with small ya, yu, yo
114
+ if i < len(japanese_text) - 1 and japanese_text[i+1] in ('ゃ', 'ゅ', 'ょ', 'ャ', 'ュ', 'ョ'):
115
+ combo = japanese_text[i:i+2]
116
+ if combo in kana_map:
117
+ result.append(kana_map[combo])
118
+ i += 2
119
+ continue
120
+
121
+ # Regular character
122
+ if japanese_text[i] in kana_map:
123
+ result.append(kana_map[japanese_text[i]])
124
+ else:
125
+ # If it's not in our map, keep it as is (might be kanji, romaji, etc.)
126
+ result.append(japanese_text[i])
127
+
128
+ i += 1
129
+
130
+ return ''.join(result)
131
+
132
+ def number_to_text(num, ordinal=False):
133
+ """
134
+ Convert a number (int or float) to its text representation.
135
+
136
+ Args:
137
+ num: The number to convert
138
+
139
+ Returns:
140
+ str: Text representation of the number
141
+ """
142
+
143
+ if not isinstance(num, (int, float)):
144
+ return "Input must be a number"
145
+
146
+ # Handle special case of zero
147
+ if num == 0:
148
+ return "zero"
149
+
150
+ # Handle negative numbers
151
+ negative = num < 0
152
+ num = abs(num)
153
+
154
+ # Handle floats
155
+ if isinstance(num, float):
156
+ # Split into integer and decimal parts
157
+ int_part = int(num)
158
+
159
+ # Convert both parts
160
+ int_text = _int_to_text(int_part)
161
+
162
+ # Handle decimal part (convert to string and remove '0.')
163
+ decimal_str = str(num).split('.')[1]
164
+ decimal_text = " point " + " ".join(_digit_to_text(int(digit)) for digit in decimal_str)
165
+
166
+ result = int_text + decimal_text
167
+ else:
168
+ # Handle integers
169
+ result = _int_to_text(num)
170
+
171
+ # Add 'negative' prefix for negative numbers
172
+ if negative:
173
+ result = "negative " + result
174
+
175
+ return result
176
+
177
+
178
+ def _int_to_text(num):
179
+ """Helper function to convert an integer to text"""
180
+
181
+ ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
182
+ "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen",
183
+ "seventeen", "eighteen", "nineteen"]
184
+
185
+ tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
186
+
187
+ if num < 20:
188
+ return ones[num]
189
+
190
+ if num < 100:
191
+ return tens[num // 10] + (" " + ones[num % 10] if num % 10 != 0 else "")
192
+
193
+ if num < 1000:
194
+ return ones[num // 100] + " hundred" + (" " + _int_to_text(num % 100) if num % 100 != 0 else "")
195
+
196
+ if num < 1000000:
197
+ return _int_to_text(num // 1000) + " thousand" + (" " + _int_to_text(num % 1000) if num % 1000 != 0 else "")
198
+
199
+ if num < 1000000000:
200
+ return _int_to_text(num // 1000000) + " million" + (" " + _int_to_text(num % 1000000) if num % 1000000 != 0 else "")
201
+
202
+ return _int_to_text(num // 1000000000) + " billion" + (" " + _int_to_text(num % 1000000000) if num % 1000000000 != 0 else "")
203
+
204
+
205
+ def _digit_to_text(digit):
206
+ """Convert a single digit to text"""
207
+ digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
208
+ return digits[digit]
209
+
210
+
211
+ _whitespace_re = re.compile(r"\s+")
212
+
213
+
214
+ # List of (regular expression, replacement) pairs for abbreviations:
215
+ _abbreviations = {
216
+ "en": [
217
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
218
+ for x in [
219
+ ("mrs", "misess"),
220
+ ("mr", "mister"),
221
+ ("dr", "doctor"),
222
+ ("st", "saint"),
223
+ ("co", "company"),
224
+ ("jr", "junior"),
225
+ ("maj", "major"),
226
+ ("gen", "general"),
227
+ ("drs", "doctors"),
228
+ ("rev", "reverend"),
229
+ ("lt", "lieutenant"),
230
+ ("hon", "honorable"),
231
+ ("sgt", "sergeant"),
232
+ ("capt", "captain"),
233
+ ("esq", "esquire"),
234
+ ("ltd", "limited"),
235
+ ("col", "colonel"),
236
+ ("ft", "fort"),
237
+ ]
238
+ ],
239
+ }
240
+
241
+
242
+ def expand_abbreviations_multilingual(text, lang="en"):
243
+ for regex, replacement in _abbreviations[lang]:
244
+ text = re.sub(regex, replacement, text)
245
+ return text
246
+
247
+
248
+ _symbols_multilingual = {
249
+ "en": [
250
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
251
+ for x in [
252
+ ("&", " and "),
253
+ ("@", " at "),
254
+ ("%", " percent "),
255
+ ("#", " hash "),
256
+ ("$", " dollar "),
257
+ ("£", " pound "),
258
+ ("°", " degree "),
259
+ ]
260
+ ],
261
+ }
262
+
263
+
264
+ def expand_symbols_multilingual(text, lang="en"):
265
+ for regex, replacement in _symbols_multilingual[lang]:
266
+ text = re.sub(regex, replacement, text)
267
+ text = text.replace(" ", " ") # Ensure there are no double spaces
268
+ return text.strip()
269
+
270
+
271
+ _ordinal_re = {
272
+ "en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
273
+ }
274
+ _number_re = re.compile(r"[0-9]+")
275
+ _currency_re = {
276
+ "USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
277
+ "GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
278
+ "EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
279
+ }
280
+
281
+ _comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
282
+ _dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
283
+ _decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
284
+
285
+
286
+ def _remove_commas(m):
287
+ text = m.group(0)
288
+ if "," in text:
289
+ text = text.replace(",", "")
290
+ return text
291
+
292
+
293
+ def _remove_dots(m):
294
+ text = m.group(0)
295
+ if "." in text:
296
+ text = text.replace(".", "")
297
+ return text
298
+
299
+
300
+ def _expand_decimal_point(m, lang="en"):
301
+ amount = m.group(1).replace(",", ".")
302
+ return number_to_text(float(amount))
303
+
304
+
305
+ def _expand_currency(m, lang="en", currency="USD"):
306
+ amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
307
+ full_amount = number_to_text(amount)
308
+
309
+ and_equivalents = {
310
+ "en": ", ",
311
+ "es": " con ",
312
+ "fr": " et ",
313
+ "de": " und ",
314
+ "pt": " e ",
315
+ "it": " e ",
316
+ "pl": ", ",
317
+ "cs": ", ",
318
+ "ru": ", ",
319
+ "nl": ", ",
320
+ "ar": ", ",
321
+ "tr": ", ",
322
+ "hu": ", ",
323
+ "ko": ", ",
324
+ }
325
+
326
+ if amount.is_integer():
327
+ last_and = full_amount.rfind(and_equivalents[lang])
328
+ if last_and != -1:
329
+ full_amount = full_amount[:last_and]
330
+
331
+ return full_amount
332
+
333
+
334
+ def _expand_ordinal(m, lang="en"):
335
+ return number_to_text(int(m.group(1)), ordinal=True)
336
+
337
+
338
+ def _expand_number(m, lang="en"):
339
+ return number_to_text(int(m.group(0)))
340
+
341
+
342
+ def expand_numbers_multilingual(text, lang="en"):
343
+ if lang in ["en", "ru"]:
344
+ text = re.sub(_comma_number_re, _remove_commas, text)
345
+ else:
346
+ text = re.sub(_dot_number_re, _remove_dots, text)
347
+ try:
348
+ text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
349
+ text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
350
+ text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
351
+ except:
352
+ pass
353
+
354
+ text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
355
+ text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
356
+ text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
357
+ return text
358
+
359
+
360
+ def lowercase(text):
361
+ return text.lower()
362
+
363
+
364
+ def collapse_whitespace(text):
365
+ return re.sub(_whitespace_re, " ", text)
366
+
367
+
368
+ def multilingual_cleaners(text, lang):
369
+ text = text.replace('"', "")
370
+ if lang == "tr":
371
+ text = text.replace("İ", "i")
372
+ text = text.replace("Ö", "ö")
373
+ text = text.replace("Ü", "ü")
374
+ text = lowercase(text)
375
+ try:
376
+ text = expand_numbers_multilingual(text, lang)
377
+ except:
378
+ pass
379
+ try:
380
+ text = expand_abbreviations_multilingual(text, lang)
381
+ except:
382
+ pass
383
+ try:
384
+ text = expand_symbols_multilingual(text, lang=lang)
385
+ except:
386
+ pass
387
+ text = collapse_whitespace(text)
388
+ return text
389
+
390
+
391
+ def basic_cleaners(text):
392
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
393
+ text = lowercase(text)
394
+ text = collapse_whitespace(text)
395
+ return text
ComfyUI/comfy/text_encoders/aura_t5.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy import sd1_clip
2
+ from .spiece_tokenizer import SPieceTokenizer
3
+ import comfy.text_encoders.t5
4
+ import os
5
+
6
+ class PT5XlModel(sd1_clip.SDClipModel):
7
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
8
+ textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
9
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
10
+
11
+ class PT5XlTokenizer(sd1_clip.SDTokenizer):
12
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
13
+ tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
14
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data)
15
+
16
+ class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
17
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
18
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
19
+
20
+ class AuraT5Model(sd1_clip.SD1ClipModel):
21
+ def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
22
+ super().__init__(device=device, dtype=dtype, model_options=model_options, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
ComfyUI/comfy/text_encoders/bert.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from comfy.ldm.modules.attention import optimized_attention_for_device
3
+ import comfy.ops
4
+
5
+ class BertAttention(torch.nn.Module):
6
+ def __init__(self, embed_dim, heads, dtype, device, operations):
7
+ super().__init__()
8
+
9
+ self.heads = heads
10
+ self.query = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
11
+ self.key = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
12
+ self.value = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
13
+
14
+
15
+ def forward(self, x, mask=None, optimized_attention=None):
16
+ q = self.query(x)
17
+ k = self.key(x)
18
+ v = self.value(x)
19
+
20
+ out = optimized_attention(q, k, v, self.heads, mask)
21
+ return out
22
+
23
+ class BertOutput(torch.nn.Module):
24
+ def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
25
+ super().__init__()
26
+ self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
27
+ self.LayerNorm = operations.LayerNorm(output_dim, eps=layer_norm_eps, dtype=dtype, device=device)
28
+ # self.dropout = nn.Dropout(0.0)
29
+
30
+ def forward(self, x, y):
31
+ x = self.dense(x)
32
+ # hidden_states = self.dropout(hidden_states)
33
+ x = self.LayerNorm(x + y)
34
+ return x
35
+
36
+ class BertAttentionBlock(torch.nn.Module):
37
+ def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
38
+ super().__init__()
39
+ self.self = BertAttention(embed_dim, heads, dtype, device, operations)
40
+ self.output = BertOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
41
+
42
+ def forward(self, x, mask, optimized_attention):
43
+ y = self.self(x, mask, optimized_attention)
44
+ return self.output(y, x)
45
+
46
+ class BertIntermediate(torch.nn.Module):
47
+ def __init__(self, embed_dim, intermediate_dim, dtype, device, operations):
48
+ super().__init__()
49
+ self.dense = operations.Linear(embed_dim, intermediate_dim, dtype=dtype, device=device)
50
+
51
+ def forward(self, x):
52
+ x = self.dense(x)
53
+ return torch.nn.functional.gelu(x)
54
+
55
+
56
+ class BertBlock(torch.nn.Module):
57
+ def __init__(self, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
58
+ super().__init__()
59
+ self.attention = BertAttentionBlock(embed_dim, heads, layer_norm_eps, dtype, device, operations)
60
+ self.intermediate = BertIntermediate(embed_dim, intermediate_dim, dtype, device, operations)
61
+ self.output = BertOutput(intermediate_dim, embed_dim, layer_norm_eps, dtype, device, operations)
62
+
63
+ def forward(self, x, mask, optimized_attention):
64
+ x = self.attention(x, mask, optimized_attention)
65
+ y = self.intermediate(x)
66
+ return self.output(y, x)
67
+
68
+ class BertEncoder(torch.nn.Module):
69
+ def __init__(self, num_layers, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
70
+ super().__init__()
71
+ self.layer = torch.nn.ModuleList([BertBlock(embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations) for i in range(num_layers)])
72
+
73
+ def forward(self, x, mask=None, intermediate_output=None):
74
+ optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
75
+
76
+ if intermediate_output is not None:
77
+ if intermediate_output < 0:
78
+ intermediate_output = len(self.layer) + intermediate_output
79
+
80
+ intermediate = None
81
+ for i, l in enumerate(self.layer):
82
+ x = l(x, mask, optimized_attention)
83
+ if i == intermediate_output:
84
+ intermediate = x.clone()
85
+ return x, intermediate
86
+
87
+ class BertEmbeddings(torch.nn.Module):
88
+ def __init__(self, vocab_size, max_position_embeddings, type_vocab_size, pad_token_id, embed_dim, layer_norm_eps, dtype, device, operations):
89
+ super().__init__()
90
+ self.word_embeddings = operations.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id, dtype=dtype, device=device)
91
+ self.position_embeddings = operations.Embedding(max_position_embeddings, embed_dim, dtype=dtype, device=device)
92
+ self.token_type_embeddings = operations.Embedding(type_vocab_size, embed_dim, dtype=dtype, device=device)
93
+
94
+ self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
95
+
96
+ def forward(self, input_tokens, embeds=None, token_type_ids=None, dtype=None):
97
+ if embeds is not None:
98
+ x = embeds
99
+ else:
100
+ x = self.word_embeddings(input_tokens, out_dtype=dtype)
101
+ x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
102
+ if token_type_ids is not None:
103
+ x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
104
+ else:
105
+ x += comfy.ops.cast_to_input(self.token_type_embeddings.weight[0], x)
106
+ x = self.LayerNorm(x)
107
+ return x
108
+
109
+
110
+ class BertModel_(torch.nn.Module):
111
+ def __init__(self, config_dict, dtype, device, operations):
112
+ super().__init__()
113
+ embed_dim = config_dict["hidden_size"]
114
+ layer_norm_eps = config_dict["layer_norm_eps"]
115
+
116
+ self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
117
+ self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
118
+
119
+ def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
120
+ x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
121
+ mask = None
122
+ if attention_mask is not None:
123
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
124
+ mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
125
+
126
+ x, i = self.encoder(x, mask, intermediate_output)
127
+ return x, i
128
+
129
+
130
+ class BertModel(torch.nn.Module):
131
+ def __init__(self, config_dict, dtype, device, operations):
132
+ super().__init__()
133
+ self.bert = BertModel_(config_dict, dtype, device, operations)
134
+ self.num_layers = config_dict["num_hidden_layers"]
135
+
136
+ def get_input_embeddings(self):
137
+ return self.bert.embeddings.word_embeddings
138
+
139
+ def set_input_embeddings(self, embeddings):
140
+ self.bert.embeddings.word_embeddings = embeddings
141
+
142
+ def forward(self, *args, **kwargs):
143
+ return self.bert(*args, **kwargs)
ComfyUI/comfy/text_encoders/cosmos.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy import sd1_clip
2
+ import comfy.text_encoders.t5
3
+ import os
4
+ from transformers import T5TokenizerFast
5
+
6
+
7
+ class T5XXLModel(sd1_clip.SDClipModel):
8
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
9
+ textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
10
+ t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
11
+ if t5xxl_scaled_fp8 is not None:
12
+ model_options = model_options.copy()
13
+ model_options["scaled_fp8"] = t5xxl_scaled_fp8
14
+
15
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
16
+
17
+ class CosmosT5XXL(sd1_clip.SD1ClipModel):
18
+ def __init__(self, device="cpu", dtype=None, model_options={}):
19
+ super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
20
+
21
+
22
+ class T5XXLTokenizer(sd1_clip.SDTokenizer):
23
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
24
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
25
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data)
26
+
27
+
28
+ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
29
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
30
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
31
+
32
+
33
+ def te(dtype_t5=None, t5xxl_scaled_fp8=None):
34
+ class CosmosTEModel_(CosmosT5XXL):
35
+ def __init__(self, device="cpu", dtype=None, model_options={}):
36
+ if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
37
+ model_options = model_options.copy()
38
+ model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
39
+ if dtype is None:
40
+ dtype = dtype_t5
41
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
42
+ return CosmosTEModel_
ComfyUI/comfy/text_encoders/flux.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy import sd1_clip
2
+ import comfy.text_encoders.t5
3
+ import comfy.text_encoders.sd3_clip
4
+ import comfy.model_management
5
+ from transformers import T5TokenizerFast
6
+ import torch
7
+ import os
8
+
9
+ class T5XXLTokenizer(sd1_clip.SDTokenizer):
10
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
11
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
12
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
13
+
14
+
15
+ class FluxTokenizer:
16
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
17
+ self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
18
+ self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
19
+
20
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
21
+ out = {}
22
+ out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
23
+ out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
24
+ return out
25
+
26
+ def untokenize(self, token_weight_pair):
27
+ return self.clip_l.untokenize(token_weight_pair)
28
+
29
+ def state_dict(self):
30
+ return {}
31
+
32
+
33
+ class FluxClipModel(torch.nn.Module):
34
+ def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
35
+ super().__init__()
36
+ dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
37
+ self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
38
+ self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
39
+ self.dtypes = set([dtype, dtype_t5])
40
+
41
+ def set_clip_options(self, options):
42
+ self.clip_l.set_clip_options(options)
43
+ self.t5xxl.set_clip_options(options)
44
+
45
+ def reset_clip_options(self):
46
+ self.clip_l.reset_clip_options()
47
+ self.t5xxl.reset_clip_options()
48
+
49
+ def encode_token_weights(self, token_weight_pairs):
50
+ token_weight_pairs_l = token_weight_pairs["l"]
51
+ token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
52
+
53
+ t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
54
+ l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
55
+ return t5_out, l_pooled
56
+
57
+ def load_sd(self, sd):
58
+ if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
59
+ return self.clip_l.load_sd(sd)
60
+ else:
61
+ return self.t5xxl.load_sd(sd)
62
+
63
+ def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
64
+ class FluxClipModel_(FluxClipModel):
65
+ def __init__(self, device="cpu", dtype=None, model_options={}):
66
+ if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
67
+ model_options = model_options.copy()
68
+ model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
69
+ super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
70
+ return FluxClipModel_
ComfyUI/comfy/text_encoders/genmo.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy import sd1_clip
2
+ import comfy.text_encoders.sd3_clip
3
+ import os
4
+ from transformers import T5TokenizerFast
5
+
6
+
7
+ class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
8
+ def __init__(self, **kwargs):
9
+ kwargs["attention_mask"] = True
10
+ super().__init__(**kwargs)
11
+
12
+
13
+ class MochiT5XXL(sd1_clip.SD1ClipModel):
14
+ def __init__(self, device="cpu", dtype=None, model_options={}):
15
+ super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
16
+
17
+
18
+ class T5XXLTokenizer(sd1_clip.SDTokenizer):
19
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
20
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
21
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
22
+
23
+
24
+ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
25
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
26
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
27
+
28
+
29
+ def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
30
+ class MochiTEModel_(MochiT5XXL):
31
+ def __init__(self, device="cpu", dtype=None, model_options={}):
32
+ if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
33
+ model_options = model_options.copy()
34
+ model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
35
+ if dtype is None:
36
+ dtype = dtype_t5
37
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
38
+ return MochiTEModel_
ComfyUI/comfy/text_encoders/hidream.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import hunyuan_video
2
+ from . import sd3_clip
3
+ from comfy import sd1_clip
4
+ from comfy import sdxl_clip
5
+ import comfy.model_management
6
+ import torch
7
+ import logging
8
+
9
+
10
+ class HiDreamTokenizer:
11
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
12
+ self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
13
+ self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
14
+ self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data)
15
+ self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
16
+
17
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
18
+ out = {}
19
+ out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
20
+ out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
21
+ t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
22
+ out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
23
+ out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs)
24
+ return out
25
+
26
+ def untokenize(self, token_weight_pair):
27
+ return self.clip_g.untokenize(token_weight_pair)
28
+
29
+ def state_dict(self):
30
+ return {}
31
+
32
+
33
+ class HiDreamTEModel(torch.nn.Module):
34
+ def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}):
35
+ super().__init__()
36
+ self.dtypes = set()
37
+ if clip_l:
38
+ self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options)
39
+ self.dtypes.add(dtype)
40
+ else:
41
+ self.clip_l = None
42
+
43
+ if clip_g:
44
+ self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
45
+ self.dtypes.add(dtype)
46
+ else:
47
+ self.clip_g = None
48
+
49
+ if t5:
50
+ dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
51
+ self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True)
52
+ self.dtypes.add(dtype_t5)
53
+ else:
54
+ self.t5xxl = None
55
+
56
+ if llama:
57
+ dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
58
+ if "vocab_size" not in model_options:
59
+ model_options["vocab_size"] = 128256
60
+ self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009})
61
+ self.dtypes.add(dtype_llama)
62
+ else:
63
+ self.llama = None
64
+
65
+ logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
66
+
67
+ def set_clip_options(self, options):
68
+ if self.clip_l is not None:
69
+ self.clip_l.set_clip_options(options)
70
+ if self.clip_g is not None:
71
+ self.clip_g.set_clip_options(options)
72
+ if self.t5xxl is not None:
73
+ self.t5xxl.set_clip_options(options)
74
+ if self.llama is not None:
75
+ self.llama.set_clip_options(options)
76
+
77
+ def reset_clip_options(self):
78
+ if self.clip_l is not None:
79
+ self.clip_l.reset_clip_options()
80
+ if self.clip_g is not None:
81
+ self.clip_g.reset_clip_options()
82
+ if self.t5xxl is not None:
83
+ self.t5xxl.reset_clip_options()
84
+ if self.llama is not None:
85
+ self.llama.reset_clip_options()
86
+
87
+ def encode_token_weights(self, token_weight_pairs):
88
+ token_weight_pairs_l = token_weight_pairs["l"]
89
+ token_weight_pairs_g = token_weight_pairs["g"]
90
+ token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
91
+ token_weight_pairs_llama = token_weight_pairs["llama"]
92
+ lg_out = None
93
+ pooled = None
94
+ extra = {}
95
+
96
+ if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
97
+ if self.clip_l is not None:
98
+ lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
99
+ else:
100
+ l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
101
+
102
+ if self.clip_g is not None:
103
+ g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
104
+ else:
105
+ g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
106
+
107
+ pooled = torch.cat((l_pooled, g_pooled), dim=-1)
108
+
109
+ if self.t5xxl is not None:
110
+ t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
111
+ t5_out, t5_pooled = t5_output[:2]
112
+ else:
113
+ t5_out = None
114
+
115
+ if self.llama is not None:
116
+ ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
117
+ ll_out, ll_pooled = ll_output[:2]
118
+ ll_out = ll_out[:, 1:]
119
+ else:
120
+ ll_out = None
121
+
122
+ if t5_out is None:
123
+ t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
124
+
125
+ if ll_out is None:
126
+ ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
127
+
128
+ if pooled is None:
129
+ pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
130
+
131
+ extra["conditioning_llama3"] = ll_out
132
+ return t5_out, pooled, extra
133
+
134
+ def load_sd(self, sd):
135
+ if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
136
+ return self.clip_g.load_sd(sd)
137
+ elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
138
+ return self.clip_l.load_sd(sd)
139
+ elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
140
+ return self.t5xxl.load_sd(sd)
141
+ else:
142
+ return self.llama.load_sd(sd)
143
+
144
+
145
+ def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
146
+ class HiDreamTEModel_(HiDreamTEModel):
147
+ def __init__(self, device="cpu", dtype=None, model_options={}):
148
+ if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
149
+ model_options = model_options.copy()
150
+ model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
151
+ if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
152
+ model_options = model_options.copy()
153
+ model_options["llama_scaled_fp8"] = llama_scaled_fp8
154
+ super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
155
+ return HiDreamTEModel_
ComfyUI/comfy/text_encoders/hunyuan_video.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy import sd1_clip
2
+ import comfy.model_management
3
+ import comfy.text_encoders.llama
4
+ from transformers import LlamaTokenizerFast
5
+ import torch
6
+ import os
7
+ import numbers
8
+
9
+
10
+ def llama_detect(state_dict, prefix=""):
11
+ out = {}
12
+ t5_key = "{}model.norm.weight".format(prefix)
13
+ if t5_key in state_dict:
14
+ out["dtype_llama"] = state_dict[t5_key].dtype
15
+
16
+ scaled_fp8_key = "{}scaled_fp8".format(prefix)
17
+ if scaled_fp8_key in state_dict:
18
+ out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
19
+
20
+ return out
21
+
22
+
23
+ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
24
+ def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258):
25
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
26
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data)
27
+
28
+ class LLAMAModel(sd1_clip.SDClipModel):
29
+ def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
30
+ llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
31
+ if llama_scaled_fp8 is not None:
32
+ model_options = model_options.copy()
33
+ model_options["scaled_fp8"] = llama_scaled_fp8
34
+
35
+ textmodel_json_config = {}
36
+ vocab_size = model_options.get("vocab_size", None)
37
+ if vocab_size is not None:
38
+ textmodel_json_config["vocab_size"] = vocab_size
39
+
40
+ model_options = {**model_options, "model_name": "llama"}
41
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
42
+
43
+
44
+ class HunyuanVideoTokenizer:
45
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
46
+ self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
47
+ self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
48
+ self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data)
49
+
50
+ def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
51
+ out = {}
52
+ out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
53
+
54
+ if llama_template is None:
55
+ llama_text = self.llama_template.format(text)
56
+ else:
57
+ llama_text = llama_template.format(text)
58
+ llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids, **kwargs)
59
+ embed_count = 0
60
+ for r in llama_text_tokens:
61
+ for i in range(len(r)):
62
+ if r[i][0] == 128257:
63
+ if image_embeds is not None and embed_count < image_embeds.shape[0]:
64
+ r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:]
65
+ embed_count += 1
66
+ out["llama"] = llama_text_tokens
67
+ return out
68
+
69
+ def untokenize(self, token_weight_pair):
70
+ return self.clip_l.untokenize(token_weight_pair)
71
+
72
+ def state_dict(self):
73
+ return {}
74
+
75
+
76
+ class HunyuanVideoClipModel(torch.nn.Module):
77
+ def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
78
+ super().__init__()
79
+ dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
80
+ self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
81
+ self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
82
+ self.dtypes = set([dtype, dtype_llama])
83
+
84
+ def set_clip_options(self, options):
85
+ self.clip_l.set_clip_options(options)
86
+ self.llama.set_clip_options(options)
87
+
88
+ def reset_clip_options(self):
89
+ self.clip_l.reset_clip_options()
90
+ self.llama.reset_clip_options()
91
+
92
+ def encode_token_weights(self, token_weight_pairs):
93
+ token_weight_pairs_l = token_weight_pairs["l"]
94
+ token_weight_pairs_llama = token_weight_pairs["llama"]
95
+
96
+ llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
97
+
98
+ template_end = 0
99
+ extra_template_end = 0
100
+ extra_sizes = 0
101
+ user_end = 9999999999999
102
+ images = []
103
+
104
+ tok_pairs = token_weight_pairs_llama[0]
105
+ for i, v in enumerate(tok_pairs):
106
+ elem = v[0]
107
+ if not torch.is_tensor(elem):
108
+ if isinstance(elem, numbers.Integral):
109
+ if elem == 128006:
110
+ if tok_pairs[i + 1][0] == 882:
111
+ if tok_pairs[i + 2][0] == 128007:
112
+ template_end = i + 2
113
+ user_end = -1
114
+ if elem == 128009 and user_end == -1:
115
+ user_end = i + 1
116
+ else:
117
+ if elem.get("original_type") == "image":
118
+ elem_size = elem.get("data").shape[0]
119
+ if template_end > 0:
120
+ if user_end == -1:
121
+ extra_template_end += elem_size - 1
122
+ else:
123
+ image_start = i + extra_sizes
124
+ image_end = i + elem_size + extra_sizes
125
+ images.append((image_start, image_end, elem.get("image_interleave", 1)))
126
+ extra_sizes += elem_size - 1
127
+
128
+ if llama_out.shape[1] > (template_end + 2):
129
+ if tok_pairs[template_end + 1][0] == 271:
130
+ template_end += 2
131
+ llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
132
+ llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
133
+ if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
134
+ llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
135
+
136
+ if len(images) > 0:
137
+ out = []
138
+ for i in images:
139
+ out.append(llama_out[:, i[0]: i[1]: i[2]])
140
+ llama_output = torch.cat(out + [llama_output], dim=1)
141
+
142
+ l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
143
+ return llama_output, l_pooled, llama_extra_out
144
+
145
+ def load_sd(self, sd):
146
+ if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
147
+ return self.clip_l.load_sd(sd)
148
+ else:
149
+ return self.llama.load_sd(sd)
150
+
151
+
152
+ def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
153
+ class HunyuanVideoClipModel_(HunyuanVideoClipModel):
154
+ def __init__(self, device="cpu", dtype=None, model_options={}):
155
+ if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
156
+ model_options = model_options.copy()
157
+ model_options["llama_scaled_fp8"] = llama_scaled_fp8
158
+ super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
159
+ return HunyuanVideoClipModel_
ComfyUI/comfy/text_encoders/hydit.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy import sd1_clip
2
+ from transformers import BertTokenizer
3
+ from .spiece_tokenizer import SPieceTokenizer
4
+ from .bert import BertModel
5
+ import comfy.text_encoders.t5
6
+ import os
7
+ import torch
8
+
9
+ class HyditBertModel(sd1_clip.SDClipModel):
10
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
11
+ textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
12
+ model_options = {**model_options, "model_name": "hydit_clip"}
13
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
14
+
15
+ class HyditBertTokenizer(sd1_clip.SDTokenizer):
16
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
17
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
18
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data)
19
+
20
+
21
+ class MT5XLModel(sd1_clip.SDClipModel):
22
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
23
+ textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
24
+ model_options = {**model_options, "model_name": "mt5xl"}
25
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
26
+
27
+ class MT5XLTokenizer(sd1_clip.SDTokenizer):
28
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
29
+ #tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
30
+ tokenizer = tokenizer_data.get("spiece_model", None)
31
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
32
+
33
+ def state_dict(self):
34
+ return {"spiece_model": self.tokenizer.serialize_model()}
35
+
36
+ class HyditTokenizer:
37
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
38
+ mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
39
+ self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
40
+ self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
41
+
42
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
43
+ out = {}
44
+ out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids, **kwargs)
45
+ out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids, **kwargs)
46
+ return out
47
+
48
+ def untokenize(self, token_weight_pair):
49
+ return self.hydit_clip.untokenize(token_weight_pair)
50
+
51
+ def state_dict(self):
52
+ return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
53
+
54
+ class HyditModel(torch.nn.Module):
55
+ def __init__(self, device="cpu", dtype=None, model_options={}):
56
+ super().__init__()
57
+ self.hydit_clip = HyditBertModel(dtype=dtype, model_options=model_options)
58
+ self.mt5xl = MT5XLModel(dtype=dtype, model_options=model_options)
59
+
60
+ self.dtypes = set()
61
+ if dtype is not None:
62
+ self.dtypes.add(dtype)
63
+
64
+ def encode_token_weights(self, token_weight_pairs):
65
+ hydit_out = self.hydit_clip.encode_token_weights(token_weight_pairs["hydit_clip"])
66
+ mt5_out = self.mt5xl.encode_token_weights(token_weight_pairs["mt5xl"])
67
+ return hydit_out[0], hydit_out[1], {"attention_mask": hydit_out[2]["attention_mask"], "conditioning_mt5xl": mt5_out[0], "attention_mask_mt5xl": mt5_out[2]["attention_mask"]}
68
+
69
+ def load_sd(self, sd):
70
+ if "bert.encoder.layer.0.attention.self.query.weight" in sd:
71
+ return self.hydit_clip.load_sd(sd)
72
+ else:
73
+ return self.mt5xl.load_sd(sd)
74
+
75
+ def set_clip_options(self, options):
76
+ self.hydit_clip.set_clip_options(options)
77
+ self.mt5xl.set_clip_options(options)
78
+
79
+ def reset_clip_options(self):
80
+ self.hydit_clip.reset_clip_options()
81
+ self.mt5xl.reset_clip_options()
ComfyUI/comfy/text_encoders/hydit_clip.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "hfl/chinese-roberta-wwm-ext-large",
3
+ "architectures": [
4
+ "BertModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "directionality": "bidi",
10
+ "eos_token_id": 2,
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 1024,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 4096,
16
+ "layer_norm_eps": 1e-12,
17
+ "max_position_embeddings": 512,
18
+ "model_type": "bert",
19
+ "num_attention_heads": 16,
20
+ "num_hidden_layers": 24,
21
+ "output_past": true,
22
+ "pad_token_id": 0,
23
+ "pooler_fc_size": 768,
24
+ "pooler_num_attention_heads": 12,
25
+ "pooler_num_fc_layers": 3,
26
+ "pooler_size_per_head": 128,
27
+ "pooler_type": "first_token_transform",
28
+ "position_embedding_type": "absolute",
29
+ "torch_dtype": "float32",
30
+ "transformers_version": "4.22.1",
31
+ "type_vocab_size": 2,
32
+ "use_cache": true,
33
+ "vocab_size": 47020
34
+ }
35
+
ComfyUI/comfy/text_encoders/llama.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Any
5
+
6
+ from comfy.ldm.modules.attention import optimized_attention_for_device
7
+ import comfy.model_management
8
+ import comfy.ldm.common_dit
9
+
10
+ import comfy.model_management
11
+
12
+ @dataclass
13
+ class Llama2Config:
14
+ vocab_size: int = 128320
15
+ hidden_size: int = 4096
16
+ intermediate_size: int = 14336
17
+ num_hidden_layers: int = 32
18
+ num_attention_heads: int = 32
19
+ num_key_value_heads: int = 8
20
+ max_position_embeddings: int = 8192
21
+ rms_norm_eps: float = 1e-5
22
+ rope_theta: float = 500000.0
23
+ transformer_type: str = "llama"
24
+ head_dim = 128
25
+ rms_norm_add = False
26
+ mlp_activation = "silu"
27
+ qkv_bias = False
28
+
29
+ @dataclass
30
+ class Qwen25_3BConfig:
31
+ vocab_size: int = 151936
32
+ hidden_size: int = 2048
33
+ intermediate_size: int = 11008
34
+ num_hidden_layers: int = 36
35
+ num_attention_heads: int = 16
36
+ num_key_value_heads: int = 2
37
+ max_position_embeddings: int = 128000
38
+ rms_norm_eps: float = 1e-6
39
+ rope_theta: float = 1000000.0
40
+ transformer_type: str = "llama"
41
+ head_dim = 128
42
+ rms_norm_add = False
43
+ mlp_activation = "silu"
44
+ qkv_bias = True
45
+
46
+ @dataclass
47
+ class Gemma2_2B_Config:
48
+ vocab_size: int = 256000
49
+ hidden_size: int = 2304
50
+ intermediate_size: int = 9216
51
+ num_hidden_layers: int = 26
52
+ num_attention_heads: int = 8
53
+ num_key_value_heads: int = 4
54
+ max_position_embeddings: int = 8192
55
+ rms_norm_eps: float = 1e-6
56
+ rope_theta: float = 10000.0
57
+ transformer_type: str = "gemma2"
58
+ head_dim = 256
59
+ rms_norm_add = True
60
+ mlp_activation = "gelu_pytorch_tanh"
61
+ qkv_bias = False
62
+
63
+ class RMSNorm(nn.Module):
64
+ def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
65
+ super().__init__()
66
+ self.eps = eps
67
+ self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
68
+ self.add = add
69
+
70
+ def forward(self, x: torch.Tensor):
71
+ w = self.weight
72
+ if self.add:
73
+ w = w + 1.0
74
+
75
+ return comfy.ldm.common_dit.rms_norm(x, w, self.eps)
76
+
77
+
78
+
79
+ def rotate_half(x):
80
+ """Rotates half the hidden dims of the input."""
81
+ x1 = x[..., : x.shape[-1] // 2]
82
+ x2 = x[..., x.shape[-1] // 2 :]
83
+ return torch.cat((-x2, x1), dim=-1)
84
+
85
+
86
+ def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
87
+ theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
88
+ inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
89
+
90
+ position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
91
+
92
+ inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
93
+ position_ids_expanded = position_ids[:, None, :].float()
94
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
95
+ emb = torch.cat((freqs, freqs), dim=-1)
96
+ cos = emb.cos()
97
+ sin = emb.sin()
98
+ return (cos, sin)
99
+
100
+
101
+ def apply_rope(xq, xk, freqs_cis):
102
+ cos = freqs_cis[0].unsqueeze(1)
103
+ sin = freqs_cis[1].unsqueeze(1)
104
+ q_embed = (xq * cos) + (rotate_half(xq) * sin)
105
+ k_embed = (xk * cos) + (rotate_half(xk) * sin)
106
+ return q_embed, k_embed
107
+
108
+
109
+ class Attention(nn.Module):
110
+ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
111
+ super().__init__()
112
+ self.num_heads = config.num_attention_heads
113
+ self.num_kv_heads = config.num_key_value_heads
114
+ self.hidden_size = config.hidden_size
115
+
116
+ self.head_dim = config.head_dim
117
+ self.inner_size = self.num_heads * self.head_dim
118
+
119
+ ops = ops or nn
120
+ self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
121
+ self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
122
+ self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
123
+ self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
124
+
125
+ def forward(
126
+ self,
127
+ hidden_states: torch.Tensor,
128
+ attention_mask: Optional[torch.Tensor] = None,
129
+ freqs_cis: Optional[torch.Tensor] = None,
130
+ optimized_attention=None,
131
+ ):
132
+ batch_size, seq_length, _ = hidden_states.shape
133
+ xq = self.q_proj(hidden_states)
134
+ xk = self.k_proj(hidden_states)
135
+ xv = self.v_proj(hidden_states)
136
+
137
+ xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
138
+ xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
139
+ xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
140
+
141
+ xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
142
+
143
+ xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
144
+ xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
145
+
146
+ output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
147
+ return self.o_proj(output)
148
+
149
+ class MLP(nn.Module):
150
+ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
151
+ super().__init__()
152
+ ops = ops or nn
153
+ self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
154
+ self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
155
+ self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
156
+ if config.mlp_activation == "silu":
157
+ self.activation = torch.nn.functional.silu
158
+ elif config.mlp_activation == "gelu_pytorch_tanh":
159
+ self.activation = lambda a: torch.nn.functional.gelu(a, approximate="tanh")
160
+
161
+ def forward(self, x):
162
+ return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
163
+
164
+ class TransformerBlock(nn.Module):
165
+ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
166
+ super().__init__()
167
+ self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
168
+ self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
169
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
170
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
171
+
172
+ def forward(
173
+ self,
174
+ x: torch.Tensor,
175
+ attention_mask: Optional[torch.Tensor] = None,
176
+ freqs_cis: Optional[torch.Tensor] = None,
177
+ optimized_attention=None,
178
+ ):
179
+ # Self Attention
180
+ residual = x
181
+ x = self.input_layernorm(x)
182
+ x = self.self_attn(
183
+ hidden_states=x,
184
+ attention_mask=attention_mask,
185
+ freqs_cis=freqs_cis,
186
+ optimized_attention=optimized_attention,
187
+ )
188
+ x = residual + x
189
+
190
+ # MLP
191
+ residual = x
192
+ x = self.post_attention_layernorm(x)
193
+ x = self.mlp(x)
194
+ x = residual + x
195
+
196
+ return x
197
+
198
+ class TransformerBlockGemma2(nn.Module):
199
+ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
200
+ super().__init__()
201
+ self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
202
+ self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
203
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
204
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
205
+ self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
206
+ self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
207
+
208
+ def forward(
209
+ self,
210
+ x: torch.Tensor,
211
+ attention_mask: Optional[torch.Tensor] = None,
212
+ freqs_cis: Optional[torch.Tensor] = None,
213
+ optimized_attention=None,
214
+ ):
215
+ # Self Attention
216
+ residual = x
217
+ x = self.input_layernorm(x)
218
+ x = self.self_attn(
219
+ hidden_states=x,
220
+ attention_mask=attention_mask,
221
+ freqs_cis=freqs_cis,
222
+ optimized_attention=optimized_attention,
223
+ )
224
+
225
+ x = self.post_attention_layernorm(x)
226
+ x = residual + x
227
+
228
+ # MLP
229
+ residual = x
230
+ x = self.pre_feedforward_layernorm(x)
231
+ x = self.mlp(x)
232
+ x = self.post_feedforward_layernorm(x)
233
+ x = residual + x
234
+
235
+ return x
236
+
237
+ class Llama2_(nn.Module):
238
+ def __init__(self, config, device=None, dtype=None, ops=None):
239
+ super().__init__()
240
+ self.config = config
241
+ self.vocab_size = config.vocab_size
242
+
243
+ self.embed_tokens = ops.Embedding(
244
+ config.vocab_size,
245
+ config.hidden_size,
246
+ device=device,
247
+ dtype=dtype
248
+ )
249
+ if self.config.transformer_type == "gemma2":
250
+ transformer = TransformerBlockGemma2
251
+ self.normalize_in = True
252
+ else:
253
+ transformer = TransformerBlock
254
+ self.normalize_in = False
255
+
256
+ self.layers = nn.ModuleList([
257
+ transformer(config, device=device, dtype=dtype, ops=ops)
258
+ for _ in range(config.num_hidden_layers)
259
+ ])
260
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
261
+ # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
262
+
263
+ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
264
+ if embeds is not None:
265
+ x = embeds
266
+ else:
267
+ x = self.embed_tokens(x, out_dtype=dtype)
268
+
269
+ if self.normalize_in:
270
+ x *= self.config.hidden_size ** 0.5
271
+
272
+ freqs_cis = precompute_freqs_cis(self.config.head_dim,
273
+ x.shape[1],
274
+ self.config.rope_theta,
275
+ device=x.device)
276
+
277
+ mask = None
278
+ if attention_mask is not None:
279
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
280
+ mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
281
+
282
+ causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
283
+ if mask is not None:
284
+ mask += causal_mask
285
+ else:
286
+ mask = causal_mask
287
+ optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
288
+
289
+ intermediate = None
290
+ all_intermediate = None
291
+ if intermediate_output is not None:
292
+ if intermediate_output == "all":
293
+ all_intermediate = []
294
+ intermediate_output = None
295
+ elif intermediate_output < 0:
296
+ intermediate_output = len(self.layers) + intermediate_output
297
+
298
+ for i, layer in enumerate(self.layers):
299
+ if all_intermediate is not None:
300
+ all_intermediate.append(x.unsqueeze(1).clone())
301
+ x = layer(
302
+ x=x,
303
+ attention_mask=mask,
304
+ freqs_cis=freqs_cis,
305
+ optimized_attention=optimized_attention,
306
+ )
307
+ if i == intermediate_output:
308
+ intermediate = x.clone()
309
+
310
+ x = self.norm(x)
311
+ if all_intermediate is not None:
312
+ all_intermediate.append(x.unsqueeze(1).clone())
313
+
314
+ if all_intermediate is not None:
315
+ intermediate = torch.cat(all_intermediate, dim=1)
316
+
317
+ if intermediate is not None and final_layer_norm_intermediate:
318
+ intermediate = self.norm(intermediate)
319
+
320
+ return x, intermediate
321
+
322
+ class BaseLlama:
323
+ def get_input_embeddings(self):
324
+ return self.model.embed_tokens
325
+
326
+ def set_input_embeddings(self, embeddings):
327
+ self.model.embed_tokens = embeddings
328
+
329
+ def forward(self, input_ids, *args, **kwargs):
330
+ return self.model(input_ids, *args, **kwargs)
331
+
332
+
333
+ class Llama2(BaseLlama, torch.nn.Module):
334
+ def __init__(self, config_dict, dtype, device, operations):
335
+ super().__init__()
336
+ config = Llama2Config(**config_dict)
337
+ self.num_layers = config.num_hidden_layers
338
+
339
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
340
+ self.dtype = dtype
341
+
342
+ class Qwen25_3B(BaseLlama, torch.nn.Module):
343
+ def __init__(self, config_dict, dtype, device, operations):
344
+ super().__init__()
345
+ config = Qwen25_3BConfig(**config_dict)
346
+ self.num_layers = config.num_hidden_layers
347
+
348
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
349
+ self.dtype = dtype
350
+
351
+ class Gemma2_2B(BaseLlama, torch.nn.Module):
352
+ def __init__(self, config_dict, dtype, device, operations):
353
+ super().__init__()
354
+ config = Gemma2_2B_Config(**config_dict)
355
+ self.num_layers = config.num_hidden_layers
356
+
357
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
358
+ self.dtype = dtype
ComfyUI/comfy/text_encoders/long_clipl.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ def model_options_long_clip(sd, tokenizer_data, model_options):
4
+ w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
5
+ if w is None:
6
+ w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None)
7
+ else:
8
+ model_name = "clip_g"
9
+
10
+ if w is None:
11
+ w = sd.get("text_model.embeddings.position_embedding.weight", None)
12
+ if w is not None:
13
+ if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
14
+ model_name = "clip_g"
15
+ elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
16
+ model_name = "clip_l"
17
+ else:
18
+ model_name = "clip_l"
19
+
20
+ if w is not None:
21
+ tokenizer_data = tokenizer_data.copy()
22
+ model_options = model_options.copy()
23
+ model_config = model_options.get("model_config", {})
24
+ model_config["max_position_embeddings"] = w.shape[0]
25
+ model_options["{}_model_config".format(model_name)] = model_config
26
+ tokenizer_data["{}_max_length".format(model_name)] = w.shape[0]
27
+ return tokenizer_data, model_options
ComfyUI/comfy/text_encoders/lt.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy import sd1_clip
2
+ import os
3
+ from transformers import T5TokenizerFast
4
+ import comfy.text_encoders.genmo
5
+
6
+ class T5XXLTokenizer(sd1_clip.SDTokenizer):
7
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
8
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
9
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128?
10
+
11
+
12
+ class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
13
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
14
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
15
+
16
+
17
+ def ltxv_te(*args, **kwargs):
18
+ return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
ComfyUI/comfy/text_encoders/lumina2.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy import sd1_clip
2
+ from .spiece_tokenizer import SPieceTokenizer
3
+ import comfy.text_encoders.llama
4
+
5
+
6
+ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
7
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
8
+ tokenizer = tokenizer_data.get("spiece_model", None)
9
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
10
+
11
+ def state_dict(self):
12
+ return {"spiece_model": self.tokenizer.serialize_model()}
13
+
14
+
15
+ class LuminaTokenizer(sd1_clip.SD1Tokenizer):
16
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
17
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
18
+
19
+
20
+ class Gemma2_2BModel(sd1_clip.SDClipModel):
21
+ def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
22
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
23
+
24
+
25
+ class LuminaModel(sd1_clip.SD1ClipModel):
26
+ def __init__(self, device="cpu", dtype=None, model_options={}):
27
+ super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
28
+
29
+
30
+ def te(dtype_llama=None, llama_scaled_fp8=None):
31
+ class LuminaTEModel_(LuminaModel):
32
+ def __init__(self, device="cpu", dtype=None, model_options={}):
33
+ if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
34
+ model_options = model_options.copy()
35
+ model_options["scaled_fp8"] = llama_scaled_fp8
36
+ if dtype_llama is not None:
37
+ dtype = dtype_llama
38
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
39
+ return LuminaTEModel_
ComfyUI/comfy/text_encoders/mt5_config_xl.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "d_ff": 5120,
3
+ "d_kv": 64,
4
+ "d_model": 2048,
5
+ "decoder_start_token_id": 0,
6
+ "dropout_rate": 0.1,
7
+ "eos_token_id": 1,
8
+ "dense_act_fn": "gelu_pytorch_tanh",
9
+ "initializer_factor": 1.0,
10
+ "is_encoder_decoder": true,
11
+ "is_gated_act": true,
12
+ "layer_norm_epsilon": 1e-06,
13
+ "model_type": "mt5",
14
+ "num_decoder_layers": 24,
15
+ "num_heads": 32,
16
+ "num_layers": 24,
17
+ "output_past": true,
18
+ "pad_token_id": 0,
19
+ "relative_attention_num_buckets": 32,
20
+ "tie_word_embeddings": false,
21
+ "vocab_size": 250112
22
+ }
ComfyUI/comfy/text_encoders/omnigen2.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2Tokenizer
2
+ from comfy import sd1_clip
3
+ import comfy.text_encoders.llama
4
+ import os
5
+
6
+
7
+ class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
8
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
9
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
10
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
11
+
12
+
13
+ class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
14
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
15
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
16
+ self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
17
+
18
+ def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
19
+ if llama_template is None:
20
+ llama_text = self.llama_template.format(text)
21
+ else:
22
+ llama_text = llama_template.format(text)
23
+ return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
24
+
25
+ class Qwen25_3BModel(sd1_clip.SDClipModel):
26
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
27
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
28
+
29
+
30
+ class Omnigen2Model(sd1_clip.SD1ClipModel):
31
+ def __init__(self, device="cpu", dtype=None, model_options={}):
32
+ super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
33
+
34
+
35
+ def te(dtype_llama=None, llama_scaled_fp8=None):
36
+ class Omnigen2TEModel_(Omnigen2Model):
37
+ def __init__(self, device="cpu", dtype=None, model_options={}):
38
+ if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
39
+ model_options = model_options.copy()
40
+ model_options["scaled_fp8"] = llama_scaled_fp8
41
+ if dtype_llama is not None:
42
+ dtype = dtype_llama
43
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
44
+ return Omnigen2TEModel_
ComfyUI/comfy/text_encoders/pixart_t5.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from comfy import sd1_clip
4
+ import comfy.text_encoders.t5
5
+ import comfy.text_encoders.sd3_clip
6
+ from comfy.sd1_clip import gen_empty_tokens
7
+
8
+ from transformers import T5TokenizerFast
9
+
10
+ class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
11
+ def __init__(self, **kwargs):
12
+ super().__init__(**kwargs)
13
+
14
+ def gen_empty_tokens(self, special_tokens, *args, **kwargs):
15
+ # PixArt expects the negative to be all pad tokens
16
+ special_tokens = special_tokens.copy()
17
+ special_tokens.pop("end")
18
+ return gen_empty_tokens(special_tokens, *args, **kwargs)
19
+
20
+ class PixArtT5XXL(sd1_clip.SD1ClipModel):
21
+ def __init__(self, device="cpu", dtype=None, model_options={}):
22
+ super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
23
+
24
+ class T5XXLTokenizer(sd1_clip.SDTokenizer):
25
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
26
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
27
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
28
+
29
+ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
30
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
31
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
32
+
33
+ def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
34
+ class PixArtTEModel_(PixArtT5XXL):
35
+ def __init__(self, device="cpu", dtype=None, model_options={}):
36
+ if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
37
+ model_options = model_options.copy()
38
+ model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
39
+ if dtype is None:
40
+ dtype = dtype_t5
41
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
42
+ return PixArtTEModel_
ComfyUI/comfy/text_encoders/sd2_clip.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy import sd1_clip
2
+ import os
3
+
4
+ class SD2ClipHModel(sd1_clip.SDClipModel):
5
+ def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
6
+ if layer == "penultimate":
7
+ layer="hidden"
8
+ layer_idx=-2
9
+
10
+ textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
11
+ super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options)
12
+
13
+ class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
14
+ def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
15
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data)
16
+
17
+ class SD2Tokenizer(sd1_clip.SD1Tokenizer):
18
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
19
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
20
+
21
+ class SD2ClipModel(sd1_clip.SD1ClipModel):
22
+ def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
23
+ super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
ComfyUI/comfy/text_encoders/sd2_clip_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 49407,
9
+ "hidden_act": "gelu",
10
+ "hidden_size": 1024,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 4096,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 16,
18
+ "num_hidden_layers": 24,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 1024,
21
+ "torch_dtype": "float32",
22
+ "vocab_size": 49408
23
+ }
ComfyUI/comfy/text_encoders/sd3_clip.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy import sd1_clip
2
+ from comfy import sdxl_clip
3
+ from transformers import T5TokenizerFast
4
+ import comfy.text_encoders.t5
5
+ import torch
6
+ import os
7
+ import comfy.model_management
8
+ import logging
9
+
10
+ class T5XXLModel(sd1_clip.SDClipModel):
11
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
12
+ textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
13
+ t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
14
+ if t5xxl_scaled_fp8 is not None:
15
+ model_options = model_options.copy()
16
+ model_options["scaled_fp8"] = t5xxl_scaled_fp8
17
+
18
+ model_options = {**model_options, "model_name": "t5xxl"}
19
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
20
+
21
+
22
+ def t5_xxl_detect(state_dict, prefix=""):
23
+ out = {}
24
+ t5_key = "{}encoder.final_layer_norm.weight".format(prefix)
25
+ if t5_key in state_dict:
26
+ out["dtype_t5"] = state_dict[t5_key].dtype
27
+
28
+ scaled_fp8_key = "{}scaled_fp8".format(prefix)
29
+ if scaled_fp8_key in state_dict:
30
+ out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
31
+
32
+ return out
33
+
34
+ class T5XXLTokenizer(sd1_clip.SDTokenizer):
35
+ def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77, max_length=99999999):
36
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
37
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=max_length, min_length=min_length, tokenizer_data=tokenizer_data)
38
+
39
+
40
+ class SD3Tokenizer:
41
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
42
+ self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
43
+ self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
44
+ self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
45
+
46
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
47
+ out = {}
48
+ out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
49
+ out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
50
+ out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
51
+ return out
52
+
53
+ def untokenize(self, token_weight_pair):
54
+ return self.clip_g.untokenize(token_weight_pair)
55
+
56
+ def state_dict(self):
57
+ return {}
58
+
59
+ class SD3ClipModel(torch.nn.Module):
60
+ def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}):
61
+ super().__init__()
62
+ self.dtypes = set()
63
+ if clip_l:
64
+ self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
65
+ self.dtypes.add(dtype)
66
+ else:
67
+ self.clip_l = None
68
+
69
+ if clip_g:
70
+ self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
71
+ self.dtypes.add(dtype)
72
+ else:
73
+ self.clip_g = None
74
+
75
+ if t5:
76
+ dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
77
+ self.t5_attention_mask = t5_attention_mask
78
+ self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=self.t5_attention_mask)
79
+ self.dtypes.add(dtype_t5)
80
+ else:
81
+ self.t5xxl = None
82
+
83
+ logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
84
+
85
+ def set_clip_options(self, options):
86
+ if self.clip_l is not None:
87
+ self.clip_l.set_clip_options(options)
88
+ if self.clip_g is not None:
89
+ self.clip_g.set_clip_options(options)
90
+ if self.t5xxl is not None:
91
+ self.t5xxl.set_clip_options(options)
92
+
93
+ def reset_clip_options(self):
94
+ if self.clip_l is not None:
95
+ self.clip_l.reset_clip_options()
96
+ if self.clip_g is not None:
97
+ self.clip_g.reset_clip_options()
98
+ if self.t5xxl is not None:
99
+ self.t5xxl.reset_clip_options()
100
+
101
+ def encode_token_weights(self, token_weight_pairs):
102
+ token_weight_pairs_l = token_weight_pairs["l"]
103
+ token_weight_pairs_g = token_weight_pairs["g"]
104
+ token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
105
+ lg_out = None
106
+ pooled = None
107
+ out = None
108
+ extra = {}
109
+
110
+ if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
111
+ if self.clip_l is not None:
112
+ lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
113
+ else:
114
+ l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
115
+
116
+ if self.clip_g is not None:
117
+ g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
118
+ if lg_out is not None:
119
+ cut_to = min(lg_out.shape[1], g_out.shape[1])
120
+ lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
121
+ else:
122
+ lg_out = torch.nn.functional.pad(g_out, (768, 0))
123
+ else:
124
+ g_out = None
125
+ g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
126
+
127
+ if lg_out is not None:
128
+ lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
129
+ out = lg_out
130
+ pooled = torch.cat((l_pooled, g_pooled), dim=-1)
131
+
132
+ if self.t5xxl is not None:
133
+ t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
134
+ t5_out, t5_pooled = t5_output[:2]
135
+ if self.t5_attention_mask:
136
+ extra["attention_mask"] = t5_output[2]["attention_mask"]
137
+
138
+ if lg_out is not None:
139
+ out = torch.cat([lg_out, t5_out], dim=-2)
140
+ else:
141
+ out = t5_out
142
+
143
+ if out is None:
144
+ out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device())
145
+
146
+ if pooled is None:
147
+ pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
148
+
149
+ return out, pooled, extra
150
+
151
+ def load_sd(self, sd):
152
+ if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
153
+ return self.clip_g.load_sd(sd)
154
+ elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
155
+ return self.clip_l.load_sd(sd)
156
+ else:
157
+ return self.t5xxl.load_sd(sd)
158
+
159
+ def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
160
+ class SD3ClipModel_(SD3ClipModel):
161
+ def __init__(self, device="cpu", dtype=None, model_options={}):
162
+ if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
163
+ model_options = model_options.copy()
164
+ model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
165
+ super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
166
+ return SD3ClipModel_
ComfyUI/comfy/text_encoders/t5.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from comfy.ldm.modules.attention import optimized_attention_for_device
4
+ import comfy.ops
5
+
6
+ class T5LayerNorm(torch.nn.Module):
7
+ def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
8
+ super().__init__()
9
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
10
+ self.variance_epsilon = eps
11
+
12
+ def forward(self, x):
13
+ variance = x.pow(2).mean(-1, keepdim=True)
14
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
15
+ return comfy.ops.cast_to_input(self.weight, x) * x
16
+
17
+ activations = {
18
+ "gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
19
+ "relu": torch.nn.functional.relu,
20
+ }
21
+
22
+ class T5DenseActDense(torch.nn.Module):
23
+ def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
24
+ super().__init__()
25
+ self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
26
+ self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
27
+ # self.dropout = nn.Dropout(config.dropout_rate)
28
+ self.act = activations[ff_activation]
29
+
30
+ def forward(self, x):
31
+ x = self.act(self.wi(x))
32
+ # x = self.dropout(x)
33
+ x = self.wo(x)
34
+ return x
35
+
36
+ class T5DenseGatedActDense(torch.nn.Module):
37
+ def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
38
+ super().__init__()
39
+ self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
40
+ self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
41
+ self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
42
+ # self.dropout = nn.Dropout(config.dropout_rate)
43
+ self.act = activations[ff_activation]
44
+
45
+ def forward(self, x):
46
+ hidden_gelu = self.act(self.wi_0(x))
47
+ hidden_linear = self.wi_1(x)
48
+ x = hidden_gelu * hidden_linear
49
+ # x = self.dropout(x)
50
+ x = self.wo(x)
51
+ return x
52
+
53
+ class T5LayerFF(torch.nn.Module):
54
+ def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations):
55
+ super().__init__()
56
+ if gated_act:
57
+ self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
58
+ else:
59
+ self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
60
+
61
+ self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
62
+ # self.dropout = nn.Dropout(config.dropout_rate)
63
+
64
+ def forward(self, x):
65
+ forwarded_states = self.layer_norm(x)
66
+ forwarded_states = self.DenseReluDense(forwarded_states)
67
+ # x = x + self.dropout(forwarded_states)
68
+ x += forwarded_states
69
+ return x
70
+
71
+ class T5Attention(torch.nn.Module):
72
+ def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device, operations):
73
+ super().__init__()
74
+
75
+ # Mesh TensorFlow initialization to avoid scaling before softmax
76
+ self.q = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
77
+ self.k = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
78
+ self.v = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
79
+ self.o = operations.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
80
+ self.num_heads = num_heads
81
+
82
+ self.relative_attention_bias = None
83
+ if relative_attention_bias:
84
+ self.relative_attention_num_buckets = 32
85
+ self.relative_attention_max_distance = 128
86
+ self.relative_attention_bias = operations.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=dtype)
87
+
88
+ @staticmethod
89
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
90
+ """
91
+ Adapted from Mesh Tensorflow:
92
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
93
+
94
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
95
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
96
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
97
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
98
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
99
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
100
+
101
+ Args:
102
+ relative_position: an int32 Tensor
103
+ bidirectional: a boolean - whether the attention is bidirectional
104
+ num_buckets: an integer
105
+ max_distance: an integer
106
+
107
+ Returns:
108
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
109
+ """
110
+ relative_buckets = 0
111
+ if bidirectional:
112
+ num_buckets //= 2
113
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
114
+ relative_position = torch.abs(relative_position)
115
+ else:
116
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
117
+ # now relative_position is in the range [0, inf)
118
+
119
+ # half of the buckets are for exact increments in positions
120
+ max_exact = num_buckets // 2
121
+ is_small = relative_position < max_exact
122
+
123
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
124
+ relative_position_if_large = max_exact + (
125
+ torch.log(relative_position.float() / max_exact)
126
+ / math.log(max_distance / max_exact)
127
+ * (num_buckets - max_exact)
128
+ ).to(torch.long)
129
+ relative_position_if_large = torch.min(
130
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
131
+ )
132
+
133
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
134
+ return relative_buckets
135
+
136
+ def compute_bias(self, query_length, key_length, device, dtype):
137
+ """Compute binned relative position bias"""
138
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
139
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
140
+ relative_position = memory_position - context_position # shape (query_length, key_length)
141
+ relative_position_bucket = self._relative_position_bucket(
142
+ relative_position, # shape (query_length, key_length)
143
+ bidirectional=True,
144
+ num_buckets=self.relative_attention_num_buckets,
145
+ max_distance=self.relative_attention_max_distance,
146
+ )
147
+ values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
148
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
149
+ return values.contiguous()
150
+
151
+ def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
152
+ q = self.q(x)
153
+ k = self.k(x)
154
+ v = self.v(x)
155
+ if self.relative_attention_bias is not None:
156
+ past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device, x.dtype)
157
+
158
+ if past_bias is not None:
159
+ if mask is not None:
160
+ mask = mask + past_bias
161
+ else:
162
+ mask = past_bias
163
+
164
+ out = optimized_attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
165
+ return self.o(out), past_bias
166
+
167
+ class T5LayerSelfAttention(torch.nn.Module):
168
+ def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations):
169
+ super().__init__()
170
+ self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device, operations)
171
+ self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
172
+ # self.dropout = nn.Dropout(config.dropout_rate)
173
+
174
+ def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
175
+ output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias, optimized_attention=optimized_attention)
176
+ # x = x + self.dropout(attention_output)
177
+ x += output
178
+ return x, past_bias
179
+
180
+ class T5Block(torch.nn.Module):
181
+ def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias, dtype, device, operations):
182
+ super().__init__()
183
+ self.layer = torch.nn.ModuleList()
184
+ self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations))
185
+ self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations))
186
+
187
+ def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
188
+ x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
189
+ x = self.layer[-1](x)
190
+ return x, past_bias
191
+
192
+ class T5Stack(torch.nn.Module):
193
+ def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention, dtype, device, operations):
194
+ super().__init__()
195
+
196
+ self.block = torch.nn.ModuleList(
197
+ [T5Block(model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias=((not relative_attention) or (i == 0)), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
198
+ )
199
+ self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
200
+ # self.dropout = nn.Dropout(config.dropout_rate)
201
+
202
+ def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
203
+ mask = None
204
+ if attention_mask is not None:
205
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
206
+ mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
207
+
208
+ intermediate = None
209
+ optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
210
+ past_bias = None
211
+
212
+ if intermediate_output is not None:
213
+ if intermediate_output < 0:
214
+ intermediate_output = len(self.block) + intermediate_output
215
+
216
+ for i, l in enumerate(self.block):
217
+ x, past_bias = l(x, mask, past_bias, optimized_attention)
218
+ if i == intermediate_output:
219
+ intermediate = x.clone()
220
+ x = self.final_layer_norm(x)
221
+ if intermediate is not None and final_layer_norm_intermediate:
222
+ intermediate = self.final_layer_norm(intermediate)
223
+ return x, intermediate
224
+
225
+ class T5(torch.nn.Module):
226
+ def __init__(self, config_dict, dtype, device, operations):
227
+ super().__init__()
228
+ self.num_layers = config_dict["num_layers"]
229
+ model_dim = config_dict["d_model"]
230
+ inner_dim = config_dict["d_kv"] * config_dict["num_heads"]
231
+
232
+ self.encoder = T5Stack(self.num_layers, model_dim, inner_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
233
+ self.dtype = dtype
234
+ self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)
235
+
236
+ def get_input_embeddings(self):
237
+ return self.shared
238
+
239
+ def set_input_embeddings(self, embeddings):
240
+ self.shared = embeddings
241
+
242
+ def forward(self, input_ids, attention_mask, embeds=None, num_tokens=None, **kwargs):
243
+ if input_ids is None:
244
+ x = embeds
245
+ else:
246
+ x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
247
+ if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
248
+ x = torch.nan_to_num(x) #Fix for fp8 T5 base
249
+ return self.encoder(x, attention_mask=attention_mask, **kwargs)
ComfyUI/comfy/text_encoders/t5_config_base.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "d_ff": 3072,
3
+ "d_kv": 64,
4
+ "d_model": 768,
5
+ "decoder_start_token_id": 0,
6
+ "dropout_rate": 0.1,
7
+ "eos_token_id": 1,
8
+ "dense_act_fn": "relu",
9
+ "initializer_factor": 1.0,
10
+ "is_encoder_decoder": true,
11
+ "is_gated_act": false,
12
+ "layer_norm_epsilon": 1e-06,
13
+ "model_type": "t5",
14
+ "num_decoder_layers": 12,
15
+ "num_heads": 12,
16
+ "num_layers": 12,
17
+ "output_past": true,
18
+ "pad_token_id": 0,
19
+ "relative_attention_num_buckets": 32,
20
+ "tie_word_embeddings": false,
21
+ "vocab_size": 32128
22
+ }
ComfyUI/comfy/text_encoders/t5_config_xxl.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "d_ff": 10240,
3
+ "d_kv": 64,
4
+ "d_model": 4096,
5
+ "decoder_start_token_id": 0,
6
+ "dropout_rate": 0.1,
7
+ "eos_token_id": 1,
8
+ "dense_act_fn": "gelu_pytorch_tanh",
9
+ "initializer_factor": 1.0,
10
+ "is_encoder_decoder": true,
11
+ "is_gated_act": true,
12
+ "layer_norm_epsilon": 1e-06,
13
+ "model_type": "t5",
14
+ "num_decoder_layers": 24,
15
+ "num_heads": 64,
16
+ "num_layers": 24,
17
+ "output_past": true,
18
+ "pad_token_id": 0,
19
+ "relative_attention_num_buckets": 32,
20
+ "tie_word_embeddings": false,
21
+ "vocab_size": 32128
22
+ }
ComfyUI/comfy/text_encoders/t5_old_config_xxl.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "d_ff": 65536,
3
+ "d_kv": 128,
4
+ "d_model": 1024,
5
+ "decoder_start_token_id": 0,
6
+ "dropout_rate": 0.1,
7
+ "eos_token_id": 1,
8
+ "dense_act_fn": "relu",
9
+ "initializer_factor": 1.0,
10
+ "is_encoder_decoder": true,
11
+ "is_gated_act": false,
12
+ "layer_norm_epsilon": 1e-06,
13
+ "model_type": "t5",
14
+ "num_decoder_layers": 24,
15
+ "num_heads": 128,
16
+ "num_layers": 24,
17
+ "output_past": true,
18
+ "pad_token_id": 0,
19
+ "relative_attention_num_buckets": 32,
20
+ "tie_word_embeddings": false,
21
+ "vocab_size": 32128
22
+ }
ComfyUI/comfy/text_encoders/umt5_config_base.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "d_ff": 2048,
3
+ "d_kv": 64,
4
+ "d_model": 768,
5
+ "decoder_start_token_id": 0,
6
+ "dropout_rate": 0.1,
7
+ "eos_token_id": 1,
8
+ "dense_act_fn": "gelu_pytorch_tanh",
9
+ "initializer_factor": 1.0,
10
+ "is_encoder_decoder": true,
11
+ "is_gated_act": true,
12
+ "layer_norm_epsilon": 1e-06,
13
+ "model_type": "umt5",
14
+ "num_decoder_layers": 12,
15
+ "num_heads": 12,
16
+ "num_layers": 12,
17
+ "output_past": true,
18
+ "pad_token_id": 0,
19
+ "relative_attention_num_buckets": 32,
20
+ "tie_word_embeddings": false,
21
+ "vocab_size": 256384
22
+ }
ComfyUI/comfy/utils.py ADDED
@@ -0,0 +1,1104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Comfy
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+
20
+ import torch
21
+ import math
22
+ import struct
23
+ import comfy.checkpoint_pickle
24
+ import safetensors.torch
25
+ import numpy as np
26
+ from PIL import Image
27
+ import logging
28
+ import itertools
29
+ from torch.nn.functional import interpolate
30
+ from einops import rearrange
31
+ from comfy.cli_args import args
32
+
33
+ MMAP_TORCH_FILES = args.mmap_torch_files
34
+ DISABLE_MMAP = args.disable_mmap
35
+
36
+ ALWAYS_SAFE_LOAD = False
37
+ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
38
+ class ModelCheckpoint:
39
+ pass
40
+ ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
41
+
42
+ from numpy.core.multiarray import scalar
43
+ from numpy import dtype
44
+ from numpy.dtypes import Float64DType
45
+ from _codecs import encode
46
+
47
+ torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
48
+ ALWAYS_SAFE_LOAD = True
49
+ logging.info("Checkpoint files will always be loaded safely.")
50
+ else:
51
+ logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
52
+
53
+ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
54
+ if device is None:
55
+ device = torch.device("cpu")
56
+ metadata = None
57
+ if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
58
+ try:
59
+ with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
60
+ sd = {}
61
+ for k in f.keys():
62
+ tensor = f.get_tensor(k)
63
+ if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
64
+ tensor = tensor.to(device=device, copy=True)
65
+ sd[k] = tensor
66
+ if return_metadata:
67
+ metadata = f.metadata()
68
+ except Exception as e:
69
+ if len(e.args) > 0:
70
+ message = e.args[0]
71
+ if "HeaderTooLarge" in message:
72
+ raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt))
73
+ if "MetadataIncompleteBuffer" in message:
74
+ raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
75
+ raise e
76
+ else:
77
+ torch_args = {}
78
+ if MMAP_TORCH_FILES:
79
+ torch_args["mmap"] = True
80
+
81
+ if safe_load or ALWAYS_SAFE_LOAD:
82
+ pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
83
+ else:
84
+ logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
85
+ pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
86
+ if "state_dict" in pl_sd:
87
+ sd = pl_sd["state_dict"]
88
+ else:
89
+ if len(pl_sd) == 1:
90
+ key = list(pl_sd.keys())[0]
91
+ sd = pl_sd[key]
92
+ if not isinstance(sd, dict):
93
+ sd = pl_sd
94
+ else:
95
+ sd = pl_sd
96
+ return (sd, metadata) if return_metadata else sd
97
+
98
+ def save_torch_file(sd, ckpt, metadata=None):
99
+ if metadata is not None:
100
+ safetensors.torch.save_file(sd, ckpt, metadata=metadata)
101
+ else:
102
+ safetensors.torch.save_file(sd, ckpt)
103
+
104
+ def calculate_parameters(sd, prefix=""):
105
+ params = 0
106
+ for k in sd.keys():
107
+ if k.startswith(prefix):
108
+ w = sd[k]
109
+ params += w.nelement()
110
+ return params
111
+
112
+ def weight_dtype(sd, prefix=""):
113
+ dtypes = {}
114
+ for k in sd.keys():
115
+ if k.startswith(prefix):
116
+ w = sd[k]
117
+ dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
118
+
119
+ if len(dtypes) == 0:
120
+ return None
121
+
122
+ return max(dtypes, key=dtypes.get)
123
+
124
+ def state_dict_key_replace(state_dict, keys_to_replace):
125
+ for x in keys_to_replace:
126
+ if x in state_dict:
127
+ state_dict[keys_to_replace[x]] = state_dict.pop(x)
128
+ return state_dict
129
+
130
+ def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
131
+ if filter_keys:
132
+ out = {}
133
+ else:
134
+ out = state_dict
135
+ for rp in replace_prefix:
136
+ replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
137
+ for x in replace:
138
+ w = state_dict.pop(x[0])
139
+ out[x[1]] = w
140
+ return out
141
+
142
+
143
+ def transformers_convert(sd, prefix_from, prefix_to, number):
144
+ keys_to_replace = {
145
+ "{}positional_embedding": "{}embeddings.position_embedding.weight",
146
+ "{}token_embedding.weight": "{}embeddings.token_embedding.weight",
147
+ "{}ln_final.weight": "{}final_layer_norm.weight",
148
+ "{}ln_final.bias": "{}final_layer_norm.bias",
149
+ }
150
+
151
+ for k in keys_to_replace:
152
+ x = k.format(prefix_from)
153
+ if x in sd:
154
+ sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
155
+
156
+ resblock_to_replace = {
157
+ "ln_1": "layer_norm1",
158
+ "ln_2": "layer_norm2",
159
+ "mlp.c_fc": "mlp.fc1",
160
+ "mlp.c_proj": "mlp.fc2",
161
+ "attn.out_proj": "self_attn.out_proj",
162
+ }
163
+
164
+ for resblock in range(number):
165
+ for x in resblock_to_replace:
166
+ for y in ["weight", "bias"]:
167
+ k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
168
+ k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
169
+ if k in sd:
170
+ sd[k_to] = sd.pop(k)
171
+
172
+ for y in ["weight", "bias"]:
173
+ k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
174
+ if k_from in sd:
175
+ weights = sd.pop(k_from)
176
+ shape_from = weights.shape[0] // 3
177
+ for x in range(3):
178
+ p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
179
+ k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
180
+ sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
181
+
182
+ return sd
183
+
184
+ def clip_text_transformers_convert(sd, prefix_from, prefix_to):
185
+ sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
186
+
187
+ tp = "{}text_projection.weight".format(prefix_from)
188
+ if tp in sd:
189
+ sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp)
190
+
191
+ tp = "{}text_projection".format(prefix_from)
192
+ if tp in sd:
193
+ sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous()
194
+ return sd
195
+
196
+
197
+ UNET_MAP_ATTENTIONS = {
198
+ "proj_in.weight",
199
+ "proj_in.bias",
200
+ "proj_out.weight",
201
+ "proj_out.bias",
202
+ "norm.weight",
203
+ "norm.bias",
204
+ }
205
+
206
+ TRANSFORMER_BLOCKS = {
207
+ "norm1.weight",
208
+ "norm1.bias",
209
+ "norm2.weight",
210
+ "norm2.bias",
211
+ "norm3.weight",
212
+ "norm3.bias",
213
+ "attn1.to_q.weight",
214
+ "attn1.to_k.weight",
215
+ "attn1.to_v.weight",
216
+ "attn1.to_out.0.weight",
217
+ "attn1.to_out.0.bias",
218
+ "attn2.to_q.weight",
219
+ "attn2.to_k.weight",
220
+ "attn2.to_v.weight",
221
+ "attn2.to_out.0.weight",
222
+ "attn2.to_out.0.bias",
223
+ "ff.net.0.proj.weight",
224
+ "ff.net.0.proj.bias",
225
+ "ff.net.2.weight",
226
+ "ff.net.2.bias",
227
+ }
228
+
229
+ UNET_MAP_RESNET = {
230
+ "in_layers.2.weight": "conv1.weight",
231
+ "in_layers.2.bias": "conv1.bias",
232
+ "emb_layers.1.weight": "time_emb_proj.weight",
233
+ "emb_layers.1.bias": "time_emb_proj.bias",
234
+ "out_layers.3.weight": "conv2.weight",
235
+ "out_layers.3.bias": "conv2.bias",
236
+ "skip_connection.weight": "conv_shortcut.weight",
237
+ "skip_connection.bias": "conv_shortcut.bias",
238
+ "in_layers.0.weight": "norm1.weight",
239
+ "in_layers.0.bias": "norm1.bias",
240
+ "out_layers.0.weight": "norm2.weight",
241
+ "out_layers.0.bias": "norm2.bias",
242
+ }
243
+
244
+ UNET_MAP_BASIC = {
245
+ ("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
246
+ ("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
247
+ ("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
248
+ ("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
249
+ ("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
250
+ ("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
251
+ ("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
252
+ ("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
253
+ ("input_blocks.0.0.weight", "conv_in.weight"),
254
+ ("input_blocks.0.0.bias", "conv_in.bias"),
255
+ ("out.0.weight", "conv_norm_out.weight"),
256
+ ("out.0.bias", "conv_norm_out.bias"),
257
+ ("out.2.weight", "conv_out.weight"),
258
+ ("out.2.bias", "conv_out.bias"),
259
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
260
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
261
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
262
+ ("time_embed.2.bias", "time_embedding.linear_2.bias")
263
+ }
264
+
265
+ def unet_to_diffusers(unet_config):
266
+ if "num_res_blocks" not in unet_config:
267
+ return {}
268
+ num_res_blocks = unet_config["num_res_blocks"]
269
+ channel_mult = unet_config["channel_mult"]
270
+ transformer_depth = unet_config["transformer_depth"][:]
271
+ transformer_depth_output = unet_config["transformer_depth_output"][:]
272
+ num_blocks = len(channel_mult)
273
+
274
+ transformers_mid = unet_config.get("transformer_depth_middle", None)
275
+
276
+ diffusers_unet_map = {}
277
+ for x in range(num_blocks):
278
+ n = 1 + (num_res_blocks[x] + 1) * x
279
+ for i in range(num_res_blocks[x]):
280
+ for b in UNET_MAP_RESNET:
281
+ diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
282
+ num_transformers = transformer_depth.pop(0)
283
+ if num_transformers > 0:
284
+ for b in UNET_MAP_ATTENTIONS:
285
+ diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
286
+ for t in range(num_transformers):
287
+ for b in TRANSFORMER_BLOCKS:
288
+ diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
289
+ n += 1
290
+ for k in ["weight", "bias"]:
291
+ diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
292
+
293
+ i = 0
294
+ for b in UNET_MAP_ATTENTIONS:
295
+ diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
296
+ for t in range(transformers_mid):
297
+ for b in TRANSFORMER_BLOCKS:
298
+ diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
299
+
300
+ for i, n in enumerate([0, 2]):
301
+ for b in UNET_MAP_RESNET:
302
+ diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
303
+
304
+ num_res_blocks = list(reversed(num_res_blocks))
305
+ for x in range(num_blocks):
306
+ n = (num_res_blocks[x] + 1) * x
307
+ l = num_res_blocks[x] + 1
308
+ for i in range(l):
309
+ c = 0
310
+ for b in UNET_MAP_RESNET:
311
+ diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
312
+ c += 1
313
+ num_transformers = transformer_depth_output.pop()
314
+ if num_transformers > 0:
315
+ c += 1
316
+ for b in UNET_MAP_ATTENTIONS:
317
+ diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
318
+ for t in range(num_transformers):
319
+ for b in TRANSFORMER_BLOCKS:
320
+ diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
321
+ if i == l - 1:
322
+ for k in ["weight", "bias"]:
323
+ diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
324
+ n += 1
325
+
326
+ for k in UNET_MAP_BASIC:
327
+ diffusers_unet_map[k[1]] = k[0]
328
+
329
+ return diffusers_unet_map
330
+
331
+ def swap_scale_shift(weight):
332
+ shift, scale = weight.chunk(2, dim=0)
333
+ new_weight = torch.cat([scale, shift], dim=0)
334
+ return new_weight
335
+
336
+ MMDIT_MAP_BASIC = {
337
+ ("context_embedder.bias", "context_embedder.bias"),
338
+ ("context_embedder.weight", "context_embedder.weight"),
339
+ ("t_embedder.mlp.0.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
340
+ ("t_embedder.mlp.0.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
341
+ ("t_embedder.mlp.2.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
342
+ ("t_embedder.mlp.2.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
343
+ ("x_embedder.proj.bias", "pos_embed.proj.bias"),
344
+ ("x_embedder.proj.weight", "pos_embed.proj.weight"),
345
+ ("y_embedder.mlp.0.bias", "time_text_embed.text_embedder.linear_1.bias"),
346
+ ("y_embedder.mlp.0.weight", "time_text_embed.text_embedder.linear_1.weight"),
347
+ ("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
348
+ ("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
349
+ ("pos_embed", "pos_embed.pos_embed"),
350
+ ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
351
+ ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
352
+ ("final_layer.linear.bias", "proj_out.bias"),
353
+ ("final_layer.linear.weight", "proj_out.weight"),
354
+ }
355
+
356
+ MMDIT_MAP_BLOCK = {
357
+ ("context_block.adaLN_modulation.1.bias", "norm1_context.linear.bias"),
358
+ ("context_block.adaLN_modulation.1.weight", "norm1_context.linear.weight"),
359
+ ("context_block.attn.proj.bias", "attn.to_add_out.bias"),
360
+ ("context_block.attn.proj.weight", "attn.to_add_out.weight"),
361
+ ("context_block.mlp.fc1.bias", "ff_context.net.0.proj.bias"),
362
+ ("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
363
+ ("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
364
+ ("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
365
+ ("context_block.attn.ln_q.weight", "attn.norm_added_q.weight"),
366
+ ("context_block.attn.ln_k.weight", "attn.norm_added_k.weight"),
367
+ ("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
368
+ ("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
369
+ ("x_block.attn.proj.bias", "attn.to_out.0.bias"),
370
+ ("x_block.attn.proj.weight", "attn.to_out.0.weight"),
371
+ ("x_block.attn.ln_q.weight", "attn.norm_q.weight"),
372
+ ("x_block.attn.ln_k.weight", "attn.norm_k.weight"),
373
+ ("x_block.attn2.proj.bias", "attn2.to_out.0.bias"),
374
+ ("x_block.attn2.proj.weight", "attn2.to_out.0.weight"),
375
+ ("x_block.attn2.ln_q.weight", "attn2.norm_q.weight"),
376
+ ("x_block.attn2.ln_k.weight", "attn2.norm_k.weight"),
377
+ ("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
378
+ ("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
379
+ ("x_block.mlp.fc2.bias", "ff.net.2.bias"),
380
+ ("x_block.mlp.fc2.weight", "ff.net.2.weight"),
381
+ }
382
+
383
+ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
384
+ key_map = {}
385
+
386
+ depth = mmdit_config.get("depth", 0)
387
+ num_blocks = mmdit_config.get("num_blocks", depth)
388
+ for i in range(num_blocks):
389
+ block_from = "transformer_blocks.{}".format(i)
390
+ block_to = "{}joint_blocks.{}".format(output_prefix, i)
391
+
392
+ offset = depth * 64
393
+
394
+ for end in ("weight", "bias"):
395
+ k = "{}.attn.".format(block_from)
396
+ qkv = "{}.x_block.attn.qkv.{}".format(block_to, end)
397
+ key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
398
+ key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
399
+ key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
400
+
401
+ qkv = "{}.context_block.attn.qkv.{}".format(block_to, end)
402
+ key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, offset))
403
+ key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
404
+ key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
405
+
406
+ k = "{}.attn2.".format(block_from)
407
+ qkv = "{}.x_block.attn2.qkv.{}".format(block_to, end)
408
+ key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
409
+ key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
410
+ key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
411
+
412
+ for k in MMDIT_MAP_BLOCK:
413
+ key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
414
+
415
+ map_basic = MMDIT_MAP_BASIC.copy()
416
+ map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift))
417
+ map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift))
418
+
419
+ for k in map_basic:
420
+ if len(k) > 2:
421
+ key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
422
+ else:
423
+ key_map[k[1]] = "{}{}".format(output_prefix, k[0])
424
+
425
+ return key_map
426
+
427
+ PIXART_MAP_BASIC = {
428
+ ("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
429
+ ("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
430
+ ("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
431
+ ("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
432
+ ("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
433
+ ("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
434
+ ("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
435
+ ("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
436
+ ("x_embedder.proj.weight", "pos_embed.proj.weight"),
437
+ ("x_embedder.proj.bias", "pos_embed.proj.bias"),
438
+ ("y_embedder.y_embedding", "caption_projection.y_embedding"),
439
+ ("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
440
+ ("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
441
+ ("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
442
+ ("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
443
+ ("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
444
+ ("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
445
+ ("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
446
+ ("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
447
+ ("t_block.1.weight", "adaln_single.linear.weight"),
448
+ ("t_block.1.bias", "adaln_single.linear.bias"),
449
+ ("final_layer.linear.weight", "proj_out.weight"),
450
+ ("final_layer.linear.bias", "proj_out.bias"),
451
+ ("final_layer.scale_shift_table", "scale_shift_table"),
452
+ }
453
+
454
+ PIXART_MAP_BLOCK = {
455
+ ("scale_shift_table", "scale_shift_table"),
456
+ ("attn.proj.weight", "attn1.to_out.0.weight"),
457
+ ("attn.proj.bias", "attn1.to_out.0.bias"),
458
+ ("mlp.fc1.weight", "ff.net.0.proj.weight"),
459
+ ("mlp.fc1.bias", "ff.net.0.proj.bias"),
460
+ ("mlp.fc2.weight", "ff.net.2.weight"),
461
+ ("mlp.fc2.bias", "ff.net.2.bias"),
462
+ ("cross_attn.proj.weight" ,"attn2.to_out.0.weight"),
463
+ ("cross_attn.proj.bias" ,"attn2.to_out.0.bias"),
464
+ }
465
+
466
+ def pixart_to_diffusers(mmdit_config, output_prefix=""):
467
+ key_map = {}
468
+
469
+ depth = mmdit_config.get("depth", 0)
470
+ offset = mmdit_config.get("hidden_size", 1152)
471
+
472
+ for i in range(depth):
473
+ block_from = "transformer_blocks.{}".format(i)
474
+ block_to = "{}blocks.{}".format(output_prefix, i)
475
+
476
+ for end in ("weight", "bias"):
477
+ s = "{}.attn1.".format(block_from)
478
+ qkv = "{}.attn.qkv.{}".format(block_to, end)
479
+ key_map["{}to_q.{}".format(s, end)] = (qkv, (0, 0, offset))
480
+ key_map["{}to_k.{}".format(s, end)] = (qkv, (0, offset, offset))
481
+ key_map["{}to_v.{}".format(s, end)] = (qkv, (0, offset * 2, offset))
482
+
483
+ s = "{}.attn2.".format(block_from)
484
+ q = "{}.cross_attn.q_linear.{}".format(block_to, end)
485
+ kv = "{}.cross_attn.kv_linear.{}".format(block_to, end)
486
+
487
+ key_map["{}to_q.{}".format(s, end)] = q
488
+ key_map["{}to_k.{}".format(s, end)] = (kv, (0, 0, offset))
489
+ key_map["{}to_v.{}".format(s, end)] = (kv, (0, offset, offset))
490
+
491
+ for k in PIXART_MAP_BLOCK:
492
+ key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
493
+
494
+ for k in PIXART_MAP_BASIC:
495
+ key_map[k[1]] = "{}{}".format(output_prefix, k[0])
496
+
497
+ return key_map
498
+
499
+ def auraflow_to_diffusers(mmdit_config, output_prefix=""):
500
+ n_double_layers = mmdit_config.get("n_double_layers", 0)
501
+ n_layers = mmdit_config.get("n_layers", 0)
502
+
503
+ key_map = {}
504
+ for i in range(n_layers):
505
+ if i < n_double_layers:
506
+ index = i
507
+ prefix_from = "joint_transformer_blocks"
508
+ prefix_to = "{}double_layers".format(output_prefix)
509
+ block_map = {
510
+ "attn.to_q.weight": "attn.w2q.weight",
511
+ "attn.to_k.weight": "attn.w2k.weight",
512
+ "attn.to_v.weight": "attn.w2v.weight",
513
+ "attn.to_out.0.weight": "attn.w2o.weight",
514
+ "attn.add_q_proj.weight": "attn.w1q.weight",
515
+ "attn.add_k_proj.weight": "attn.w1k.weight",
516
+ "attn.add_v_proj.weight": "attn.w1v.weight",
517
+ "attn.to_add_out.weight": "attn.w1o.weight",
518
+ "ff.linear_1.weight": "mlpX.c_fc1.weight",
519
+ "ff.linear_2.weight": "mlpX.c_fc2.weight",
520
+ "ff.out_projection.weight": "mlpX.c_proj.weight",
521
+ "ff_context.linear_1.weight": "mlpC.c_fc1.weight",
522
+ "ff_context.linear_2.weight": "mlpC.c_fc2.weight",
523
+ "ff_context.out_projection.weight": "mlpC.c_proj.weight",
524
+ "norm1.linear.weight": "modX.1.weight",
525
+ "norm1_context.linear.weight": "modC.1.weight",
526
+ }
527
+ else:
528
+ index = i - n_double_layers
529
+ prefix_from = "single_transformer_blocks"
530
+ prefix_to = "{}single_layers".format(output_prefix)
531
+
532
+ block_map = {
533
+ "attn.to_q.weight": "attn.w1q.weight",
534
+ "attn.to_k.weight": "attn.w1k.weight",
535
+ "attn.to_v.weight": "attn.w1v.weight",
536
+ "attn.to_out.0.weight": "attn.w1o.weight",
537
+ "norm1.linear.weight": "modCX.1.weight",
538
+ "ff.linear_1.weight": "mlp.c_fc1.weight",
539
+ "ff.linear_2.weight": "mlp.c_fc2.weight",
540
+ "ff.out_projection.weight": "mlp.c_proj.weight"
541
+ }
542
+
543
+ for k in block_map:
544
+ key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
545
+
546
+ MAP_BASIC = {
547
+ ("positional_encoding", "pos_embed.pos_embed"),
548
+ ("register_tokens", "register_tokens"),
549
+ ("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
550
+ ("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
551
+ ("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
552
+ ("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
553
+ ("cond_seq_linear.weight", "context_embedder.weight"),
554
+ ("init_x_linear.weight", "pos_embed.proj.weight"),
555
+ ("init_x_linear.bias", "pos_embed.proj.bias"),
556
+ ("final_linear.weight", "proj_out.weight"),
557
+ ("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
558
+ }
559
+
560
+ for k in MAP_BASIC:
561
+ if len(k) > 2:
562
+ key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
563
+ else:
564
+ key_map[k[1]] = "{}{}".format(output_prefix, k[0])
565
+
566
+ return key_map
567
+
568
+ def flux_to_diffusers(mmdit_config, output_prefix=""):
569
+ n_double_layers = mmdit_config.get("depth", 0)
570
+ n_single_layers = mmdit_config.get("depth_single_blocks", 0)
571
+ hidden_size = mmdit_config.get("hidden_size", 0)
572
+
573
+ key_map = {}
574
+ for index in range(n_double_layers):
575
+ prefix_from = "transformer_blocks.{}".format(index)
576
+ prefix_to = "{}double_blocks.{}".format(output_prefix, index)
577
+
578
+ for end in ("weight", "bias"):
579
+ k = "{}.attn.".format(prefix_from)
580
+ qkv = "{}.img_attn.qkv.{}".format(prefix_to, end)
581
+ key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
582
+ key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
583
+ key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
584
+
585
+ k = "{}.attn.".format(prefix_from)
586
+ qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end)
587
+ key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
588
+ key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
589
+ key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
590
+
591
+ block_map = {
592
+ "attn.to_out.0.weight": "img_attn.proj.weight",
593
+ "attn.to_out.0.bias": "img_attn.proj.bias",
594
+ "norm1.linear.weight": "img_mod.lin.weight",
595
+ "norm1.linear.bias": "img_mod.lin.bias",
596
+ "norm1_context.linear.weight": "txt_mod.lin.weight",
597
+ "norm1_context.linear.bias": "txt_mod.lin.bias",
598
+ "attn.to_add_out.weight": "txt_attn.proj.weight",
599
+ "attn.to_add_out.bias": "txt_attn.proj.bias",
600
+ "ff.net.0.proj.weight": "img_mlp.0.weight",
601
+ "ff.net.0.proj.bias": "img_mlp.0.bias",
602
+ "ff.net.2.weight": "img_mlp.2.weight",
603
+ "ff.net.2.bias": "img_mlp.2.bias",
604
+ "ff_context.net.0.proj.weight": "txt_mlp.0.weight",
605
+ "ff_context.net.0.proj.bias": "txt_mlp.0.bias",
606
+ "ff_context.net.2.weight": "txt_mlp.2.weight",
607
+ "ff_context.net.2.bias": "txt_mlp.2.bias",
608
+ "attn.norm_q.weight": "img_attn.norm.query_norm.scale",
609
+ "attn.norm_k.weight": "img_attn.norm.key_norm.scale",
610
+ "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
611
+ "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
612
+ }
613
+
614
+ for k in block_map:
615
+ key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
616
+
617
+ for index in range(n_single_layers):
618
+ prefix_from = "single_transformer_blocks.{}".format(index)
619
+ prefix_to = "{}single_blocks.{}".format(output_prefix, index)
620
+
621
+ for end in ("weight", "bias"):
622
+ k = "{}.attn.".format(prefix_from)
623
+ qkv = "{}.linear1.{}".format(prefix_to, end)
624
+ key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
625
+ key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
626
+ key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
627
+ key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
628
+
629
+ block_map = {
630
+ "norm.linear.weight": "modulation.lin.weight",
631
+ "norm.linear.bias": "modulation.lin.bias",
632
+ "proj_out.weight": "linear2.weight",
633
+ "proj_out.bias": "linear2.bias",
634
+ "attn.norm_q.weight": "norm.query_norm.scale",
635
+ "attn.norm_k.weight": "norm.key_norm.scale",
636
+ }
637
+
638
+ for k in block_map:
639
+ key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
640
+
641
+ MAP_BASIC = {
642
+ ("final_layer.linear.bias", "proj_out.bias"),
643
+ ("final_layer.linear.weight", "proj_out.weight"),
644
+ ("img_in.bias", "x_embedder.bias"),
645
+ ("img_in.weight", "x_embedder.weight"),
646
+ ("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
647
+ ("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
648
+ ("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
649
+ ("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
650
+ ("txt_in.bias", "context_embedder.bias"),
651
+ ("txt_in.weight", "context_embedder.weight"),
652
+ ("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
653
+ ("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
654
+ ("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
655
+ ("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
656
+ ("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
657
+ ("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
658
+ ("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
659
+ ("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
660
+ ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
661
+ ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
662
+ ("pos_embed_input.bias", "controlnet_x_embedder.bias"),
663
+ ("pos_embed_input.weight", "controlnet_x_embedder.weight"),
664
+ }
665
+
666
+ for k in MAP_BASIC:
667
+ if len(k) > 2:
668
+ key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
669
+ else:
670
+ key_map[k[1]] = "{}{}".format(output_prefix, k[0])
671
+
672
+ return key_map
673
+
674
+ def repeat_to_batch_size(tensor, batch_size, dim=0):
675
+ if tensor.shape[dim] > batch_size:
676
+ return tensor.narrow(dim, 0, batch_size)
677
+ elif tensor.shape[dim] < batch_size:
678
+ return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size)
679
+ return tensor
680
+
681
+ def resize_to_batch_size(tensor, batch_size):
682
+ in_batch_size = tensor.shape[0]
683
+ if in_batch_size == batch_size:
684
+ return tensor
685
+
686
+ if batch_size <= 1:
687
+ return tensor[:batch_size]
688
+
689
+ output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device)
690
+ if batch_size < in_batch_size:
691
+ scale = (in_batch_size - 1) / (batch_size - 1)
692
+ for i in range(batch_size):
693
+ output[i] = tensor[min(round(i * scale), in_batch_size - 1)]
694
+ else:
695
+ scale = in_batch_size / batch_size
696
+ for i in range(batch_size):
697
+ output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]
698
+
699
+ return output
700
+
701
+ def resize_list_to_batch_size(l, batch_size):
702
+ in_batch_size = len(l)
703
+ if in_batch_size == batch_size or in_batch_size == 0:
704
+ return l
705
+
706
+ if batch_size <= 1:
707
+ return l[:batch_size]
708
+
709
+ output = []
710
+ if batch_size < in_batch_size:
711
+ scale = (in_batch_size - 1) / (batch_size - 1)
712
+ for i in range(batch_size):
713
+ output.append(l[min(round(i * scale), in_batch_size - 1)])
714
+ else:
715
+ scale = in_batch_size / batch_size
716
+ for i in range(batch_size):
717
+ output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)])
718
+
719
+ return output
720
+
721
+ def convert_sd_to(state_dict, dtype):
722
+ keys = list(state_dict.keys())
723
+ for k in keys:
724
+ state_dict[k] = state_dict[k].to(dtype)
725
+ return state_dict
726
+
727
+ def safetensors_header(safetensors_path, max_size=100*1024*1024):
728
+ with open(safetensors_path, "rb") as f:
729
+ header = f.read(8)
730
+ length_of_header = struct.unpack('<Q', header)[0]
731
+ if length_of_header > max_size:
732
+ return None
733
+ return f.read(length_of_header)
734
+
735
+ def set_attr(obj, attr, value):
736
+ attrs = attr.split(".")
737
+ for name in attrs[:-1]:
738
+ obj = getattr(obj, name)
739
+ prev = getattr(obj, attrs[-1])
740
+ setattr(obj, attrs[-1], value)
741
+ return prev
742
+
743
+ def set_attr_param(obj, attr, value):
744
+ return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
745
+
746
+ def copy_to_param(obj, attr, value):
747
+ # inplace update tensor instead of replacing it
748
+ attrs = attr.split(".")
749
+ for name in attrs[:-1]:
750
+ obj = getattr(obj, name)
751
+ prev = getattr(obj, attrs[-1])
752
+ prev.data.copy_(value)
753
+
754
+ def get_attr(obj, attr: str):
755
+ """Retrieves a nested attribute from an object using dot notation.
756
+
757
+ Args:
758
+ obj: The object to get the attribute from
759
+ attr (str): The attribute path using dot notation (e.g. "model.layer.weight")
760
+
761
+ Returns:
762
+ The value of the requested attribute
763
+
764
+ Example:
765
+ model = MyModel()
766
+ weight = get_attr(model, "layer1.conv.weight")
767
+ # Equivalent to: model.layer1.conv.weight
768
+
769
+ Important:
770
+ Always prefer `comfy.model_patcher.ModelPatcher.get_model_object` when
771
+ accessing nested model objects under `ModelPatcher.model`.
772
+ """
773
+ attrs = attr.split(".")
774
+ for name in attrs:
775
+ obj = getattr(obj, name)
776
+ return obj
777
+
778
+ def bislerp(samples, width, height):
779
+ def slerp(b1, b2, r):
780
+ '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
781
+
782
+ c = b1.shape[-1]
783
+
784
+ #norms
785
+ b1_norms = torch.norm(b1, dim=-1, keepdim=True)
786
+ b2_norms = torch.norm(b2, dim=-1, keepdim=True)
787
+
788
+ #normalize
789
+ b1_normalized = b1 / b1_norms
790
+ b2_normalized = b2 / b2_norms
791
+
792
+ #zero when norms are zero
793
+ b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
794
+ b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
795
+
796
+ #slerp
797
+ dot = (b1_normalized*b2_normalized).sum(1)
798
+ omega = torch.acos(dot)
799
+ so = torch.sin(omega)
800
+
801
+ #technically not mathematically correct, but more pleasing?
802
+ res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
803
+ res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
804
+
805
+ #edge cases for same or polar opposites
806
+ res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
807
+ res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
808
+ return res
809
+
810
+ def generate_bilinear_data(length_old, length_new, device):
811
+ coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
812
+ coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
813
+ ratios = coords_1 - coords_1.floor()
814
+ coords_1 = coords_1.to(torch.int64)
815
+
816
+ coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
817
+ coords_2[:,:,:,-1] -= 1
818
+ coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
819
+ coords_2 = coords_2.to(torch.int64)
820
+ return ratios, coords_1, coords_2
821
+
822
+ orig_dtype = samples.dtype
823
+ samples = samples.float()
824
+ n,c,h,w = samples.shape
825
+ h_new, w_new = (height, width)
826
+
827
+ #linear w
828
+ ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
829
+ coords_1 = coords_1.expand((n, c, h, -1))
830
+ coords_2 = coords_2.expand((n, c, h, -1))
831
+ ratios = ratios.expand((n, 1, h, -1))
832
+
833
+ pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
834
+ pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
835
+ ratios = ratios.movedim(1, -1).reshape((-1,1))
836
+
837
+ result = slerp(pass_1, pass_2, ratios)
838
+ result = result.reshape(n, h, w_new, c).movedim(-1, 1)
839
+
840
+ #linear h
841
+ ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
842
+ coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
843
+ coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
844
+ ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
845
+
846
+ pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
847
+ pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
848
+ ratios = ratios.movedim(1, -1).reshape((-1,1))
849
+
850
+ result = slerp(pass_1, pass_2, ratios)
851
+ result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
852
+ return result.to(orig_dtype)
853
+
854
+ def lanczos(samples, width, height):
855
+ images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
856
+ images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
857
+ images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
858
+ result = torch.stack(images)
859
+ return result.to(samples.device, samples.dtype)
860
+
861
+ def common_upscale(samples, width, height, upscale_method, crop):
862
+ orig_shape = tuple(samples.shape)
863
+ if len(orig_shape) > 4:
864
+ samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
865
+ samples = samples.movedim(2, 1)
866
+ samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
867
+ if crop == "center":
868
+ old_width = samples.shape[-1]
869
+ old_height = samples.shape[-2]
870
+ old_aspect = old_width / old_height
871
+ new_aspect = width / height
872
+ x = 0
873
+ y = 0
874
+ if old_aspect > new_aspect:
875
+ x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
876
+ elif old_aspect < new_aspect:
877
+ y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
878
+ s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
879
+ else:
880
+ s = samples
881
+
882
+ if upscale_method == "bislerp":
883
+ out = bislerp(s, width, height)
884
+ elif upscale_method == "lanczos":
885
+ out = lanczos(s, width, height)
886
+ else:
887
+ out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
888
+
889
+ if len(orig_shape) == 4:
890
+ return out
891
+
892
+ out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
893
+ return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
894
+
895
+ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
896
+ rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
897
+ cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
898
+ return rows * cols
899
+
900
+ @torch.inference_mode()
901
+ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
902
+ dims = len(tile)
903
+
904
+ if not (isinstance(upscale_amount, (tuple, list))):
905
+ upscale_amount = [upscale_amount] * dims
906
+
907
+ if not (isinstance(overlap, (tuple, list))):
908
+ overlap = [overlap] * dims
909
+
910
+ if index_formulas is None:
911
+ index_formulas = upscale_amount
912
+
913
+ if not (isinstance(index_formulas, (tuple, list))):
914
+ index_formulas = [index_formulas] * dims
915
+
916
+ def get_upscale(dim, val):
917
+ up = upscale_amount[dim]
918
+ if callable(up):
919
+ return up(val)
920
+ else:
921
+ return up * val
922
+
923
+ def get_downscale(dim, val):
924
+ up = upscale_amount[dim]
925
+ if callable(up):
926
+ return up(val)
927
+ else:
928
+ return val / up
929
+
930
+ def get_upscale_pos(dim, val):
931
+ up = index_formulas[dim]
932
+ if callable(up):
933
+ return up(val)
934
+ else:
935
+ return up * val
936
+
937
+ def get_downscale_pos(dim, val):
938
+ up = index_formulas[dim]
939
+ if callable(up):
940
+ return up(val)
941
+ else:
942
+ return val / up
943
+
944
+ if downscale:
945
+ get_scale = get_downscale
946
+ get_pos = get_downscale_pos
947
+ else:
948
+ get_scale = get_upscale
949
+ get_pos = get_upscale_pos
950
+
951
+ def mult_list_upscale(a):
952
+ out = []
953
+ for i in range(len(a)):
954
+ out.append(round(get_scale(i, a[i])))
955
+ return out
956
+
957
+ output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
958
+
959
+ for b in range(samples.shape[0]):
960
+ s = samples[b:b+1]
961
+
962
+ # handle entire input fitting in a single tile
963
+ if all(s.shape[d+2] <= tile[d] for d in range(dims)):
964
+ output[b:b+1] = function(s).to(output_device)
965
+ if pbar is not None:
966
+ pbar.update(1)
967
+ continue
968
+
969
+ out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
970
+ out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
971
+
972
+ positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
973
+
974
+ for it in itertools.product(*positions):
975
+ s_in = s
976
+ upscaled = []
977
+
978
+ for d in range(dims):
979
+ pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
980
+ l = min(tile[d], s.shape[d + 2] - pos)
981
+ s_in = s_in.narrow(d + 2, pos, l)
982
+ upscaled.append(round(get_pos(d, pos)))
983
+
984
+ ps = function(s_in).to(output_device)
985
+ mask = torch.ones_like(ps)
986
+
987
+ for d in range(2, dims + 2):
988
+ feather = round(get_scale(d - 2, overlap[d - 2]))
989
+ if feather >= mask.shape[d]:
990
+ continue
991
+ for t in range(feather):
992
+ a = (t + 1) / feather
993
+ mask.narrow(d, t, 1).mul_(a)
994
+ mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
995
+
996
+ o = out
997
+ o_d = out_div
998
+ for d in range(dims):
999
+ o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
1000
+ o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
1001
+
1002
+ o.add_(ps * mask)
1003
+ o_d.add_(mask)
1004
+
1005
+ if pbar is not None:
1006
+ pbar.update(1)
1007
+
1008
+ output[b:b+1] = out/out_div
1009
+ return output
1010
+
1011
+ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
1012
+ return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
1013
+
1014
+ PROGRESS_BAR_ENABLED = True
1015
+ def set_progress_bar_enabled(enabled):
1016
+ global PROGRESS_BAR_ENABLED
1017
+ PROGRESS_BAR_ENABLED = enabled
1018
+
1019
+ PROGRESS_BAR_HOOK = None
1020
+ def set_progress_bar_global_hook(function):
1021
+ global PROGRESS_BAR_HOOK
1022
+ PROGRESS_BAR_HOOK = function
1023
+
1024
+ class ProgressBar:
1025
+ def __init__(self, total, node_id=None):
1026
+ global PROGRESS_BAR_HOOK
1027
+ self.total = total
1028
+ self.current = 0
1029
+ self.hook = PROGRESS_BAR_HOOK
1030
+ self.node_id = node_id
1031
+
1032
+ def update_absolute(self, value, total=None, preview=None):
1033
+ if total is not None:
1034
+ self.total = total
1035
+ if value > self.total:
1036
+ value = self.total
1037
+ self.current = value
1038
+ if self.hook is not None:
1039
+ self.hook(self.current, self.total, preview, node_id=self.node_id)
1040
+
1041
+ def update(self, value):
1042
+ self.update_absolute(self.current + value)
1043
+
1044
+ def reshape_mask(input_mask, output_shape):
1045
+ dims = len(output_shape) - 2
1046
+
1047
+ if dims == 1:
1048
+ scale_mode = "linear"
1049
+
1050
+ if dims == 2:
1051
+ input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
1052
+ scale_mode = "bilinear"
1053
+
1054
+ if dims == 3:
1055
+ if len(input_mask.shape) < 5:
1056
+ input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
1057
+ scale_mode = "trilinear"
1058
+
1059
+ mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
1060
+ if mask.shape[1] < output_shape[1]:
1061
+ mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
1062
+ mask = repeat_to_batch_size(mask, output_shape[0])
1063
+ return mask
1064
+
1065
+ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
1066
+ hi, wi = img_size_in
1067
+ ho, wo = img_size_out
1068
+ # if it's already the correct size, no need to do anything
1069
+ if (hi, wi) == (ho, wo):
1070
+ return mask
1071
+ if mask.ndim == 2:
1072
+ mask = mask.unsqueeze(0)
1073
+ if mask.ndim != 3:
1074
+ raise ValueError(f"Got a mask of shape {list(mask.shape)}, expected [b, q, k] or [q, k]")
1075
+ txt_tokens = mask.shape[1] - (hi * wi)
1076
+ # quadrants of the mask
1077
+ txt_to_txt = mask[:, :txt_tokens, :txt_tokens]
1078
+ txt_to_img = mask[:, :txt_tokens, txt_tokens:]
1079
+ img_to_img = mask[:, txt_tokens:, txt_tokens:]
1080
+ img_to_txt = mask[:, txt_tokens:, :txt_tokens]
1081
+
1082
+ # convert to 1d x 2d, interpolate, then back to 1d x 1d
1083
+ txt_to_img = rearrange (txt_to_img, "b t (h w) -> b t h w", h=hi, w=wi)
1084
+ txt_to_img = interpolate(txt_to_img, size=img_size_out, mode="bilinear")
1085
+ txt_to_img = rearrange (txt_to_img, "b t h w -> b t (h w)")
1086
+ # this one is hard because we have to do it twice
1087
+ # convert to 1d x 2d, interpolate, then to 2d x 1d, interpolate, then 1d x 1d
1088
+ img_to_img = rearrange (img_to_img, "b hw (h w) -> b hw h w", h=hi, w=wi)
1089
+ img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
1090
+ img_to_img = rearrange (img_to_img, "b (hk wk) hq wq -> b (hq wq) hk wk", hk=hi, wk=wi)
1091
+ img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
1092
+ img_to_img = rearrange (img_to_img, "b (hq wq) hk wk -> b (hk wk) (hq wq)", hq=ho, wq=wo)
1093
+ # convert to 2d x 1d, interpolate, then back to 1d x 1d
1094
+ img_to_txt = rearrange (img_to_txt, "b (h w) t -> b t h w", h=hi, w=wi)
1095
+ img_to_txt = interpolate(img_to_txt, size=img_size_out, mode="bilinear")
1096
+ img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t")
1097
+
1098
+ # reassemble the mask from blocks
1099
+ out = torch.cat([
1100
+ torch.cat([txt_to_txt, txt_to_img], dim=2),
1101
+ torch.cat([img_to_txt, img_to_img], dim=2)],
1102
+ dim=1
1103
+ )
1104
+ return out
ComfyUI/comfy/weight_adapter/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import WeightAdapterBase, WeightAdapterTrainBase
2
+ from .lora import LoRAAdapter
3
+ from .loha import LoHaAdapter
4
+ from .lokr import LoKrAdapter
5
+ from .glora import GLoRAAdapter
6
+ from .oft import OFTAdapter
7
+ from .boft import BOFTAdapter
8
+
9
+
10
+ adapters: list[type[WeightAdapterBase]] = [
11
+ LoRAAdapter,
12
+ LoHaAdapter,
13
+ LoKrAdapter,
14
+ GLoRAAdapter,
15
+ OFTAdapter,
16
+ BOFTAdapter,
17
+ ]
18
+ adapter_maps: dict[str, type[WeightAdapterBase]] = {
19
+ "LoRA": LoRAAdapter,
20
+ "LoHa": LoHaAdapter,
21
+ "LoKr": LoKrAdapter,
22
+ "OFT": OFTAdapter,
23
+ ## We disable not implemented algo for now
24
+ # "GLoRA": GLoRAAdapter,
25
+ # "BOFT": BOFTAdapter,
26
+ }
27
+
28
+
29
+ __all__ = [
30
+ "WeightAdapterBase",
31
+ "WeightAdapterTrainBase",
32
+ "adapters",
33
+ "adapter_maps",
34
+ ] + [a.__name__ for a in adapters]
ComfyUI/comfy/weight_adapter/boft.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import comfy.model_management
6
+ from .base import WeightAdapterBase, weight_decompose
7
+
8
+
9
+ class BOFTAdapter(WeightAdapterBase):
10
+ name = "boft"
11
+
12
+ def __init__(self, loaded_keys, weights):
13
+ self.loaded_keys = loaded_keys
14
+ self.weights = weights
15
+
16
+ @classmethod
17
+ def load(
18
+ cls,
19
+ x: str,
20
+ lora: dict[str, torch.Tensor],
21
+ alpha: float,
22
+ dora_scale: torch.Tensor,
23
+ loaded_keys: set[str] = None,
24
+ ) -> Optional["BOFTAdapter"]:
25
+ if loaded_keys is None:
26
+ loaded_keys = set()
27
+ blocks_name = "{}.oft_blocks".format(x)
28
+ rescale_name = "{}.rescale".format(x)
29
+
30
+ blocks = None
31
+ if blocks_name in lora.keys():
32
+ blocks = lora[blocks_name]
33
+ if blocks.ndim == 4:
34
+ loaded_keys.add(blocks_name)
35
+ else:
36
+ blocks = None
37
+ if blocks is None:
38
+ return None
39
+
40
+ rescale = None
41
+ if rescale_name in lora.keys():
42
+ rescale = lora[rescale_name]
43
+ loaded_keys.add(rescale_name)
44
+
45
+ weights = (blocks, rescale, alpha, dora_scale)
46
+ return cls(loaded_keys, weights)
47
+
48
+ def calculate_weight(
49
+ self,
50
+ weight,
51
+ key,
52
+ strength,
53
+ strength_model,
54
+ offset,
55
+ function,
56
+ intermediate_dtype=torch.float32,
57
+ original_weight=None,
58
+ ):
59
+ v = self.weights
60
+ blocks = v[0]
61
+ rescale = v[1]
62
+ alpha = v[2]
63
+ dora_scale = v[3]
64
+
65
+ blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
66
+ if rescale is not None:
67
+ rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
68
+
69
+ boft_m, block_num, boft_b, *_ = blocks.shape
70
+
71
+ try:
72
+ # Get r
73
+ I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
74
+ # for Q = -Q^T
75
+ q = blocks - blocks.transpose(-1, -2)
76
+ normed_q = q
77
+ if alpha > 0: # alpha in boft/bboft is for constraint
78
+ q_norm = torch.norm(q) + 1e-8
79
+ if q_norm > alpha:
80
+ normed_q = q * alpha / q_norm
81
+ # use float() to prevent unsupported type in .inverse()
82
+ r = (I + normed_q) @ (I - normed_q).float().inverse()
83
+ r = r.to(weight)
84
+ inp = org = weight
85
+
86
+ r_b = boft_b//2
87
+ for i in range(boft_m):
88
+ bi = r[i]
89
+ g = 2
90
+ k = 2**i * r_b
91
+ if strength != 1:
92
+ bi = bi * strength + (1-strength) * I
93
+ inp = (
94
+ inp.unflatten(0, (-1, g, k))
95
+ .transpose(1, 2)
96
+ .flatten(0, 2)
97
+ .unflatten(0, (-1, boft_b))
98
+ )
99
+ inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
100
+ inp = (
101
+ inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
102
+ )
103
+
104
+ if rescale is not None:
105
+ inp = inp * rescale
106
+
107
+ lora_diff = inp - org
108
+ lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
109
+ if dora_scale is not None:
110
+ weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
111
+ else:
112
+ weight += function((strength * lora_diff).type(weight.dtype))
113
+ except Exception as e:
114
+ logging.error("ERROR {} {} {}".format(self.name, key, e))
115
+ return weight
ComfyUI/comfy_api/feature_flags.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feature flags module for ComfyUI WebSocket protocol negotiation.
3
+
4
+ This module handles capability negotiation between frontend and backend,
5
+ allowing graceful protocol evolution while maintaining backward compatibility.
6
+ """
7
+
8
+ from typing import Any, Dict
9
+
10
+ from comfy.cli_args import args
11
+
12
+ # Default server capabilities
13
+ SERVER_FEATURE_FLAGS: Dict[str, Any] = {
14
+ "supports_preview_metadata": True,
15
+ "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
16
+ }
17
+
18
+
19
+ def get_connection_feature(
20
+ sockets_metadata: Dict[str, Dict[str, Any]],
21
+ sid: str,
22
+ feature_name: str,
23
+ default: Any = False
24
+ ) -> Any:
25
+ """
26
+ Get a feature flag value for a specific connection.
27
+
28
+ Args:
29
+ sockets_metadata: Dictionary of socket metadata
30
+ sid: Session ID of the connection
31
+ feature_name: Name of the feature to check
32
+ default: Default value if feature not found
33
+
34
+ Returns:
35
+ Feature value or default if not found
36
+ """
37
+ if sid not in sockets_metadata:
38
+ return default
39
+
40
+ return sockets_metadata[sid].get("feature_flags", {}).get(feature_name, default)
41
+
42
+
43
+ def supports_feature(
44
+ sockets_metadata: Dict[str, Dict[str, Any]],
45
+ sid: str,
46
+ feature_name: str
47
+ ) -> bool:
48
+ """
49
+ Check if a connection supports a specific feature.
50
+
51
+ Args:
52
+ sockets_metadata: Dictionary of socket metadata
53
+ sid: Session ID of the connection
54
+ feature_name: Name of the feature to check
55
+
56
+ Returns:
57
+ Boolean indicating if feature is supported
58
+ """
59
+ return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
60
+
61
+
62
+ def get_server_features() -> Dict[str, Any]:
63
+ """
64
+ Get the server's feature flags.
65
+
66
+ Returns:
67
+ Dictionary of server feature flags
68
+ """
69
+ return SERVER_FEATURE_FLAGS.copy()
ComfyUI/comfy_api_nodes/README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ComfyUI API Nodes
2
+
3
+ ## Introduction
4
+
5
+ Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview).
6
+
7
+ ## Development
8
+
9
+ While developing, you should be testing against the Staging environment. To test against staging:
10
+
11
+ **Install ComfyUI_frontend**
12
+
13
+ Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication.
14
+
15
+ > **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication.
16
+
17
+ ```bash
18
+ python run main.py --comfy-api-base https://stagingapi.comfy.org
19
+ ```
20
+
21
+ To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging.
22
+
23
+ API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
24
+
25
+ ### Redocly Instructions
26
+
27
+ **Tip**
28
+ When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet.
29
+
30
+ Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
31
+
32
+ ```bash
33
+ # Download the OpenAPI file from staging server.
34
+ curl -o openapi.yaml https://stagingapi.comfy.org/openapi
35
+
36
+ # Filter out unneeded API definitions.
37
+ npm install -g @redocly/cli
38
+ redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components
39
+
40
+ # Generate the pydantic datamodels for validation.
41
+ datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
42
+
43
+ ```
44
+
45
+
46
+ # Merging to Master
47
+
48
+ Before merging to comfyanonymous/ComfyUI master, follow these steps:
49
+
50
+ 1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes.
51
+ 1. Make sure the ComfyUI API is deployed to prod with your changes.
52
+ 1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file.
53
+
54
+ ```bash
55
+ # Download the OpenAPI file from prod server.
56
+ curl -o openapi.yaml https://api.comfy.org/openapi
57
+
58
+ # Filter out unneeded API definitions.
59
+ npm install -g @redocly/cli
60
+ redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
61
+
62
+ # Generate the pydantic datamodels for validation.
63
+ datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
64
+
65
+ ```