Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- ComfyUI/comfy/image_encoders/dino2.py +141 -0
- ComfyUI/comfy/k_diffusion/sa_solver.py +121 -0
- ComfyUI/comfy/k_diffusion/sampling.py +1761 -0
- ComfyUI/comfy/ldm/common_dit.py +16 -0
- ComfyUI/comfy/model_detection.py +910 -0
- ComfyUI/comfy/model_patcher.py +1215 -0
- ComfyUI/comfy/ops.py +441 -0
- ComfyUI/comfy/patcher_extension.py +157 -0
- ComfyUI/comfy/sample.py +52 -0
- ComfyUI/comfy/samplers.py +1143 -0
- ComfyUI/comfy/sd1_clip.py +687 -0
- ComfyUI/comfy/sd1_clip_config.json +25 -0
- ComfyUI/comfy/sd1_tokenizer/merges.txt +0 -0
- ComfyUI/comfy/sd1_tokenizer/tokenizer_config.json +34 -0
- ComfyUI/comfy/sd1_tokenizer/vocab.json +0 -0
- ComfyUI/comfy/supported_models.py +1235 -0
- ComfyUI/comfy/supported_models_base.py +119 -0
- ComfyUI/comfy/t2i_adapter/adapter.py +299 -0
- ComfyUI/comfy/taesd/taesd.py +79 -0
- ComfyUI/comfy/text_encoders/ace.py +153 -0
- ComfyUI/comfy/text_encoders/ace_text_cleaners.py +395 -0
- ComfyUI/comfy/text_encoders/aura_t5.py +22 -0
- ComfyUI/comfy/text_encoders/bert.py +143 -0
- ComfyUI/comfy/text_encoders/cosmos.py +42 -0
- ComfyUI/comfy/text_encoders/flux.py +70 -0
- ComfyUI/comfy/text_encoders/genmo.py +38 -0
- ComfyUI/comfy/text_encoders/hidream.py +155 -0
- ComfyUI/comfy/text_encoders/hunyuan_video.py +159 -0
- ComfyUI/comfy/text_encoders/hydit.py +81 -0
- ComfyUI/comfy/text_encoders/hydit_clip.json +35 -0
- ComfyUI/comfy/text_encoders/llama.py +358 -0
- ComfyUI/comfy/text_encoders/long_clipl.py +27 -0
- ComfyUI/comfy/text_encoders/lt.py +18 -0
- ComfyUI/comfy/text_encoders/lumina2.py +39 -0
- ComfyUI/comfy/text_encoders/mt5_config_xl.json +22 -0
- ComfyUI/comfy/text_encoders/omnigen2.py +44 -0
- ComfyUI/comfy/text_encoders/pixart_t5.py +42 -0
- ComfyUI/comfy/text_encoders/sd2_clip.py +23 -0
- ComfyUI/comfy/text_encoders/sd2_clip_config.json +23 -0
- ComfyUI/comfy/text_encoders/sd3_clip.py +166 -0
- ComfyUI/comfy/text_encoders/t5.py +249 -0
- ComfyUI/comfy/text_encoders/t5_config_base.json +22 -0
- ComfyUI/comfy/text_encoders/t5_config_xxl.json +22 -0
- ComfyUI/comfy/text_encoders/t5_old_config_xxl.json +22 -0
- ComfyUI/comfy/text_encoders/umt5_config_base.json +22 -0
- ComfyUI/comfy/utils.py +1104 -0
- ComfyUI/comfy/weight_adapter/__init__.py +34 -0
- ComfyUI/comfy/weight_adapter/boft.py +115 -0
- ComfyUI/comfy_api/feature_flags.py +69 -0
- ComfyUI/comfy_api_nodes/README.md +65 -0
ComfyUI/comfy/image_encoders/dino2.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from comfy.text_encoders.bert import BertAttention
|
| 3 |
+
import comfy.model_management
|
| 4 |
+
from comfy.ldm.modules.attention import optimized_attention_for_device
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Dino2AttentionOutput(torch.nn.Module):
|
| 8 |
+
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
return self.dense(x)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Dino2AttentionBlock(torch.nn.Module):
|
| 17 |
+
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
| 20 |
+
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
| 21 |
+
|
| 22 |
+
def forward(self, x, mask, optimized_attention):
|
| 23 |
+
return self.output(self.attention(x, mask, optimized_attention))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class LayerScale(torch.nn.Module):
|
| 27 |
+
def __init__(self, dim, dtype, device, operations):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SwiGLUFFN(torch.nn.Module):
|
| 36 |
+
def __init__(self, dim, dtype, device, operations):
|
| 37 |
+
super().__init__()
|
| 38 |
+
in_features = out_features = dim
|
| 39 |
+
hidden_features = int(dim * 4)
|
| 40 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 41 |
+
|
| 42 |
+
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
|
| 43 |
+
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
x = self.weights_in(x)
|
| 47 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 48 |
+
x = torch.nn.functional.silu(x1) * x2
|
| 49 |
+
return self.weights_out(x)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Dino2Block(torch.nn.Module):
|
| 53 |
+
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
| 56 |
+
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
| 57 |
+
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
| 58 |
+
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
| 59 |
+
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
| 60 |
+
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
| 61 |
+
|
| 62 |
+
def forward(self, x, optimized_attention):
|
| 63 |
+
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
| 64 |
+
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
| 65 |
+
return x
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Dino2Encoder(torch.nn.Module):
|
| 69 |
+
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
| 72 |
+
|
| 73 |
+
def forward(self, x, intermediate_output=None):
|
| 74 |
+
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
| 75 |
+
|
| 76 |
+
if intermediate_output is not None:
|
| 77 |
+
if intermediate_output < 0:
|
| 78 |
+
intermediate_output = len(self.layer) + intermediate_output
|
| 79 |
+
|
| 80 |
+
intermediate = None
|
| 81 |
+
for i, l in enumerate(self.layer):
|
| 82 |
+
x = l(x, optimized_attention)
|
| 83 |
+
if i == intermediate_output:
|
| 84 |
+
intermediate = x.clone()
|
| 85 |
+
return x, intermediate
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class Dino2PatchEmbeddings(torch.nn.Module):
|
| 89 |
+
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.projection = operations.Conv2d(
|
| 92 |
+
in_channels=num_channels,
|
| 93 |
+
out_channels=dim,
|
| 94 |
+
kernel_size=patch_size,
|
| 95 |
+
stride=patch_size,
|
| 96 |
+
bias=True,
|
| 97 |
+
dtype=dtype,
|
| 98 |
+
device=device
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def forward(self, pixel_values):
|
| 102 |
+
return self.projection(pixel_values).flatten(2).transpose(1, 2)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Dino2Embeddings(torch.nn.Module):
|
| 106 |
+
def __init__(self, dim, dtype, device, operations):
|
| 107 |
+
super().__init__()
|
| 108 |
+
patch_size = 14
|
| 109 |
+
image_size = 518
|
| 110 |
+
|
| 111 |
+
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
| 112 |
+
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
| 113 |
+
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
|
| 114 |
+
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
| 115 |
+
|
| 116 |
+
def forward(self, pixel_values):
|
| 117 |
+
x = self.patch_embeddings(pixel_values)
|
| 118 |
+
# TODO: mask_token?
|
| 119 |
+
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
| 120 |
+
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Dinov2Model(torch.nn.Module):
|
| 125 |
+
def __init__(self, config_dict, dtype, device, operations):
|
| 126 |
+
super().__init__()
|
| 127 |
+
num_layers = config_dict["num_hidden_layers"]
|
| 128 |
+
dim = config_dict["hidden_size"]
|
| 129 |
+
heads = config_dict["num_attention_heads"]
|
| 130 |
+
layer_norm_eps = config_dict["layer_norm_eps"]
|
| 131 |
+
|
| 132 |
+
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
| 133 |
+
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
| 134 |
+
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
| 135 |
+
|
| 136 |
+
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
| 137 |
+
x = self.embeddings(pixel_values)
|
| 138 |
+
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
| 139 |
+
x = self.layernorm(x)
|
| 140 |
+
pooled_output = x[:, 0, :]
|
| 141 |
+
return x, i, pooled_output, None
|
ComfyUI/comfy/k_diffusion/sa_solver.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
|
| 2 |
+
# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
|
| 3 |
+
# Codebase ref: https://github.com/scxue/SA-Solver
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import Union, Callable
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
|
| 11 |
+
"""Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
|
| 12 |
+
|
| 13 |
+
Integral of exp((1 + tau^2) * x) * x^p dx
|
| 14 |
+
= product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
|
| 15 |
+
with base case p=0 where integral equals product_terms[0].
|
| 16 |
+
|
| 17 |
+
where
|
| 18 |
+
product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
|
| 19 |
+
|
| 20 |
+
Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
|
| 21 |
+
Return coefficients used by the SA-Solver in data prediction mode.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
s: Start time s.
|
| 25 |
+
t: End time t.
|
| 26 |
+
solver_order: Current order of the solver.
|
| 27 |
+
tau_t: Stochastic strength parameter in the SDE.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,).
|
| 31 |
+
"""
|
| 32 |
+
tau_mul = 1 + tau_t ** 2
|
| 33 |
+
h = t - s
|
| 34 |
+
p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
|
| 35 |
+
|
| 36 |
+
# product_terms after factoring out exp((1 + tau^2) * t)
|
| 37 |
+
# Includes (1 + tau^2) factor from outside the integral
|
| 38 |
+
product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
|
| 39 |
+
|
| 40 |
+
# Lower triangular recursive coefficient matrix
|
| 41 |
+
# Accumulates recursive coefficients based on p / (1 + tau^2)
|
| 42 |
+
recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
|
| 43 |
+
log_factorial = (p + 1).lgamma()
|
| 44 |
+
recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
|
| 45 |
+
if tau_t > 0:
|
| 46 |
+
recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
|
| 47 |
+
signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
|
| 48 |
+
recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
|
| 49 |
+
|
| 50 |
+
return recursive_coeff_mat @ product_terms_factored
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
|
| 54 |
+
"""Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
|
| 55 |
+
tau_mul = 1 + tau_t ** 2
|
| 56 |
+
h = lambda_t - lambda_s
|
| 57 |
+
alpha_t = sigma_next * lambda_t.exp()
|
| 58 |
+
if is_corrector_step:
|
| 59 |
+
# Simplified 1-step (order-2) corrector
|
| 60 |
+
b_1 = alpha_t * (0.5 * tau_mul * h)
|
| 61 |
+
b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
|
| 62 |
+
else:
|
| 63 |
+
# Simplified 2-step predictor
|
| 64 |
+
b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
|
| 65 |
+
b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
|
| 66 |
+
return torch.stack([b_2, b_1])
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
|
| 70 |
+
"""Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
|
| 71 |
+
|
| 72 |
+
The solver order corresponds to the number of input lambdas (half-logSNR points).
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
sigma_next: Sigma at end time t.
|
| 76 |
+
curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
|
| 77 |
+
lambda_s: Lambda at start time s.
|
| 78 |
+
lambda_t: Lambda at end time t.
|
| 79 |
+
tau_t: Stochastic strength parameter in the SDE.
|
| 80 |
+
simple_order_2: Whether to enable the simple order-2 scheme.
|
| 81 |
+
is_corrector_step: Flag for corrector step in simple order-2 mode.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
|
| 85 |
+
"""
|
| 86 |
+
num_timesteps = curr_lambdas.shape[0]
|
| 87 |
+
|
| 88 |
+
if simple_order_2 and num_timesteps == 2:
|
| 89 |
+
return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
|
| 90 |
+
|
| 91 |
+
# Compute coefficients by solving a linear system from Lagrange basis interpolation
|
| 92 |
+
exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
|
| 93 |
+
vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
|
| 94 |
+
lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
|
| 95 |
+
|
| 96 |
+
# (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
|
| 97 |
+
# = sigma_t * exp(lambda_t) = alpha_t
|
| 98 |
+
# exp((1 + tau^2) * lambda_t) is extracted from the integral
|
| 99 |
+
alpha_t = sigma_next * lambda_t.exp()
|
| 100 |
+
return alpha_t * lagrange_integrals
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
|
| 104 |
+
"""Return a function that controls the stochasticity of SA-Solver.
|
| 105 |
+
|
| 106 |
+
When eta = 0, SA-Solver runs as ODE. The official approach uses
|
| 107 |
+
time t to determine the SDE interval, while here we use sigma instead.
|
| 108 |
+
|
| 109 |
+
See:
|
| 110 |
+
https://github.com/scxue/SA-Solver/blob/main/README.md
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def tau_func(sigma: Union[torch.Tensor, float]) -> float:
|
| 114 |
+
if eta <= 0:
|
| 115 |
+
return 0.0 # ODE
|
| 116 |
+
|
| 117 |
+
if isinstance(sigma, torch.Tensor):
|
| 118 |
+
sigma = sigma.item()
|
| 119 |
+
return eta if start_sigma >= sigma >= end_sigma else 0.0
|
| 120 |
+
|
| 121 |
+
return tau_func
|
ComfyUI/comfy/k_diffusion/sampling.py
ADDED
|
@@ -0,0 +1,1761 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
from scipy import integrate
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torchsde
|
| 8 |
+
from tqdm.auto import trange, tqdm
|
| 9 |
+
|
| 10 |
+
from . import utils
|
| 11 |
+
from . import deis
|
| 12 |
+
from . import sa_solver
|
| 13 |
+
import comfy.model_patcher
|
| 14 |
+
import comfy.model_sampling
|
| 15 |
+
|
| 16 |
+
def append_zero(x):
|
| 17 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
|
| 21 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 22 |
+
ramp = torch.linspace(0, 1, n, device=device)
|
| 23 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 24 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 25 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 26 |
+
return append_zero(sigmas).to(device)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
|
| 30 |
+
"""Constructs an exponential noise schedule."""
|
| 31 |
+
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
|
| 32 |
+
return append_zero(sigmas)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
|
| 36 |
+
"""Constructs an polynomial in log sigma noise schedule."""
|
| 37 |
+
ramp = torch.linspace(1, 0, n, device=device) ** rho
|
| 38 |
+
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
|
| 39 |
+
return append_zero(sigmas)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
| 43 |
+
"""Constructs a continuous VP noise schedule."""
|
| 44 |
+
t = torch.linspace(1, eps_s, n, device=device)
|
| 45 |
+
sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
|
| 46 |
+
return append_zero(sigmas)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
|
| 50 |
+
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
|
| 51 |
+
epsilon = 1e-5 # avoid log(0)
|
| 52 |
+
x = torch.linspace(0, 1, n, device=device)
|
| 53 |
+
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
|
| 54 |
+
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
|
| 55 |
+
sigmas = clamp(torch.exp(lmb))
|
| 56 |
+
return sigmas
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def to_d(x, sigma, denoised):
|
| 61 |
+
"""Converts a denoiser output to a Karras ODE derivative."""
|
| 62 |
+
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
| 66 |
+
"""Calculates the noise level (sigma_down) to step down to and the amount
|
| 67 |
+
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
| 68 |
+
if not eta:
|
| 69 |
+
return sigma_to, 0.
|
| 70 |
+
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
|
| 71 |
+
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
|
| 72 |
+
return sigma_down, sigma_up
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def default_noise_sampler(x, seed=None):
|
| 76 |
+
if seed is not None:
|
| 77 |
+
generator = torch.Generator(device=x.device)
|
| 78 |
+
generator.manual_seed(seed)
|
| 79 |
+
else:
|
| 80 |
+
generator = None
|
| 81 |
+
|
| 82 |
+
return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class BatchedBrownianTree:
|
| 86 |
+
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
| 87 |
+
|
| 88 |
+
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
| 89 |
+
self.cpu_tree = True
|
| 90 |
+
if "cpu" in kwargs:
|
| 91 |
+
self.cpu_tree = kwargs.pop("cpu")
|
| 92 |
+
t0, t1, self.sign = self.sort(t0, t1)
|
| 93 |
+
w0 = kwargs.get('w0', torch.zeros_like(x))
|
| 94 |
+
if seed is None:
|
| 95 |
+
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
| 96 |
+
self.batched = True
|
| 97 |
+
try:
|
| 98 |
+
assert len(seed) == x.shape[0]
|
| 99 |
+
w0 = w0[0]
|
| 100 |
+
except TypeError:
|
| 101 |
+
seed = [seed]
|
| 102 |
+
self.batched = False
|
| 103 |
+
if self.cpu_tree:
|
| 104 |
+
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
|
| 105 |
+
else:
|
| 106 |
+
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
| 107 |
+
|
| 108 |
+
@staticmethod
|
| 109 |
+
def sort(a, b):
|
| 110 |
+
return (a, b, 1) if a < b else (b, a, -1)
|
| 111 |
+
|
| 112 |
+
def __call__(self, t0, t1):
|
| 113 |
+
t0, t1, sign = self.sort(t0, t1)
|
| 114 |
+
if self.cpu_tree:
|
| 115 |
+
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
|
| 116 |
+
else:
|
| 117 |
+
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
| 118 |
+
|
| 119 |
+
return w if self.batched else w[0]
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class BrownianTreeNoiseSampler:
|
| 123 |
+
"""A noise sampler backed by a torchsde.BrownianTree.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
| 127 |
+
random samples.
|
| 128 |
+
sigma_min (float): The low end of the valid interval.
|
| 129 |
+
sigma_max (float): The high end of the valid interval.
|
| 130 |
+
seed (int or List[int]): The random seed. If a list of seeds is
|
| 131 |
+
supplied instead of a single integer, then the noise sampler will
|
| 132 |
+
use one BrownianTree per batch item, each with its own seed.
|
| 133 |
+
transform (callable): A function that maps sigma to the sampler's
|
| 134 |
+
internal timestep.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
|
| 138 |
+
self.transform = transform
|
| 139 |
+
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
| 140 |
+
self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
|
| 141 |
+
|
| 142 |
+
def __call__(self, sigma, sigma_next):
|
| 143 |
+
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
| 144 |
+
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def sigma_to_half_log_snr(sigma, model_sampling):
|
| 148 |
+
"""Convert sigma to half-logSNR log(alpha_t / sigma_t)."""
|
| 149 |
+
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
| 150 |
+
# log((1 - t) / t) = log((1 - sigma) / sigma)
|
| 151 |
+
return sigma.logit().neg()
|
| 152 |
+
return sigma.log().neg()
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def half_log_snr_to_sigma(half_log_snr, model_sampling):
|
| 156 |
+
"""Convert half-logSNR log(alpha_t / sigma_t) to sigma."""
|
| 157 |
+
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
| 158 |
+
# 1 / (1 + exp(half_log_snr))
|
| 159 |
+
return half_log_snr.neg().sigmoid()
|
| 160 |
+
return half_log_snr.neg().exp()
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
|
| 164 |
+
"""Adjust the first sigma to avoid invalid logSNR."""
|
| 165 |
+
if len(sigmas) <= 1:
|
| 166 |
+
return sigmas
|
| 167 |
+
if isinstance(model_sampling, comfy.model_sampling.CONST):
|
| 168 |
+
if sigmas[0] >= 1:
|
| 169 |
+
sigmas = sigmas.clone()
|
| 170 |
+
sigmas[0] = model_sampling.percent_to_sigma(percent_offset)
|
| 171 |
+
return sigmas
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
@torch.no_grad()
|
| 175 |
+
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
| 176 |
+
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
| 177 |
+
extra_args = {} if extra_args is None else extra_args
|
| 178 |
+
s_in = x.new_ones([x.shape[0]])
|
| 179 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 180 |
+
if s_churn > 0:
|
| 181 |
+
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
| 182 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
| 183 |
+
else:
|
| 184 |
+
gamma = 0
|
| 185 |
+
sigma_hat = sigmas[i]
|
| 186 |
+
|
| 187 |
+
if gamma > 0:
|
| 188 |
+
eps = torch.randn_like(x) * s_noise
|
| 189 |
+
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
| 190 |
+
denoised = model(x, sigma_hat * s_in, **extra_args)
|
| 191 |
+
d = to_d(x, sigma_hat, denoised)
|
| 192 |
+
if callback is not None:
|
| 193 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
| 194 |
+
dt = sigmas[i + 1] - sigma_hat
|
| 195 |
+
# Euler method
|
| 196 |
+
x = x + d * dt
|
| 197 |
+
return x
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@torch.no_grad()
|
| 201 |
+
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 202 |
+
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
| 203 |
+
return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
| 204 |
+
"""Ancestral sampling with Euler method steps."""
|
| 205 |
+
extra_args = {} if extra_args is None else extra_args
|
| 206 |
+
seed = extra_args.get("seed", None)
|
| 207 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 208 |
+
s_in = x.new_ones([x.shape[0]])
|
| 209 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 210 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 211 |
+
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
| 212 |
+
if callback is not None:
|
| 213 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 214 |
+
|
| 215 |
+
if sigma_down == 0:
|
| 216 |
+
x = denoised
|
| 217 |
+
else:
|
| 218 |
+
d = to_d(x, sigmas[i], denoised)
|
| 219 |
+
# Euler method
|
| 220 |
+
dt = sigma_down - sigmas[i]
|
| 221 |
+
x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
| 222 |
+
return x
|
| 223 |
+
|
| 224 |
+
@torch.no_grad()
|
| 225 |
+
def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None):
|
| 226 |
+
"""Ancestral sampling with Euler method steps."""
|
| 227 |
+
extra_args = {} if extra_args is None else extra_args
|
| 228 |
+
seed = extra_args.get("seed", None)
|
| 229 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 230 |
+
s_in = x.new_ones([x.shape[0]])
|
| 231 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 232 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 233 |
+
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
| 234 |
+
if callback is not None:
|
| 235 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 236 |
+
|
| 237 |
+
if sigmas[i + 1] == 0:
|
| 238 |
+
x = denoised
|
| 239 |
+
else:
|
| 240 |
+
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
|
| 241 |
+
sigma_down = sigmas[i + 1] * downstep_ratio
|
| 242 |
+
alpha_ip1 = 1 - sigmas[i + 1]
|
| 243 |
+
alpha_down = 1 - sigma_down
|
| 244 |
+
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
|
| 245 |
+
# Euler method
|
| 246 |
+
sigma_down_i_ratio = sigma_down / sigmas[i]
|
| 247 |
+
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
| 248 |
+
if eta > 0:
|
| 249 |
+
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
| 250 |
+
return x
|
| 251 |
+
|
| 252 |
+
@torch.no_grad()
|
| 253 |
+
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
| 254 |
+
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
| 255 |
+
extra_args = {} if extra_args is None else extra_args
|
| 256 |
+
s_in = x.new_ones([x.shape[0]])
|
| 257 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 258 |
+
if s_churn > 0:
|
| 259 |
+
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
| 260 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
| 261 |
+
else:
|
| 262 |
+
gamma = 0
|
| 263 |
+
sigma_hat = sigmas[i]
|
| 264 |
+
|
| 265 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
| 266 |
+
if gamma > 0:
|
| 267 |
+
eps = torch.randn_like(x) * s_noise
|
| 268 |
+
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
| 269 |
+
denoised = model(x, sigma_hat * s_in, **extra_args)
|
| 270 |
+
d = to_d(x, sigma_hat, denoised)
|
| 271 |
+
if callback is not None:
|
| 272 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
| 273 |
+
dt = sigmas[i + 1] - sigma_hat
|
| 274 |
+
if sigmas[i + 1] == 0:
|
| 275 |
+
# Euler method
|
| 276 |
+
x = x + d * dt
|
| 277 |
+
else:
|
| 278 |
+
# Heun's method
|
| 279 |
+
x_2 = x + d * dt
|
| 280 |
+
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
| 281 |
+
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
| 282 |
+
d_prime = (d + d_2) / 2
|
| 283 |
+
x = x + d_prime * dt
|
| 284 |
+
return x
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
@torch.no_grad()
|
| 288 |
+
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
| 289 |
+
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
| 290 |
+
extra_args = {} if extra_args is None else extra_args
|
| 291 |
+
s_in = x.new_ones([x.shape[0]])
|
| 292 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 293 |
+
if s_churn > 0:
|
| 294 |
+
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
| 295 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
| 296 |
+
else:
|
| 297 |
+
gamma = 0
|
| 298 |
+
sigma_hat = sigmas[i]
|
| 299 |
+
|
| 300 |
+
if gamma > 0:
|
| 301 |
+
eps = torch.randn_like(x) * s_noise
|
| 302 |
+
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
| 303 |
+
denoised = model(x, sigma_hat * s_in, **extra_args)
|
| 304 |
+
d = to_d(x, sigma_hat, denoised)
|
| 305 |
+
if callback is not None:
|
| 306 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
| 307 |
+
if sigmas[i + 1] == 0:
|
| 308 |
+
# Euler method
|
| 309 |
+
dt = sigmas[i + 1] - sigma_hat
|
| 310 |
+
x = x + d * dt
|
| 311 |
+
else:
|
| 312 |
+
# DPM-Solver-2
|
| 313 |
+
sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
|
| 314 |
+
dt_1 = sigma_mid - sigma_hat
|
| 315 |
+
dt_2 = sigmas[i + 1] - sigma_hat
|
| 316 |
+
x_2 = x + d * dt_1
|
| 317 |
+
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
| 318 |
+
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
| 319 |
+
x = x + d_2 * dt_2
|
| 320 |
+
return x
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@torch.no_grad()
|
| 324 |
+
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 325 |
+
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
| 326 |
+
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
| 327 |
+
|
| 328 |
+
"""Ancestral sampling with DPM-Solver second-order steps."""
|
| 329 |
+
extra_args = {} if extra_args is None else extra_args
|
| 330 |
+
seed = extra_args.get("seed", None)
|
| 331 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 332 |
+
s_in = x.new_ones([x.shape[0]])
|
| 333 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 334 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 335 |
+
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
| 336 |
+
if callback is not None:
|
| 337 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 338 |
+
d = to_d(x, sigmas[i], denoised)
|
| 339 |
+
if sigma_down == 0:
|
| 340 |
+
# Euler method
|
| 341 |
+
dt = sigma_down - sigmas[i]
|
| 342 |
+
x = x + d * dt
|
| 343 |
+
else:
|
| 344 |
+
# DPM-Solver-2
|
| 345 |
+
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
|
| 346 |
+
dt_1 = sigma_mid - sigmas[i]
|
| 347 |
+
dt_2 = sigma_down - sigmas[i]
|
| 348 |
+
x_2 = x + d * dt_1
|
| 349 |
+
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
| 350 |
+
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
| 351 |
+
x = x + d_2 * dt_2
|
| 352 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
| 353 |
+
return x
|
| 354 |
+
|
| 355 |
+
@torch.no_grad()
|
| 356 |
+
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 357 |
+
"""Ancestral sampling with DPM-Solver second-order steps."""
|
| 358 |
+
extra_args = {} if extra_args is None else extra_args
|
| 359 |
+
seed = extra_args.get("seed", None)
|
| 360 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 361 |
+
s_in = x.new_ones([x.shape[0]])
|
| 362 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 363 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 364 |
+
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
| 365 |
+
sigma_down = sigmas[i+1] * downstep_ratio
|
| 366 |
+
alpha_ip1 = 1 - sigmas[i+1]
|
| 367 |
+
alpha_down = 1 - sigma_down
|
| 368 |
+
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
| 369 |
+
|
| 370 |
+
if callback is not None:
|
| 371 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 372 |
+
d = to_d(x, sigmas[i], denoised)
|
| 373 |
+
if sigma_down == 0:
|
| 374 |
+
# Euler method
|
| 375 |
+
dt = sigma_down - sigmas[i]
|
| 376 |
+
x = x + d * dt
|
| 377 |
+
else:
|
| 378 |
+
# DPM-Solver-2
|
| 379 |
+
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
|
| 380 |
+
dt_1 = sigma_mid - sigmas[i]
|
| 381 |
+
dt_2 = sigma_down - sigmas[i]
|
| 382 |
+
x_2 = x + d * dt_1
|
| 383 |
+
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
| 384 |
+
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
| 385 |
+
x = x + d_2 * dt_2
|
| 386 |
+
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
| 387 |
+
return x
|
| 388 |
+
|
| 389 |
+
def linear_multistep_coeff(order, t, i, j):
|
| 390 |
+
if order - 1 > i:
|
| 391 |
+
raise ValueError(f'Order {order} too high for step {i}')
|
| 392 |
+
def fn(tau):
|
| 393 |
+
prod = 1.
|
| 394 |
+
for k in range(order):
|
| 395 |
+
if j == k:
|
| 396 |
+
continue
|
| 397 |
+
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
| 398 |
+
return prod
|
| 399 |
+
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
@torch.no_grad()
|
| 403 |
+
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
|
| 404 |
+
extra_args = {} if extra_args is None else extra_args
|
| 405 |
+
s_in = x.new_ones([x.shape[0]])
|
| 406 |
+
sigmas_cpu = sigmas.detach().cpu().numpy()
|
| 407 |
+
ds = []
|
| 408 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 409 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 410 |
+
d = to_d(x, sigmas[i], denoised)
|
| 411 |
+
ds.append(d)
|
| 412 |
+
if len(ds) > order:
|
| 413 |
+
ds.pop(0)
|
| 414 |
+
if callback is not None:
|
| 415 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 416 |
+
if sigmas[i + 1] == 0:
|
| 417 |
+
# Denoising step
|
| 418 |
+
x = denoised
|
| 419 |
+
else:
|
| 420 |
+
cur_order = min(i + 1, order)
|
| 421 |
+
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
| 422 |
+
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
| 423 |
+
return x
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class PIDStepSizeController:
|
| 427 |
+
"""A PID controller for ODE adaptive step size control."""
|
| 428 |
+
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
| 429 |
+
self.h = h
|
| 430 |
+
self.b1 = (pcoeff + icoeff + dcoeff) / order
|
| 431 |
+
self.b2 = -(pcoeff + 2 * dcoeff) / order
|
| 432 |
+
self.b3 = dcoeff / order
|
| 433 |
+
self.accept_safety = accept_safety
|
| 434 |
+
self.eps = eps
|
| 435 |
+
self.errs = []
|
| 436 |
+
|
| 437 |
+
def limiter(self, x):
|
| 438 |
+
return 1 + math.atan(x - 1)
|
| 439 |
+
|
| 440 |
+
def propose_step(self, error):
|
| 441 |
+
inv_error = 1 / (float(error) + self.eps)
|
| 442 |
+
if not self.errs:
|
| 443 |
+
self.errs = [inv_error, inv_error, inv_error]
|
| 444 |
+
self.errs[0] = inv_error
|
| 445 |
+
factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
|
| 446 |
+
factor = self.limiter(factor)
|
| 447 |
+
accept = factor >= self.accept_safety
|
| 448 |
+
if accept:
|
| 449 |
+
self.errs[2] = self.errs[1]
|
| 450 |
+
self.errs[1] = self.errs[0]
|
| 451 |
+
self.h *= factor
|
| 452 |
+
return accept
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class DPMSolver(nn.Module):
|
| 456 |
+
"""DPM-Solver. See https://arxiv.org/abs/2206.00927."""
|
| 457 |
+
|
| 458 |
+
def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
|
| 459 |
+
super().__init__()
|
| 460 |
+
self.model = model
|
| 461 |
+
self.extra_args = {} if extra_args is None else extra_args
|
| 462 |
+
self.eps_callback = eps_callback
|
| 463 |
+
self.info_callback = info_callback
|
| 464 |
+
|
| 465 |
+
def t(self, sigma):
|
| 466 |
+
return -sigma.log()
|
| 467 |
+
|
| 468 |
+
def sigma(self, t):
|
| 469 |
+
return t.neg().exp()
|
| 470 |
+
|
| 471 |
+
def eps(self, eps_cache, key, x, t, *args, **kwargs):
|
| 472 |
+
if key in eps_cache:
|
| 473 |
+
return eps_cache[key], eps_cache
|
| 474 |
+
sigma = self.sigma(t) * x.new_ones([x.shape[0]])
|
| 475 |
+
eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
|
| 476 |
+
if self.eps_callback is not None:
|
| 477 |
+
self.eps_callback()
|
| 478 |
+
return eps, {key: eps, **eps_cache}
|
| 479 |
+
|
| 480 |
+
def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
|
| 481 |
+
eps_cache = {} if eps_cache is None else eps_cache
|
| 482 |
+
h = t_next - t
|
| 483 |
+
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
| 484 |
+
x_1 = x - self.sigma(t_next) * h.expm1() * eps
|
| 485 |
+
return x_1, eps_cache
|
| 486 |
+
|
| 487 |
+
def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
|
| 488 |
+
eps_cache = {} if eps_cache is None else eps_cache
|
| 489 |
+
h = t_next - t
|
| 490 |
+
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
| 491 |
+
s1 = t + r1 * h
|
| 492 |
+
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
|
| 493 |
+
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
|
| 494 |
+
x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
|
| 495 |
+
return x_2, eps_cache
|
| 496 |
+
|
| 497 |
+
def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
|
| 498 |
+
eps_cache = {} if eps_cache is None else eps_cache
|
| 499 |
+
h = t_next - t
|
| 500 |
+
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
| 501 |
+
s1 = t + r1 * h
|
| 502 |
+
s2 = t + r2 * h
|
| 503 |
+
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
|
| 504 |
+
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
|
| 505 |
+
u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
|
| 506 |
+
eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
|
| 507 |
+
x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
|
| 508 |
+
return x_3, eps_cache
|
| 509 |
+
|
| 510 |
+
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
|
| 511 |
+
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
|
| 512 |
+
if not t_end > t_start and eta:
|
| 513 |
+
raise ValueError('eta must be 0 for reverse sampling')
|
| 514 |
+
|
| 515 |
+
m = math.floor(nfe / 3) + 1
|
| 516 |
+
ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
|
| 517 |
+
|
| 518 |
+
if nfe % 3 == 0:
|
| 519 |
+
orders = [3] * (m - 2) + [2, 1]
|
| 520 |
+
else:
|
| 521 |
+
orders = [3] * (m - 1) + [nfe % 3]
|
| 522 |
+
|
| 523 |
+
for i in range(len(orders)):
|
| 524 |
+
eps_cache = {}
|
| 525 |
+
t, t_next = ts[i], ts[i + 1]
|
| 526 |
+
if eta:
|
| 527 |
+
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
|
| 528 |
+
t_next_ = torch.minimum(t_end, self.t(sd))
|
| 529 |
+
su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
|
| 530 |
+
else:
|
| 531 |
+
t_next_, su = t_next, 0.
|
| 532 |
+
|
| 533 |
+
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
| 534 |
+
denoised = x - self.sigma(t) * eps
|
| 535 |
+
if self.info_callback is not None:
|
| 536 |
+
self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
|
| 537 |
+
|
| 538 |
+
if orders[i] == 1:
|
| 539 |
+
x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
|
| 540 |
+
elif orders[i] == 2:
|
| 541 |
+
x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
|
| 542 |
+
else:
|
| 543 |
+
x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
|
| 544 |
+
|
| 545 |
+
x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
|
| 546 |
+
|
| 547 |
+
return x
|
| 548 |
+
|
| 549 |
+
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
|
| 550 |
+
noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler
|
| 551 |
+
if order not in {2, 3}:
|
| 552 |
+
raise ValueError('order should be 2 or 3')
|
| 553 |
+
forward = t_end > t_start
|
| 554 |
+
if not forward and eta:
|
| 555 |
+
raise ValueError('eta must be 0 for reverse sampling')
|
| 556 |
+
h_init = abs(h_init) * (1 if forward else -1)
|
| 557 |
+
atol = torch.tensor(atol)
|
| 558 |
+
rtol = torch.tensor(rtol)
|
| 559 |
+
s = t_start
|
| 560 |
+
x_prev = x
|
| 561 |
+
accept = True
|
| 562 |
+
pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
|
| 563 |
+
info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
|
| 564 |
+
|
| 565 |
+
while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
|
| 566 |
+
eps_cache = {}
|
| 567 |
+
t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
|
| 568 |
+
if eta:
|
| 569 |
+
sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
|
| 570 |
+
t_ = torch.minimum(t_end, self.t(sd))
|
| 571 |
+
su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
|
| 572 |
+
else:
|
| 573 |
+
t_, su = t, 0.
|
| 574 |
+
|
| 575 |
+
eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
|
| 576 |
+
denoised = x - self.sigma(s) * eps
|
| 577 |
+
|
| 578 |
+
if order == 2:
|
| 579 |
+
x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
|
| 580 |
+
x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
|
| 581 |
+
else:
|
| 582 |
+
x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
|
| 583 |
+
x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
|
| 584 |
+
delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
|
| 585 |
+
error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
|
| 586 |
+
accept = pid.propose_step(error)
|
| 587 |
+
if accept:
|
| 588 |
+
x_prev = x_low
|
| 589 |
+
x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
|
| 590 |
+
s = t
|
| 591 |
+
info['n_accept'] += 1
|
| 592 |
+
else:
|
| 593 |
+
info['n_reject'] += 1
|
| 594 |
+
info['nfe'] += order
|
| 595 |
+
info['steps'] += 1
|
| 596 |
+
|
| 597 |
+
if self.info_callback is not None:
|
| 598 |
+
self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
|
| 599 |
+
|
| 600 |
+
return x, info
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
@torch.no_grad()
|
| 604 |
+
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
|
| 605 |
+
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
|
| 606 |
+
if sigma_min <= 0 or sigma_max <= 0:
|
| 607 |
+
raise ValueError('sigma_min and sigma_max must not be 0')
|
| 608 |
+
with tqdm(total=n, disable=disable) as pbar:
|
| 609 |
+
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
|
| 610 |
+
if callback is not None:
|
| 611 |
+
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
|
| 612 |
+
return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
@torch.no_grad()
|
| 616 |
+
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
|
| 617 |
+
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
|
| 618 |
+
if sigma_min <= 0 or sigma_max <= 0:
|
| 619 |
+
raise ValueError('sigma_min and sigma_max must not be 0')
|
| 620 |
+
with tqdm(disable=disable) as pbar:
|
| 621 |
+
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
|
| 622 |
+
if callback is not None:
|
| 623 |
+
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
|
| 624 |
+
x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
|
| 625 |
+
if return_info:
|
| 626 |
+
return x, info
|
| 627 |
+
return x
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
@torch.no_grad()
|
| 631 |
+
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 632 |
+
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
|
| 633 |
+
return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
| 634 |
+
|
| 635 |
+
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
| 636 |
+
extra_args = {} if extra_args is None else extra_args
|
| 637 |
+
seed = extra_args.get("seed", None)
|
| 638 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 639 |
+
s_in = x.new_ones([x.shape[0]])
|
| 640 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 641 |
+
t_fn = lambda sigma: sigma.log().neg()
|
| 642 |
+
|
| 643 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 644 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 645 |
+
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
| 646 |
+
if callback is not None:
|
| 647 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 648 |
+
if sigma_down == 0:
|
| 649 |
+
# Euler method
|
| 650 |
+
d = to_d(x, sigmas[i], denoised)
|
| 651 |
+
dt = sigma_down - sigmas[i]
|
| 652 |
+
x = x + d * dt
|
| 653 |
+
else:
|
| 654 |
+
# DPM-Solver++(2S)
|
| 655 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
| 656 |
+
r = 1 / 2
|
| 657 |
+
h = t_next - t
|
| 658 |
+
s = t + r * h
|
| 659 |
+
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
|
| 660 |
+
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
| 661 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
|
| 662 |
+
# Noise addition
|
| 663 |
+
if sigmas[i + 1] > 0:
|
| 664 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
| 665 |
+
return x
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
@torch.no_grad()
|
| 669 |
+
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 670 |
+
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
| 671 |
+
extra_args = {} if extra_args is None else extra_args
|
| 672 |
+
seed = extra_args.get("seed", None)
|
| 673 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 674 |
+
s_in = x.new_ones([x.shape[0]])
|
| 675 |
+
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
| 676 |
+
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
| 677 |
+
|
| 678 |
+
# logged_x = x.unsqueeze(0)
|
| 679 |
+
|
| 680 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 681 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 682 |
+
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
| 683 |
+
sigma_down = sigmas[i+1] * downstep_ratio
|
| 684 |
+
alpha_ip1 = 1 - sigmas[i+1]
|
| 685 |
+
alpha_down = 1 - sigma_down
|
| 686 |
+
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
| 687 |
+
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
| 688 |
+
if callback is not None:
|
| 689 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 690 |
+
if sigmas[i + 1] == 0:
|
| 691 |
+
# Euler method
|
| 692 |
+
d = to_d(x, sigmas[i], denoised)
|
| 693 |
+
dt = sigma_down - sigmas[i]
|
| 694 |
+
x = x + d * dt
|
| 695 |
+
else:
|
| 696 |
+
# DPM-Solver++(2S)
|
| 697 |
+
if sigmas[i] == 1.0:
|
| 698 |
+
sigma_s = 0.9999
|
| 699 |
+
else:
|
| 700 |
+
t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
|
| 701 |
+
r = 1 / 2
|
| 702 |
+
h = t_down - t_i
|
| 703 |
+
s = t_i + r * h
|
| 704 |
+
sigma_s = sigma_fn(s)
|
| 705 |
+
# sigma_s = sigmas[i+1]
|
| 706 |
+
sigma_s_i_ratio = sigma_s / sigmas[i]
|
| 707 |
+
u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
|
| 708 |
+
D_i = model(u, sigma_s * s_in, **extra_args)
|
| 709 |
+
sigma_down_i_ratio = sigma_down / sigmas[i]
|
| 710 |
+
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
|
| 711 |
+
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
|
| 712 |
+
# Noise addition
|
| 713 |
+
if sigmas[i + 1] > 0 and eta > 0:
|
| 714 |
+
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
| 715 |
+
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
| 716 |
+
return x
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
@torch.no_grad()
|
| 720 |
+
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
| 721 |
+
"""DPM-Solver++ (stochastic)."""
|
| 722 |
+
if len(sigmas) <= 1:
|
| 723 |
+
return x
|
| 724 |
+
|
| 725 |
+
extra_args = {} if extra_args is None else extra_args
|
| 726 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
| 727 |
+
seed = extra_args.get("seed", None)
|
| 728 |
+
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
| 729 |
+
s_in = x.new_ones([x.shape[0]])
|
| 730 |
+
|
| 731 |
+
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
| 732 |
+
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
| 733 |
+
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
| 734 |
+
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
| 735 |
+
|
| 736 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 737 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 738 |
+
if callback is not None:
|
| 739 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 740 |
+
if sigmas[i + 1] == 0:
|
| 741 |
+
# Denoising step
|
| 742 |
+
x = denoised
|
| 743 |
+
else:
|
| 744 |
+
# DPM-Solver++
|
| 745 |
+
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
| 746 |
+
h = lambda_t - lambda_s
|
| 747 |
+
lambda_s_1 = lambda_s + r * h
|
| 748 |
+
fac = 1 / (2 * r)
|
| 749 |
+
|
| 750 |
+
sigma_s_1 = sigma_fn(lambda_s_1)
|
| 751 |
+
|
| 752 |
+
alpha_s = sigmas[i] * lambda_s.exp()
|
| 753 |
+
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
| 754 |
+
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
| 755 |
+
|
| 756 |
+
# Step 1
|
| 757 |
+
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta)
|
| 758 |
+
lambda_s_1_ = sd.log().neg()
|
| 759 |
+
h_ = lambda_s_1_ - lambda_s
|
| 760 |
+
x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised
|
| 761 |
+
if eta > 0 and s_noise > 0:
|
| 762 |
+
x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su
|
| 763 |
+
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
| 764 |
+
|
| 765 |
+
# Step 2
|
| 766 |
+
sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_t.neg().exp(), eta)
|
| 767 |
+
lambda_t_ = sd.log().neg()
|
| 768 |
+
h_ = lambda_t_ - lambda_s
|
| 769 |
+
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
| 770 |
+
x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d
|
| 771 |
+
if eta > 0 and s_noise > 0:
|
| 772 |
+
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su
|
| 773 |
+
return x
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
@torch.no_grad()
|
| 777 |
+
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
| 778 |
+
"""DPM-Solver++(2M)."""
|
| 779 |
+
extra_args = {} if extra_args is None else extra_args
|
| 780 |
+
s_in = x.new_ones([x.shape[0]])
|
| 781 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 782 |
+
t_fn = lambda sigma: sigma.log().neg()
|
| 783 |
+
old_denoised = None
|
| 784 |
+
|
| 785 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 786 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 787 |
+
if callback is not None:
|
| 788 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 789 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
| 790 |
+
h = t_next - t
|
| 791 |
+
if old_denoised is None or sigmas[i + 1] == 0:
|
| 792 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
| 793 |
+
else:
|
| 794 |
+
h_last = t - t_fn(sigmas[i - 1])
|
| 795 |
+
r = h_last / h
|
| 796 |
+
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
| 797 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
| 798 |
+
old_denoised = denoised
|
| 799 |
+
return x
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
@torch.no_grad()
|
| 803 |
+
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
| 804 |
+
"""DPM-Solver++(2M) SDE."""
|
| 805 |
+
if len(sigmas) <= 1:
|
| 806 |
+
return x
|
| 807 |
+
|
| 808 |
+
if solver_type not in {'heun', 'midpoint'}:
|
| 809 |
+
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
| 810 |
+
|
| 811 |
+
extra_args = {} if extra_args is None else extra_args
|
| 812 |
+
seed = extra_args.get("seed", None)
|
| 813 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
| 814 |
+
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
| 815 |
+
s_in = x.new_ones([x.shape[0]])
|
| 816 |
+
|
| 817 |
+
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
| 818 |
+
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
| 819 |
+
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
| 820 |
+
|
| 821 |
+
old_denoised = None
|
| 822 |
+
h, h_last = None, None
|
| 823 |
+
|
| 824 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 825 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 826 |
+
if callback is not None:
|
| 827 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 828 |
+
if sigmas[i + 1] == 0:
|
| 829 |
+
# Denoising step
|
| 830 |
+
x = denoised
|
| 831 |
+
else:
|
| 832 |
+
# DPM-Solver++(2M) SDE
|
| 833 |
+
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
| 834 |
+
h = lambda_t - lambda_s
|
| 835 |
+
h_eta = h * (eta + 1)
|
| 836 |
+
|
| 837 |
+
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
| 838 |
+
|
| 839 |
+
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
| 840 |
+
|
| 841 |
+
if old_denoised is not None:
|
| 842 |
+
r = h_last / h
|
| 843 |
+
if solver_type == 'heun':
|
| 844 |
+
x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised)
|
| 845 |
+
elif solver_type == 'midpoint':
|
| 846 |
+
x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
| 847 |
+
|
| 848 |
+
if eta > 0 and s_noise > 0:
|
| 849 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
| 850 |
+
|
| 851 |
+
old_denoised = denoised
|
| 852 |
+
h_last = h
|
| 853 |
+
return x
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
@torch.no_grad()
|
| 857 |
+
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 858 |
+
"""DPM-Solver++(3M) SDE."""
|
| 859 |
+
|
| 860 |
+
if len(sigmas) <= 1:
|
| 861 |
+
return x
|
| 862 |
+
|
| 863 |
+
extra_args = {} if extra_args is None else extra_args
|
| 864 |
+
seed = extra_args.get("seed", None)
|
| 865 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
| 866 |
+
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
|
| 867 |
+
s_in = x.new_ones([x.shape[0]])
|
| 868 |
+
|
| 869 |
+
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
| 870 |
+
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
| 871 |
+
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
| 872 |
+
|
| 873 |
+
denoised_1, denoised_2 = None, None
|
| 874 |
+
h, h_1, h_2 = None, None, None
|
| 875 |
+
|
| 876 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 877 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 878 |
+
if callback is not None:
|
| 879 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 880 |
+
if sigmas[i + 1] == 0:
|
| 881 |
+
# Denoising step
|
| 882 |
+
x = denoised
|
| 883 |
+
else:
|
| 884 |
+
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
| 885 |
+
h = lambda_t - lambda_s
|
| 886 |
+
h_eta = h * (eta + 1)
|
| 887 |
+
|
| 888 |
+
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
| 889 |
+
|
| 890 |
+
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised
|
| 891 |
+
|
| 892 |
+
if h_2 is not None:
|
| 893 |
+
# DPM-Solver++(3M) SDE
|
| 894 |
+
r0 = h_1 / h
|
| 895 |
+
r1 = h_2 / h
|
| 896 |
+
d1_0 = (denoised - denoised_1) / r0
|
| 897 |
+
d1_1 = (denoised_1 - denoised_2) / r1
|
| 898 |
+
d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
|
| 899 |
+
d2 = (d1_0 - d1_1) / (r0 + r1)
|
| 900 |
+
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
| 901 |
+
phi_3 = phi_2 / h_eta - 0.5
|
| 902 |
+
x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2
|
| 903 |
+
elif h_1 is not None:
|
| 904 |
+
# DPM-Solver++(2M) SDE
|
| 905 |
+
r = h_1 / h
|
| 906 |
+
d = (denoised - denoised_1) / r
|
| 907 |
+
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
| 908 |
+
x = x + (alpha_t * phi_2) * d
|
| 909 |
+
|
| 910 |
+
if eta > 0 and s_noise > 0:
|
| 911 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
| 912 |
+
|
| 913 |
+
denoised_1, denoised_2 = denoised, denoised_1
|
| 914 |
+
h_1, h_2 = h, h_1
|
| 915 |
+
return x
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
@torch.no_grad()
|
| 919 |
+
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 920 |
+
if len(sigmas) <= 1:
|
| 921 |
+
return x
|
| 922 |
+
extra_args = {} if extra_args is None else extra_args
|
| 923 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
| 924 |
+
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
| 925 |
+
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
@torch.no_grad()
|
| 929 |
+
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
| 930 |
+
if len(sigmas) <= 1:
|
| 931 |
+
return x
|
| 932 |
+
extra_args = {} if extra_args is None else extra_args
|
| 933 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
| 934 |
+
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
| 935 |
+
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
@torch.no_grad()
|
| 939 |
+
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
| 940 |
+
if len(sigmas) <= 1:
|
| 941 |
+
return x
|
| 942 |
+
extra_args = {} if extra_args is None else extra_args
|
| 943 |
+
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
| 944 |
+
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
| 945 |
+
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
| 949 |
+
alpha_cumprod = 1 / ((sigma * sigma) + 1)
|
| 950 |
+
alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
|
| 951 |
+
alpha = (alpha_cumprod / alpha_cumprod_prev)
|
| 952 |
+
|
| 953 |
+
mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt())
|
| 954 |
+
if sigma_prev > 0:
|
| 955 |
+
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
|
| 956 |
+
return mu
|
| 957 |
+
|
| 958 |
+
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
| 959 |
+
extra_args = {} if extra_args is None else extra_args
|
| 960 |
+
seed = extra_args.get("seed", None)
|
| 961 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 962 |
+
s_in = x.new_ones([x.shape[0]])
|
| 963 |
+
|
| 964 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 965 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 966 |
+
if callback is not None:
|
| 967 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 968 |
+
x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler)
|
| 969 |
+
if sigmas[i + 1] != 0:
|
| 970 |
+
x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0)
|
| 971 |
+
return x
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
@torch.no_grad()
|
| 975 |
+
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
| 976 |
+
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
| 977 |
+
|
| 978 |
+
@torch.no_grad()
|
| 979 |
+
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
| 980 |
+
extra_args = {} if extra_args is None else extra_args
|
| 981 |
+
seed = extra_args.get("seed", None)
|
| 982 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 983 |
+
s_in = x.new_ones([x.shape[0]])
|
| 984 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 985 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 986 |
+
if callback is not None:
|
| 987 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 988 |
+
|
| 989 |
+
x = denoised
|
| 990 |
+
if sigmas[i + 1] > 0:
|
| 991 |
+
x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
|
| 992 |
+
return x
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
@torch.no_grad()
|
| 997 |
+
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
| 998 |
+
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
|
| 999 |
+
extra_args = {} if extra_args is None else extra_args
|
| 1000 |
+
s_in = x.new_ones([x.shape[0]])
|
| 1001 |
+
s_end = sigmas[-1]
|
| 1002 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 1003 |
+
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
| 1004 |
+
eps = torch.randn_like(x) * s_noise
|
| 1005 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
| 1006 |
+
if gamma > 0:
|
| 1007 |
+
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
| 1008 |
+
denoised = model(x, sigma_hat * s_in, **extra_args)
|
| 1009 |
+
d = to_d(x, sigma_hat, denoised)
|
| 1010 |
+
if callback is not None:
|
| 1011 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
| 1012 |
+
dt = sigmas[i + 1] - sigma_hat
|
| 1013 |
+
if sigmas[i + 1] == s_end:
|
| 1014 |
+
# Euler method
|
| 1015 |
+
x = x + d * dt
|
| 1016 |
+
elif sigmas[i + 2] == s_end:
|
| 1017 |
+
|
| 1018 |
+
# Heun's method
|
| 1019 |
+
x_2 = x + d * dt
|
| 1020 |
+
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
| 1021 |
+
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
| 1022 |
+
|
| 1023 |
+
w = 2 * sigmas[0]
|
| 1024 |
+
w2 = sigmas[i+1]/w
|
| 1025 |
+
w1 = 1 - w2
|
| 1026 |
+
|
| 1027 |
+
d_prime = d * w1 + d_2 * w2
|
| 1028 |
+
|
| 1029 |
+
|
| 1030 |
+
x = x + d_prime * dt
|
| 1031 |
+
|
| 1032 |
+
else:
|
| 1033 |
+
# Heun++
|
| 1034 |
+
x_2 = x + d * dt
|
| 1035 |
+
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
| 1036 |
+
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
| 1037 |
+
dt_2 = sigmas[i + 2] - sigmas[i + 1]
|
| 1038 |
+
|
| 1039 |
+
x_3 = x_2 + d_2 * dt_2
|
| 1040 |
+
denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
|
| 1041 |
+
d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
|
| 1042 |
+
|
| 1043 |
+
w = 3 * sigmas[0]
|
| 1044 |
+
w2 = sigmas[i + 1] / w
|
| 1045 |
+
w3 = sigmas[i + 2] / w
|
| 1046 |
+
w1 = 1 - w2 - w3
|
| 1047 |
+
|
| 1048 |
+
d_prime = w1 * d + w2 * d_2 + w3 * d_3
|
| 1049 |
+
x = x + d_prime * dt
|
| 1050 |
+
return x
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
| 1054 |
+
#under Apache 2 license
|
| 1055 |
+
def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
| 1056 |
+
extra_args = {} if extra_args is None else extra_args
|
| 1057 |
+
s_in = x.new_ones([x.shape[0]])
|
| 1058 |
+
|
| 1059 |
+
x_next = x
|
| 1060 |
+
|
| 1061 |
+
buffer_model = []
|
| 1062 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 1063 |
+
t_cur = sigmas[i]
|
| 1064 |
+
t_next = sigmas[i + 1]
|
| 1065 |
+
|
| 1066 |
+
x_cur = x_next
|
| 1067 |
+
|
| 1068 |
+
denoised = model(x_cur, t_cur * s_in, **extra_args)
|
| 1069 |
+
if callback is not None:
|
| 1070 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 1071 |
+
|
| 1072 |
+
d_cur = (x_cur - denoised) / t_cur
|
| 1073 |
+
|
| 1074 |
+
order = min(max_order, i+1)
|
| 1075 |
+
if t_next == 0: # Denoising step
|
| 1076 |
+
x_next = denoised
|
| 1077 |
+
elif order == 1: # First Euler step.
|
| 1078 |
+
x_next = x_cur + (t_next - t_cur) * d_cur
|
| 1079 |
+
elif order == 2: # Use one history point.
|
| 1080 |
+
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
|
| 1081 |
+
elif order == 3: # Use two history points.
|
| 1082 |
+
x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
|
| 1083 |
+
elif order == 4: # Use three history points.
|
| 1084 |
+
x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
|
| 1085 |
+
|
| 1086 |
+
if len(buffer_model) == max_order - 1:
|
| 1087 |
+
for k in range(max_order - 2):
|
| 1088 |
+
buffer_model[k] = buffer_model[k+1]
|
| 1089 |
+
buffer_model[-1] = d_cur
|
| 1090 |
+
else:
|
| 1091 |
+
buffer_model.append(d_cur)
|
| 1092 |
+
|
| 1093 |
+
return x_next
|
| 1094 |
+
|
| 1095 |
+
|
| 1096 |
+
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
| 1097 |
+
#under Apache 2 license
|
| 1098 |
+
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
| 1099 |
+
extra_args = {} if extra_args is None else extra_args
|
| 1100 |
+
s_in = x.new_ones([x.shape[0]])
|
| 1101 |
+
|
| 1102 |
+
x_next = x
|
| 1103 |
+
t_steps = sigmas
|
| 1104 |
+
|
| 1105 |
+
buffer_model = []
|
| 1106 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 1107 |
+
t_cur = sigmas[i]
|
| 1108 |
+
t_next = sigmas[i + 1]
|
| 1109 |
+
|
| 1110 |
+
x_cur = x_next
|
| 1111 |
+
|
| 1112 |
+
denoised = model(x_cur, t_cur * s_in, **extra_args)
|
| 1113 |
+
if callback is not None:
|
| 1114 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 1115 |
+
|
| 1116 |
+
d_cur = (x_cur - denoised) / t_cur
|
| 1117 |
+
|
| 1118 |
+
order = min(max_order, i+1)
|
| 1119 |
+
if t_next == 0: # Denoising step
|
| 1120 |
+
x_next = denoised
|
| 1121 |
+
elif order == 1: # First Euler step.
|
| 1122 |
+
x_next = x_cur + (t_next - t_cur) * d_cur
|
| 1123 |
+
elif order == 2: # Use one history point.
|
| 1124 |
+
h_n = (t_next - t_cur)
|
| 1125 |
+
h_n_1 = (t_cur - t_steps[i-1])
|
| 1126 |
+
coeff1 = (2 + (h_n / h_n_1)) / 2
|
| 1127 |
+
coeff2 = -(h_n / h_n_1) / 2
|
| 1128 |
+
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
|
| 1129 |
+
elif order == 3: # Use two history points.
|
| 1130 |
+
h_n = (t_next - t_cur)
|
| 1131 |
+
h_n_1 = (t_cur - t_steps[i-1])
|
| 1132 |
+
h_n_2 = (t_steps[i-1] - t_steps[i-2])
|
| 1133 |
+
temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
|
| 1134 |
+
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
|
| 1135 |
+
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
|
| 1136 |
+
coeff3 = temp * h_n_1 / h_n_2
|
| 1137 |
+
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
|
| 1138 |
+
elif order == 4: # Use three history points.
|
| 1139 |
+
h_n = (t_next - t_cur)
|
| 1140 |
+
h_n_1 = (t_cur - t_steps[i-1])
|
| 1141 |
+
h_n_2 = (t_steps[i-1] - t_steps[i-2])
|
| 1142 |
+
h_n_3 = (t_steps[i-2] - t_steps[i-3])
|
| 1143 |
+
temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
|
| 1144 |
+
temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
|
| 1145 |
+
* (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
|
| 1146 |
+
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
|
| 1147 |
+
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
|
| 1148 |
+
coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
|
| 1149 |
+
coeff4 = -temp2 * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * h_n_1 / h_n_2
|
| 1150 |
+
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + coeff4 * buffer_model[-3])
|
| 1151 |
+
|
| 1152 |
+
if len(buffer_model) == max_order - 1:
|
| 1153 |
+
for k in range(max_order - 2):
|
| 1154 |
+
buffer_model[k] = buffer_model[k+1]
|
| 1155 |
+
buffer_model[-1] = d_cur.detach()
|
| 1156 |
+
else:
|
| 1157 |
+
buffer_model.append(d_cur.detach())
|
| 1158 |
+
|
| 1159 |
+
return x_next
|
| 1160 |
+
|
| 1161 |
+
|
| 1162 |
+
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
| 1163 |
+
#under Apache 2 license
|
| 1164 |
+
@torch.no_grad()
|
| 1165 |
+
def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
|
| 1166 |
+
extra_args = {} if extra_args is None else extra_args
|
| 1167 |
+
s_in = x.new_ones([x.shape[0]])
|
| 1168 |
+
|
| 1169 |
+
x_next = x
|
| 1170 |
+
t_steps = sigmas
|
| 1171 |
+
|
| 1172 |
+
coeff_list = deis.get_deis_coeff_list(t_steps, max_order, deis_mode=deis_mode)
|
| 1173 |
+
|
| 1174 |
+
buffer_model = []
|
| 1175 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 1176 |
+
t_cur = sigmas[i]
|
| 1177 |
+
t_next = sigmas[i + 1]
|
| 1178 |
+
|
| 1179 |
+
x_cur = x_next
|
| 1180 |
+
|
| 1181 |
+
denoised = model(x_cur, t_cur * s_in, **extra_args)
|
| 1182 |
+
if callback is not None:
|
| 1183 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 1184 |
+
|
| 1185 |
+
d_cur = (x_cur - denoised) / t_cur
|
| 1186 |
+
|
| 1187 |
+
order = min(max_order, i+1)
|
| 1188 |
+
if t_next <= 0:
|
| 1189 |
+
order = 1
|
| 1190 |
+
|
| 1191 |
+
if order == 1: # First Euler step.
|
| 1192 |
+
x_next = x_cur + (t_next - t_cur) * d_cur
|
| 1193 |
+
elif order == 2: # Use one history point.
|
| 1194 |
+
coeff_cur, coeff_prev1 = coeff_list[i]
|
| 1195 |
+
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
|
| 1196 |
+
elif order == 3: # Use two history points.
|
| 1197 |
+
coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
|
| 1198 |
+
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
|
| 1199 |
+
elif order == 4: # Use three history points.
|
| 1200 |
+
coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
|
| 1201 |
+
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
|
| 1202 |
+
|
| 1203 |
+
if len(buffer_model) == max_order - 1:
|
| 1204 |
+
for k in range(max_order - 2):
|
| 1205 |
+
buffer_model[k] = buffer_model[k+1]
|
| 1206 |
+
buffer_model[-1] = d_cur.detach()
|
| 1207 |
+
else:
|
| 1208 |
+
buffer_model.append(d_cur.detach())
|
| 1209 |
+
|
| 1210 |
+
return x_next
|
| 1211 |
+
|
| 1212 |
+
|
| 1213 |
+
@torch.no_grad()
|
| 1214 |
+
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 1215 |
+
"""Ancestral sampling with Euler method steps (CFG++)."""
|
| 1216 |
+
extra_args = {} if extra_args is None else extra_args
|
| 1217 |
+
seed = extra_args.get("seed", None)
|
| 1218 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 1219 |
+
|
| 1220 |
+
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
| 1221 |
+
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
| 1222 |
+
|
| 1223 |
+
uncond_denoised = None
|
| 1224 |
+
|
| 1225 |
+
def post_cfg_function(args):
|
| 1226 |
+
nonlocal uncond_denoised
|
| 1227 |
+
uncond_denoised = args["uncond_denoised"]
|
| 1228 |
+
return args["denoised"]
|
| 1229 |
+
|
| 1230 |
+
model_options = extra_args.get("model_options", {}).copy()
|
| 1231 |
+
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
| 1232 |
+
|
| 1233 |
+
s_in = x.new_ones([x.shape[0]])
|
| 1234 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 1235 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 1236 |
+
if callback is not None:
|
| 1237 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 1238 |
+
if sigmas[i + 1] == 0:
|
| 1239 |
+
# Denoising step
|
| 1240 |
+
x = denoised
|
| 1241 |
+
else:
|
| 1242 |
+
alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp()
|
| 1243 |
+
alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp()
|
| 1244 |
+
d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise
|
| 1245 |
+
|
| 1246 |
+
# DDIM stochastic sampling
|
| 1247 |
+
sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta)
|
| 1248 |
+
sigma_down = alpha_t * sigma_down
|
| 1249 |
+
|
| 1250 |
+
# Euler method
|
| 1251 |
+
x = alpha_t * denoised + sigma_down * d
|
| 1252 |
+
if eta > 0 and s_noise > 0:
|
| 1253 |
+
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
| 1254 |
+
return x
|
| 1255 |
+
|
| 1256 |
+
|
| 1257 |
+
@torch.no_grad()
|
| 1258 |
+
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
| 1259 |
+
"""Euler method steps (CFG++)."""
|
| 1260 |
+
return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None)
|
| 1261 |
+
|
| 1262 |
+
|
| 1263 |
+
@torch.no_grad()
|
| 1264 |
+
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 1265 |
+
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
| 1266 |
+
extra_args = {} if extra_args is None else extra_args
|
| 1267 |
+
seed = extra_args.get("seed", None)
|
| 1268 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 1269 |
+
|
| 1270 |
+
temp = [0]
|
| 1271 |
+
def post_cfg_function(args):
|
| 1272 |
+
temp[0] = args["uncond_denoised"]
|
| 1273 |
+
return args["denoised"]
|
| 1274 |
+
|
| 1275 |
+
model_options = extra_args.get("model_options", {}).copy()
|
| 1276 |
+
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
| 1277 |
+
|
| 1278 |
+
s_in = x.new_ones([x.shape[0]])
|
| 1279 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 1280 |
+
t_fn = lambda sigma: sigma.log().neg()
|
| 1281 |
+
|
| 1282 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 1283 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 1284 |
+
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
| 1285 |
+
if callback is not None:
|
| 1286 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 1287 |
+
if sigma_down == 0:
|
| 1288 |
+
# Euler method
|
| 1289 |
+
d = to_d(x, sigmas[i], temp[0])
|
| 1290 |
+
x = denoised + d * sigma_down
|
| 1291 |
+
else:
|
| 1292 |
+
# DPM-Solver++(2S)
|
| 1293 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
| 1294 |
+
# r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
|
| 1295 |
+
r = 1 / 2
|
| 1296 |
+
h = t_next - t
|
| 1297 |
+
s = t + r * h
|
| 1298 |
+
x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
|
| 1299 |
+
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
| 1300 |
+
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
|
| 1301 |
+
# Noise addition
|
| 1302 |
+
if sigmas[i + 1] > 0:
|
| 1303 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
| 1304 |
+
return x
|
| 1305 |
+
|
| 1306 |
+
@torch.no_grad()
|
| 1307 |
+
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
| 1308 |
+
"""DPM-Solver++(2M)."""
|
| 1309 |
+
extra_args = {} if extra_args is None else extra_args
|
| 1310 |
+
s_in = x.new_ones([x.shape[0]])
|
| 1311 |
+
t_fn = lambda sigma: sigma.log().neg()
|
| 1312 |
+
|
| 1313 |
+
old_uncond_denoised = None
|
| 1314 |
+
uncond_denoised = None
|
| 1315 |
+
def post_cfg_function(args):
|
| 1316 |
+
nonlocal uncond_denoised
|
| 1317 |
+
uncond_denoised = args["uncond_denoised"]
|
| 1318 |
+
return args["denoised"]
|
| 1319 |
+
|
| 1320 |
+
model_options = extra_args.get("model_options", {}).copy()
|
| 1321 |
+
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
| 1322 |
+
|
| 1323 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 1324 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 1325 |
+
if callback is not None:
|
| 1326 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 1327 |
+
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
| 1328 |
+
h = t_next - t
|
| 1329 |
+
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
| 1330 |
+
denoised_mix = -torch.exp(-h) * uncond_denoised
|
| 1331 |
+
else:
|
| 1332 |
+
h_last = t - t_fn(sigmas[i - 1])
|
| 1333 |
+
r = h_last / h
|
| 1334 |
+
denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised)
|
| 1335 |
+
x = denoised + denoised_mix + torch.exp(-h) * x
|
| 1336 |
+
old_uncond_denoised = uncond_denoised
|
| 1337 |
+
return x
|
| 1338 |
+
|
| 1339 |
+
@torch.no_grad()
|
| 1340 |
+
def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False):
|
| 1341 |
+
extra_args = {} if extra_args is None else extra_args
|
| 1342 |
+
seed = extra_args.get("seed", None)
|
| 1343 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 1344 |
+
s_in = x.new_ones([x.shape[0]])
|
| 1345 |
+
sigma_fn = lambda t: t.neg().exp()
|
| 1346 |
+
t_fn = lambda sigma: sigma.log().neg()
|
| 1347 |
+
phi1_fn = lambda t: torch.expm1(t) / t
|
| 1348 |
+
phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
|
| 1349 |
+
|
| 1350 |
+
old_sigma_down = None
|
| 1351 |
+
old_denoised = None
|
| 1352 |
+
uncond_denoised = None
|
| 1353 |
+
def post_cfg_function(args):
|
| 1354 |
+
nonlocal uncond_denoised
|
| 1355 |
+
uncond_denoised = args["uncond_denoised"]
|
| 1356 |
+
return args["denoised"]
|
| 1357 |
+
|
| 1358 |
+
if cfg_pp:
|
| 1359 |
+
model_options = extra_args.get("model_options", {}).copy()
|
| 1360 |
+
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
| 1361 |
+
|
| 1362 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 1363 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 1364 |
+
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
| 1365 |
+
if callback is not None:
|
| 1366 |
+
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
| 1367 |
+
if sigma_down == 0 or old_denoised is None:
|
| 1368 |
+
# Euler method
|
| 1369 |
+
if cfg_pp:
|
| 1370 |
+
d = to_d(x, sigmas[i], uncond_denoised)
|
| 1371 |
+
x = denoised + d * sigma_down
|
| 1372 |
+
else:
|
| 1373 |
+
d = to_d(x, sigmas[i], denoised)
|
| 1374 |
+
dt = sigma_down - sigmas[i]
|
| 1375 |
+
x = x + d * dt
|
| 1376 |
+
else:
|
| 1377 |
+
# Second order multistep method in https://arxiv.org/pdf/2308.02157
|
| 1378 |
+
t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
|
| 1379 |
+
h = t_next - t
|
| 1380 |
+
c2 = (t_prev - t_old) / h
|
| 1381 |
+
|
| 1382 |
+
phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
|
| 1383 |
+
b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
|
| 1384 |
+
b2 = torch.nan_to_num(phi2_val / c2, nan=0.0)
|
| 1385 |
+
|
| 1386 |
+
if cfg_pp:
|
| 1387 |
+
x = x + (denoised - uncond_denoised)
|
| 1388 |
+
x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised)
|
| 1389 |
+
else:
|
| 1390 |
+
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised)
|
| 1391 |
+
|
| 1392 |
+
# Noise addition
|
| 1393 |
+
if sigmas[i + 1] > 0:
|
| 1394 |
+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
| 1395 |
+
|
| 1396 |
+
if cfg_pp:
|
| 1397 |
+
old_denoised = uncond_denoised
|
| 1398 |
+
else:
|
| 1399 |
+
old_denoised = denoised
|
| 1400 |
+
old_sigma_down = sigma_down
|
| 1401 |
+
return x
|
| 1402 |
+
|
| 1403 |
+
@torch.no_grad()
|
| 1404 |
+
def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
|
| 1405 |
+
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False)
|
| 1406 |
+
|
| 1407 |
+
@torch.no_grad()
|
| 1408 |
+
def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
|
| 1409 |
+
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True)
|
| 1410 |
+
|
| 1411 |
+
@torch.no_grad()
|
| 1412 |
+
def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 1413 |
+
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False)
|
| 1414 |
+
|
| 1415 |
+
@torch.no_grad()
|
| 1416 |
+
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
| 1417 |
+
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
|
| 1418 |
+
|
| 1419 |
+
|
| 1420 |
+
@torch.no_grad()
|
| 1421 |
+
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
|
| 1422 |
+
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
|
| 1423 |
+
extra_args = {} if extra_args is None else extra_args
|
| 1424 |
+
s_in = x.new_ones([x.shape[0]])
|
| 1425 |
+
old_d = None
|
| 1426 |
+
|
| 1427 |
+
uncond_denoised = None
|
| 1428 |
+
def post_cfg_function(args):
|
| 1429 |
+
nonlocal uncond_denoised
|
| 1430 |
+
uncond_denoised = args["uncond_denoised"]
|
| 1431 |
+
return args["denoised"]
|
| 1432 |
+
|
| 1433 |
+
if cfg_pp:
|
| 1434 |
+
model_options = extra_args.get("model_options", {}).copy()
|
| 1435 |
+
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
| 1436 |
+
|
| 1437 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 1438 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 1439 |
+
if cfg_pp:
|
| 1440 |
+
d = to_d(x, sigmas[i], uncond_denoised)
|
| 1441 |
+
else:
|
| 1442 |
+
d = to_d(x, sigmas[i], denoised)
|
| 1443 |
+
if callback is not None:
|
| 1444 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 1445 |
+
dt = sigmas[i + 1] - sigmas[i]
|
| 1446 |
+
if sigmas[i + 1] == 0:
|
| 1447 |
+
# Denoising step
|
| 1448 |
+
x = denoised
|
| 1449 |
+
else:
|
| 1450 |
+
# Euler method
|
| 1451 |
+
if cfg_pp:
|
| 1452 |
+
x = denoised + d * sigmas[i + 1]
|
| 1453 |
+
else:
|
| 1454 |
+
x = x + d * dt
|
| 1455 |
+
|
| 1456 |
+
if i >= 1:
|
| 1457 |
+
# Gradient estimation
|
| 1458 |
+
d_bar = (ge_gamma - 1) * (d - old_d)
|
| 1459 |
+
x = x + d_bar * dt
|
| 1460 |
+
old_d = d
|
| 1461 |
+
return x
|
| 1462 |
+
|
| 1463 |
+
|
| 1464 |
+
@torch.no_grad()
|
| 1465 |
+
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
|
| 1466 |
+
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
|
| 1467 |
+
|
| 1468 |
+
|
| 1469 |
+
@torch.no_grad()
|
| 1470 |
+
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
|
| 1471 |
+
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
|
| 1472 |
+
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
|
| 1473 |
+
"""
|
| 1474 |
+
extra_args = {} if extra_args is None else extra_args
|
| 1475 |
+
seed = extra_args.get("seed", None)
|
| 1476 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 1477 |
+
s_in = x.new_ones([x.shape[0]])
|
| 1478 |
+
|
| 1479 |
+
def default_er_sde_noise_scaler(x):
|
| 1480 |
+
return x * ((x ** 0.3).exp() + 10.0)
|
| 1481 |
+
|
| 1482 |
+
noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
|
| 1483 |
+
num_integration_points = 200.0
|
| 1484 |
+
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
|
| 1485 |
+
|
| 1486 |
+
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
| 1487 |
+
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
| 1488 |
+
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
|
| 1489 |
+
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t
|
| 1490 |
+
|
| 1491 |
+
old_denoised = None
|
| 1492 |
+
old_denoised_d = None
|
| 1493 |
+
|
| 1494 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 1495 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 1496 |
+
if callback is not None:
|
| 1497 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 1498 |
+
stage_used = min(max_stage, i + 1)
|
| 1499 |
+
if sigmas[i + 1] == 0:
|
| 1500 |
+
x = denoised
|
| 1501 |
+
else:
|
| 1502 |
+
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
|
| 1503 |
+
alpha_s = sigmas[i] / er_lambda_s
|
| 1504 |
+
alpha_t = sigmas[i + 1] / er_lambda_t
|
| 1505 |
+
r_alpha = alpha_t / alpha_s
|
| 1506 |
+
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)
|
| 1507 |
+
|
| 1508 |
+
# Stage 1 Euler
|
| 1509 |
+
x = r_alpha * r * x + alpha_t * (1 - r) * denoised
|
| 1510 |
+
|
| 1511 |
+
if stage_used >= 2:
|
| 1512 |
+
dt = er_lambda_t - er_lambda_s
|
| 1513 |
+
lambda_step_size = -dt / num_integration_points
|
| 1514 |
+
lambda_pos = er_lambda_t + point_indice * lambda_step_size
|
| 1515 |
+
scaled_pos = noise_scaler(lambda_pos)
|
| 1516 |
+
|
| 1517 |
+
# Stage 2
|
| 1518 |
+
s = torch.sum(1 / scaled_pos) * lambda_step_size
|
| 1519 |
+
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
|
| 1520 |
+
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d
|
| 1521 |
+
|
| 1522 |
+
if stage_used >= 3:
|
| 1523 |
+
# Stage 3
|
| 1524 |
+
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
|
| 1525 |
+
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
|
| 1526 |
+
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
|
| 1527 |
+
old_denoised_d = denoised_d
|
| 1528 |
+
|
| 1529 |
+
if s_noise > 0:
|
| 1530 |
+
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
| 1531 |
+
old_denoised = denoised
|
| 1532 |
+
return x
|
| 1533 |
+
|
| 1534 |
+
|
| 1535 |
+
@torch.no_grad()
|
| 1536 |
+
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
| 1537 |
+
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
| 1538 |
+
arXiv: https://arxiv.org/abs/2305.14267
|
| 1539 |
+
"""
|
| 1540 |
+
extra_args = {} if extra_args is None else extra_args
|
| 1541 |
+
seed = extra_args.get("seed", None)
|
| 1542 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 1543 |
+
s_in = x.new_ones([x.shape[0]])
|
| 1544 |
+
|
| 1545 |
+
inject_noise = eta > 0 and s_noise > 0
|
| 1546 |
+
|
| 1547 |
+
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
| 1548 |
+
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
| 1549 |
+
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
| 1550 |
+
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
| 1551 |
+
|
| 1552 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 1553 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 1554 |
+
if callback is not None:
|
| 1555 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 1556 |
+
if sigmas[i + 1] == 0:
|
| 1557 |
+
x = denoised
|
| 1558 |
+
else:
|
| 1559 |
+
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
| 1560 |
+
h = lambda_t - lambda_s
|
| 1561 |
+
h_eta = h * (eta + 1)
|
| 1562 |
+
lambda_s_1 = lambda_s + r * h
|
| 1563 |
+
fac = 1 / (2 * r)
|
| 1564 |
+
sigma_s_1 = sigma_fn(lambda_s_1)
|
| 1565 |
+
|
| 1566 |
+
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
| 1567 |
+
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
| 1568 |
+
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
| 1569 |
+
|
| 1570 |
+
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
| 1571 |
+
if inject_noise:
|
| 1572 |
+
# 0 < r < 1
|
| 1573 |
+
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
| 1574 |
+
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
|
| 1575 |
+
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
|
| 1576 |
+
|
| 1577 |
+
# Step 1
|
| 1578 |
+
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
| 1579 |
+
if inject_noise:
|
| 1580 |
+
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
| 1581 |
+
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
| 1582 |
+
|
| 1583 |
+
# Step 2
|
| 1584 |
+
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
| 1585 |
+
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
|
| 1586 |
+
if inject_noise:
|
| 1587 |
+
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
| 1588 |
+
return x
|
| 1589 |
+
|
| 1590 |
+
|
| 1591 |
+
@torch.no_grad()
|
| 1592 |
+
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
| 1593 |
+
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
|
| 1594 |
+
arXiv: https://arxiv.org/abs/2305.14267
|
| 1595 |
+
"""
|
| 1596 |
+
extra_args = {} if extra_args is None else extra_args
|
| 1597 |
+
seed = extra_args.get("seed", None)
|
| 1598 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 1599 |
+
s_in = x.new_ones([x.shape[0]])
|
| 1600 |
+
|
| 1601 |
+
inject_noise = eta > 0 and s_noise > 0
|
| 1602 |
+
|
| 1603 |
+
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
|
| 1604 |
+
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
|
| 1605 |
+
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
|
| 1606 |
+
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
| 1607 |
+
|
| 1608 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 1609 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 1610 |
+
if callback is not None:
|
| 1611 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 1612 |
+
if sigmas[i + 1] == 0:
|
| 1613 |
+
x = denoised
|
| 1614 |
+
else:
|
| 1615 |
+
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
|
| 1616 |
+
h = lambda_t - lambda_s
|
| 1617 |
+
h_eta = h * (eta + 1)
|
| 1618 |
+
lambda_s_1 = lambda_s + r_1 * h
|
| 1619 |
+
lambda_s_2 = lambda_s + r_2 * h
|
| 1620 |
+
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
|
| 1621 |
+
|
| 1622 |
+
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
|
| 1623 |
+
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
|
| 1624 |
+
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
|
| 1625 |
+
alpha_t = sigmas[i + 1] * lambda_t.exp()
|
| 1626 |
+
|
| 1627 |
+
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
| 1628 |
+
if inject_noise:
|
| 1629 |
+
# 0 < r_1 < r_2 < 1
|
| 1630 |
+
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
| 1631 |
+
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
|
| 1632 |
+
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
|
| 1633 |
+
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
| 1634 |
+
|
| 1635 |
+
# Step 1
|
| 1636 |
+
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
|
| 1637 |
+
if inject_noise:
|
| 1638 |
+
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
| 1639 |
+
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
| 1640 |
+
|
| 1641 |
+
# Step 2
|
| 1642 |
+
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
| 1643 |
+
if inject_noise:
|
| 1644 |
+
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
| 1645 |
+
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
| 1646 |
+
|
| 1647 |
+
# Step 3
|
| 1648 |
+
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
| 1649 |
+
if inject_noise:
|
| 1650 |
+
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
| 1651 |
+
return x
|
| 1652 |
+
|
| 1653 |
+
|
| 1654 |
+
@torch.no_grad()
|
| 1655 |
+
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False):
|
| 1656 |
+
"""Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023)."""
|
| 1657 |
+
if len(sigmas) <= 1:
|
| 1658 |
+
return x
|
| 1659 |
+
extra_args = {} if extra_args is None else extra_args
|
| 1660 |
+
seed = extra_args.get("seed", None)
|
| 1661 |
+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
| 1662 |
+
s_in = x.new_ones([x.shape[0]])
|
| 1663 |
+
|
| 1664 |
+
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
|
| 1665 |
+
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
|
| 1666 |
+
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
|
| 1667 |
+
|
| 1668 |
+
if tau_func is None:
|
| 1669 |
+
# Use default interval for stochastic sampling
|
| 1670 |
+
start_sigma = model_sampling.percent_to_sigma(0.2)
|
| 1671 |
+
end_sigma = model_sampling.percent_to_sigma(0.8)
|
| 1672 |
+
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0)
|
| 1673 |
+
|
| 1674 |
+
max_used_order = max(predictor_order, corrector_order)
|
| 1675 |
+
x_pred = x # x: current state, x_pred: predicted next state
|
| 1676 |
+
|
| 1677 |
+
h = 0.0
|
| 1678 |
+
tau_t = 0.0
|
| 1679 |
+
noise = 0.0
|
| 1680 |
+
pred_list = []
|
| 1681 |
+
|
| 1682 |
+
# Lower order near the end to improve stability
|
| 1683 |
+
lower_order_to_end = sigmas[-1].item() == 0
|
| 1684 |
+
|
| 1685 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 1686 |
+
# Evaluation
|
| 1687 |
+
denoised = model(x_pred, sigmas[i] * s_in, **extra_args)
|
| 1688 |
+
if callback is not None:
|
| 1689 |
+
callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
| 1690 |
+
pred_list.append(denoised)
|
| 1691 |
+
pred_list = pred_list[-max_used_order:]
|
| 1692 |
+
|
| 1693 |
+
predictor_order_used = min(predictor_order, len(pred_list))
|
| 1694 |
+
if i == 0 or (sigmas[i + 1] == 0 and not use_pece):
|
| 1695 |
+
corrector_order_used = 0
|
| 1696 |
+
else:
|
| 1697 |
+
corrector_order_used = min(corrector_order, len(pred_list))
|
| 1698 |
+
|
| 1699 |
+
if lower_order_to_end:
|
| 1700 |
+
predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i)
|
| 1701 |
+
corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i)
|
| 1702 |
+
|
| 1703 |
+
# Corrector
|
| 1704 |
+
if corrector_order_used == 0:
|
| 1705 |
+
# Update by the predicted state
|
| 1706 |
+
x = x_pred
|
| 1707 |
+
else:
|
| 1708 |
+
curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1]
|
| 1709 |
+
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
|
| 1710 |
+
sigmas[i],
|
| 1711 |
+
curr_lambdas,
|
| 1712 |
+
lambdas[i - 1],
|
| 1713 |
+
lambdas[i],
|
| 1714 |
+
tau_t,
|
| 1715 |
+
simple_order_2,
|
| 1716 |
+
is_corrector_step=True,
|
| 1717 |
+
)
|
| 1718 |
+
pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...)
|
| 1719 |
+
corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
|
| 1720 |
+
x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res
|
| 1721 |
+
|
| 1722 |
+
if tau_t > 0 and s_noise > 0:
|
| 1723 |
+
# The noise from the previous predictor step
|
| 1724 |
+
x = x + noise
|
| 1725 |
+
|
| 1726 |
+
if use_pece:
|
| 1727 |
+
# Evaluate the corrected state
|
| 1728 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 1729 |
+
pred_list[-1] = denoised
|
| 1730 |
+
|
| 1731 |
+
# Predictor
|
| 1732 |
+
if sigmas[i + 1] == 0:
|
| 1733 |
+
# Denoising step
|
| 1734 |
+
x = denoised
|
| 1735 |
+
else:
|
| 1736 |
+
tau_t = tau_func(sigmas[i + 1])
|
| 1737 |
+
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
| 1738 |
+
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
|
| 1739 |
+
sigmas[i + 1],
|
| 1740 |
+
curr_lambdas,
|
| 1741 |
+
lambdas[i],
|
| 1742 |
+
lambdas[i + 1],
|
| 1743 |
+
tau_t,
|
| 1744 |
+
simple_order_2,
|
| 1745 |
+
is_corrector_step=False,
|
| 1746 |
+
)
|
| 1747 |
+
pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...)
|
| 1748 |
+
pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
|
| 1749 |
+
h = lambdas[i + 1] - lambdas[i]
|
| 1750 |
+
x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res
|
| 1751 |
+
|
| 1752 |
+
if tau_t > 0 and s_noise > 0:
|
| 1753 |
+
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
| 1754 |
+
x_pred = x_pred + noise
|
| 1755 |
+
return x
|
| 1756 |
+
|
| 1757 |
+
|
| 1758 |
+
@torch.no_grad()
|
| 1759 |
+
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
| 1760 |
+
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
| 1761 |
+
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
ComfyUI/comfy/ldm/common_dit.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import comfy.rmsnorm
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
| 6 |
+
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
| 7 |
+
padding_mode = "reflect"
|
| 8 |
+
|
| 9 |
+
pad = ()
|
| 10 |
+
for i in range(img.ndim - 2):
|
| 11 |
+
pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
|
| 12 |
+
|
| 13 |
+
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
rms_norm = comfy.rmsnorm.rms_norm
|
ComfyUI/comfy/model_detection.py
ADDED
|
@@ -0,0 +1,910 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import comfy.supported_models
|
| 3 |
+
import comfy.supported_models_base
|
| 4 |
+
import comfy.utils
|
| 5 |
+
import math
|
| 6 |
+
import logging
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
def count_blocks(state_dict_keys, prefix_string):
|
| 10 |
+
count = 0
|
| 11 |
+
while True:
|
| 12 |
+
c = False
|
| 13 |
+
for k in state_dict_keys:
|
| 14 |
+
if k.startswith(prefix_string.format(count)):
|
| 15 |
+
c = True
|
| 16 |
+
break
|
| 17 |
+
if c == False:
|
| 18 |
+
break
|
| 19 |
+
count += 1
|
| 20 |
+
return count
|
| 21 |
+
|
| 22 |
+
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
| 23 |
+
context_dim = None
|
| 24 |
+
use_linear_in_transformer = False
|
| 25 |
+
|
| 26 |
+
transformer_prefix = prefix + "1.transformer_blocks."
|
| 27 |
+
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
|
| 28 |
+
if len(transformer_keys) > 0:
|
| 29 |
+
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
|
| 30 |
+
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
| 31 |
+
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
| 32 |
+
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
|
| 33 |
+
time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
|
| 34 |
+
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
def detect_unet_config(state_dict, key_prefix, metadata=None):
|
| 38 |
+
state_dict_keys = list(state_dict.keys())
|
| 39 |
+
|
| 40 |
+
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
| 41 |
+
unet_config = {}
|
| 42 |
+
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
|
| 43 |
+
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
|
| 44 |
+
unet_config["patch_size"] = patch_size
|
| 45 |
+
final_layer = '{}final_layer.linear.weight'.format(key_prefix)
|
| 46 |
+
if final_layer in state_dict:
|
| 47 |
+
unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size)
|
| 48 |
+
|
| 49 |
+
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
|
| 50 |
+
unet_config["input_size"] = None
|
| 51 |
+
y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
|
| 52 |
+
if y_key in state_dict_keys:
|
| 53 |
+
unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
|
| 54 |
+
|
| 55 |
+
context_key = '{}context_embedder.weight'.format(key_prefix)
|
| 56 |
+
if context_key in state_dict_keys:
|
| 57 |
+
in_features = state_dict[context_key].shape[1]
|
| 58 |
+
out_features = state_dict[context_key].shape[0]
|
| 59 |
+
unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
|
| 60 |
+
num_patches_key = '{}pos_embed'.format(key_prefix)
|
| 61 |
+
if num_patches_key in state_dict_keys:
|
| 62 |
+
num_patches = state_dict[num_patches_key].shape[1]
|
| 63 |
+
unet_config["num_patches"] = num_patches
|
| 64 |
+
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
|
| 65 |
+
|
| 66 |
+
rms_qk = '{}joint_blocks.0.context_block.attn.ln_q.weight'.format(key_prefix)
|
| 67 |
+
if rms_qk in state_dict_keys:
|
| 68 |
+
unet_config["qk_norm"] = "rms"
|
| 69 |
+
|
| 70 |
+
unet_config["pos_embed_scaling_factor"] = None #unused for inference
|
| 71 |
+
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
|
| 72 |
+
if context_processor in state_dict_keys:
|
| 73 |
+
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
|
| 74 |
+
unet_config["x_block_self_attn_layers"] = []
|
| 75 |
+
for key in state_dict_keys:
|
| 76 |
+
if key.startswith('{}joint_blocks.'.format(key_prefix)) and key.endswith('.x_block.attn2.qkv.weight'):
|
| 77 |
+
layer = key[len('{}joint_blocks.'.format(key_prefix)):-len('.x_block.attn2.qkv.weight')]
|
| 78 |
+
unet_config["x_block_self_attn_layers"].append(int(layer))
|
| 79 |
+
return unet_config
|
| 80 |
+
|
| 81 |
+
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
| 82 |
+
unet_config = {}
|
| 83 |
+
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
|
| 84 |
+
if text_mapper_name in state_dict_keys:
|
| 85 |
+
unet_config['stable_cascade_stage'] = 'c'
|
| 86 |
+
w = state_dict[text_mapper_name]
|
| 87 |
+
if w.shape[0] == 1536: #stage c lite
|
| 88 |
+
unet_config['c_cond'] = 1536
|
| 89 |
+
unet_config['c_hidden'] = [1536, 1536]
|
| 90 |
+
unet_config['nhead'] = [24, 24]
|
| 91 |
+
unet_config['blocks'] = [[4, 12], [12, 4]]
|
| 92 |
+
elif w.shape[0] == 2048: #stage c full
|
| 93 |
+
unet_config['c_cond'] = 2048
|
| 94 |
+
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
|
| 95 |
+
unet_config['stable_cascade_stage'] = 'b'
|
| 96 |
+
w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
|
| 97 |
+
if w.shape[-1] == 640:
|
| 98 |
+
unet_config['c_hidden'] = [320, 640, 1280, 1280]
|
| 99 |
+
unet_config['nhead'] = [-1, -1, 20, 20]
|
| 100 |
+
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
|
| 101 |
+
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
|
| 102 |
+
elif w.shape[-1] == 576: #stage b lite
|
| 103 |
+
unet_config['c_hidden'] = [320, 576, 1152, 1152]
|
| 104 |
+
unet_config['nhead'] = [-1, 9, 18, 18]
|
| 105 |
+
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
|
| 106 |
+
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
| 107 |
+
return unet_config
|
| 108 |
+
|
| 109 |
+
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
| 110 |
+
unet_config = {}
|
| 111 |
+
unet_config["audio_model"] = "dit1.0"
|
| 112 |
+
return unet_config
|
| 113 |
+
|
| 114 |
+
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
|
| 115 |
+
unet_config = {}
|
| 116 |
+
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
|
| 117 |
+
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
|
| 118 |
+
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
|
| 119 |
+
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
|
| 120 |
+
unet_config["n_double_layers"] = double_layers
|
| 121 |
+
unet_config["n_layers"] = double_layers + single_layers
|
| 122 |
+
return unet_config
|
| 123 |
+
|
| 124 |
+
if '{}mlp_t5.0.weight'.format(key_prefix) in state_dict_keys: #Hunyuan DiT
|
| 125 |
+
unet_config = {}
|
| 126 |
+
unet_config["image_model"] = "hydit"
|
| 127 |
+
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
| 128 |
+
unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
|
| 129 |
+
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2
|
| 130 |
+
unet_config["mlp_ratio"] = 4.3637
|
| 131 |
+
if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
|
| 132 |
+
unet_config["size_cond"] = True
|
| 133 |
+
unet_config["use_style_cond"] = True
|
| 134 |
+
unet_config["image_model"] = "hydit1"
|
| 135 |
+
return unet_config
|
| 136 |
+
|
| 137 |
+
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
|
| 138 |
+
dit_config = {}
|
| 139 |
+
dit_config["image_model"] = "hunyuan_video"
|
| 140 |
+
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
|
| 141 |
+
dit_config["patch_size"] = [1, 2, 2]
|
| 142 |
+
dit_config["out_channels"] = 16
|
| 143 |
+
dit_config["vec_in_dim"] = 768
|
| 144 |
+
dit_config["context_in_dim"] = 4096
|
| 145 |
+
dit_config["hidden_size"] = 3072
|
| 146 |
+
dit_config["mlp_ratio"] = 4.0
|
| 147 |
+
dit_config["num_heads"] = 24
|
| 148 |
+
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
| 149 |
+
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
| 150 |
+
dit_config["axes_dim"] = [16, 56, 56]
|
| 151 |
+
dit_config["theta"] = 256
|
| 152 |
+
dit_config["qkv_bias"] = True
|
| 153 |
+
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
|
| 154 |
+
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
| 155 |
+
return dit_config
|
| 156 |
+
|
| 157 |
+
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
|
| 158 |
+
dit_config = {}
|
| 159 |
+
dit_config["image_model"] = "flux"
|
| 160 |
+
dit_config["in_channels"] = 16
|
| 161 |
+
patch_size = 2
|
| 162 |
+
dit_config["patch_size"] = patch_size
|
| 163 |
+
in_key = "{}img_in.weight".format(key_prefix)
|
| 164 |
+
if in_key in state_dict_keys:
|
| 165 |
+
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
| 166 |
+
dit_config["out_channels"] = 16
|
| 167 |
+
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
| 168 |
+
if vec_in_key in state_dict_keys:
|
| 169 |
+
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
| 170 |
+
dit_config["context_in_dim"] = 4096
|
| 171 |
+
dit_config["hidden_size"] = 3072
|
| 172 |
+
dit_config["mlp_ratio"] = 4.0
|
| 173 |
+
dit_config["num_heads"] = 24
|
| 174 |
+
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
| 175 |
+
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
| 176 |
+
dit_config["axes_dim"] = [16, 56, 56]
|
| 177 |
+
dit_config["theta"] = 10000
|
| 178 |
+
dit_config["qkv_bias"] = True
|
| 179 |
+
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
| 180 |
+
dit_config["image_model"] = "chroma"
|
| 181 |
+
dit_config["in_channels"] = 64
|
| 182 |
+
dit_config["out_channels"] = 64
|
| 183 |
+
dit_config["in_dim"] = 64
|
| 184 |
+
dit_config["out_dim"] = 3072
|
| 185 |
+
dit_config["hidden_dim"] = 5120
|
| 186 |
+
dit_config["n_layers"] = 5
|
| 187 |
+
else:
|
| 188 |
+
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
| 189 |
+
return dit_config
|
| 190 |
+
|
| 191 |
+
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
| 192 |
+
dit_config = {}
|
| 193 |
+
dit_config["image_model"] = "mochi_preview"
|
| 194 |
+
dit_config["depth"] = 48
|
| 195 |
+
dit_config["patch_size"] = 2
|
| 196 |
+
dit_config["num_heads"] = 24
|
| 197 |
+
dit_config["hidden_size_x"] = 3072
|
| 198 |
+
dit_config["hidden_size_y"] = 1536
|
| 199 |
+
dit_config["mlp_ratio_x"] = 4.0
|
| 200 |
+
dit_config["mlp_ratio_y"] = 4.0
|
| 201 |
+
dit_config["learn_sigma"] = False
|
| 202 |
+
dit_config["in_channels"] = 12
|
| 203 |
+
dit_config["qk_norm"] = True
|
| 204 |
+
dit_config["qkv_bias"] = False
|
| 205 |
+
dit_config["out_bias"] = True
|
| 206 |
+
dit_config["attn_drop"] = 0.0
|
| 207 |
+
dit_config["patch_embed_bias"] = True
|
| 208 |
+
dit_config["posenc_preserve_area"] = True
|
| 209 |
+
dit_config["timestep_mlp_bias"] = True
|
| 210 |
+
dit_config["attend_to_padding"] = False
|
| 211 |
+
dit_config["timestep_scale"] = 1000.0
|
| 212 |
+
dit_config["use_t5"] = True
|
| 213 |
+
dit_config["t5_feat_dim"] = 4096
|
| 214 |
+
dit_config["t5_token_length"] = 256
|
| 215 |
+
dit_config["rope_theta"] = 10000.0
|
| 216 |
+
return dit_config
|
| 217 |
+
|
| 218 |
+
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys and '{}pos_embed.proj.bias'.format(key_prefix) in state_dict_keys:
|
| 219 |
+
# PixArt diffusers
|
| 220 |
+
return None
|
| 221 |
+
|
| 222 |
+
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
| 223 |
+
dit_config = {}
|
| 224 |
+
dit_config["image_model"] = "ltxv"
|
| 225 |
+
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
| 226 |
+
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
|
| 227 |
+
dit_config["attention_head_dim"] = shape[0] // 32
|
| 228 |
+
dit_config["cross_attention_dim"] = shape[1]
|
| 229 |
+
if metadata is not None and "config" in metadata:
|
| 230 |
+
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
| 231 |
+
return dit_config
|
| 232 |
+
|
| 233 |
+
if '{}genre_embedder.weight'.format(key_prefix) in state_dict_keys: #ACE-Step model
|
| 234 |
+
dit_config = {}
|
| 235 |
+
dit_config["audio_model"] = "ace"
|
| 236 |
+
dit_config["attention_head_dim"] = 128
|
| 237 |
+
dit_config["in_channels"] = 8
|
| 238 |
+
dit_config["inner_dim"] = 2560
|
| 239 |
+
dit_config["max_height"] = 16
|
| 240 |
+
dit_config["max_position"] = 32768
|
| 241 |
+
dit_config["max_width"] = 32768
|
| 242 |
+
dit_config["mlp_ratio"] = 2.5
|
| 243 |
+
dit_config["num_attention_heads"] = 20
|
| 244 |
+
dit_config["num_layers"] = 24
|
| 245 |
+
dit_config["out_channels"] = 8
|
| 246 |
+
dit_config["patch_size"] = [16, 1]
|
| 247 |
+
dit_config["rope_theta"] = 1000000.0
|
| 248 |
+
dit_config["speaker_embedding_dim"] = 512
|
| 249 |
+
dit_config["text_embedding_dim"] = 768
|
| 250 |
+
|
| 251 |
+
dit_config["ssl_encoder_depths"] = [8, 8]
|
| 252 |
+
dit_config["ssl_latent_dims"] = [1024, 768]
|
| 253 |
+
dit_config["ssl_names"] = ["mert", "m-hubert"]
|
| 254 |
+
dit_config["lyric_encoder_vocab_size"] = 6693
|
| 255 |
+
dit_config["lyric_hidden_size"] = 1024
|
| 256 |
+
return dit_config
|
| 257 |
+
|
| 258 |
+
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
| 259 |
+
patch_size = 2
|
| 260 |
+
dit_config = {}
|
| 261 |
+
dit_config["num_heads"] = 16
|
| 262 |
+
dit_config["patch_size"] = patch_size
|
| 263 |
+
dit_config["hidden_size"] = 1152
|
| 264 |
+
dit_config["in_channels"] = 4
|
| 265 |
+
dit_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
| 266 |
+
|
| 267 |
+
y_key = "{}y_embedder.y_embedding".format(key_prefix)
|
| 268 |
+
if y_key in state_dict_keys:
|
| 269 |
+
dit_config["model_max_length"] = state_dict[y_key].shape[0]
|
| 270 |
+
|
| 271 |
+
pe_key = "{}pos_embed".format(key_prefix)
|
| 272 |
+
if pe_key in state_dict_keys:
|
| 273 |
+
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
|
| 274 |
+
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
|
| 275 |
+
|
| 276 |
+
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
|
| 277 |
+
if ar_key in state_dict_keys:
|
| 278 |
+
dit_config["image_model"] = "pixart_alpha"
|
| 279 |
+
dit_config["micro_condition"] = True
|
| 280 |
+
else:
|
| 281 |
+
dit_config["image_model"] = "pixart_sigma"
|
| 282 |
+
dit_config["micro_condition"] = False
|
| 283 |
+
return dit_config
|
| 284 |
+
|
| 285 |
+
if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys: # Cosmos
|
| 286 |
+
dit_config = {}
|
| 287 |
+
dit_config["image_model"] = "cosmos"
|
| 288 |
+
dit_config["max_img_h"] = 240
|
| 289 |
+
dit_config["max_img_w"] = 240
|
| 290 |
+
dit_config["max_frames"] = 128
|
| 291 |
+
concat_padding_mask = True
|
| 292 |
+
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
| 293 |
+
dit_config["out_channels"] = 16
|
| 294 |
+
dit_config["patch_spatial"] = 2
|
| 295 |
+
dit_config["patch_temporal"] = 1
|
| 296 |
+
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
|
| 297 |
+
dit_config["block_config"] = "FA-CA-MLP"
|
| 298 |
+
dit_config["concat_padding_mask"] = concat_padding_mask
|
| 299 |
+
dit_config["pos_emb_cls"] = "rope3d"
|
| 300 |
+
dit_config["pos_emb_learnable"] = False
|
| 301 |
+
dit_config["pos_emb_interpolation"] = "crop"
|
| 302 |
+
dit_config["block_x_format"] = "THWBD"
|
| 303 |
+
dit_config["affline_emb_norm"] = True
|
| 304 |
+
dit_config["use_adaln_lora"] = True
|
| 305 |
+
dit_config["adaln_lora_dim"] = 256
|
| 306 |
+
|
| 307 |
+
if dit_config["model_channels"] == 4096:
|
| 308 |
+
# 7B
|
| 309 |
+
dit_config["num_blocks"] = 28
|
| 310 |
+
dit_config["num_heads"] = 32
|
| 311 |
+
dit_config["extra_per_block_abs_pos_emb"] = True
|
| 312 |
+
dit_config["rope_h_extrapolation_ratio"] = 1.0
|
| 313 |
+
dit_config["rope_w_extrapolation_ratio"] = 1.0
|
| 314 |
+
dit_config["rope_t_extrapolation_ratio"] = 2.0
|
| 315 |
+
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
| 316 |
+
else: # 5120
|
| 317 |
+
# 14B
|
| 318 |
+
dit_config["num_blocks"] = 36
|
| 319 |
+
dit_config["num_heads"] = 40
|
| 320 |
+
dit_config["extra_per_block_abs_pos_emb"] = True
|
| 321 |
+
dit_config["rope_h_extrapolation_ratio"] = 2.0
|
| 322 |
+
dit_config["rope_w_extrapolation_ratio"] = 2.0
|
| 323 |
+
dit_config["rope_t_extrapolation_ratio"] = 2.0
|
| 324 |
+
dit_config["extra_h_extrapolation_ratio"] = 2.0
|
| 325 |
+
dit_config["extra_w_extrapolation_ratio"] = 2.0
|
| 326 |
+
dit_config["extra_t_extrapolation_ratio"] = 2.0
|
| 327 |
+
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
| 328 |
+
return dit_config
|
| 329 |
+
|
| 330 |
+
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
| 331 |
+
dit_config = {}
|
| 332 |
+
dit_config["image_model"] = "lumina2"
|
| 333 |
+
dit_config["patch_size"] = 2
|
| 334 |
+
dit_config["in_channels"] = 16
|
| 335 |
+
dit_config["dim"] = 2304
|
| 336 |
+
dit_config["cap_feat_dim"] = 2304
|
| 337 |
+
dit_config["n_layers"] = 26
|
| 338 |
+
dit_config["n_heads"] = 24
|
| 339 |
+
dit_config["n_kv_heads"] = 8
|
| 340 |
+
dit_config["qk_norm"] = True
|
| 341 |
+
dit_config["axes_dims"] = [32, 32, 32]
|
| 342 |
+
dit_config["axes_lens"] = [300, 512, 512]
|
| 343 |
+
return dit_config
|
| 344 |
+
|
| 345 |
+
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
| 346 |
+
dit_config = {}
|
| 347 |
+
dit_config["image_model"] = "wan2.1"
|
| 348 |
+
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
|
| 349 |
+
out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
|
| 350 |
+
dit_config["dim"] = dim
|
| 351 |
+
dit_config["out_dim"] = out_dim
|
| 352 |
+
dit_config["num_heads"] = dim // 128
|
| 353 |
+
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
|
| 354 |
+
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
| 355 |
+
dit_config["patch_size"] = (1, 2, 2)
|
| 356 |
+
dit_config["freq_dim"] = 256
|
| 357 |
+
dit_config["window_size"] = (-1, -1)
|
| 358 |
+
dit_config["qk_norm"] = True
|
| 359 |
+
dit_config["cross_attn_norm"] = True
|
| 360 |
+
dit_config["eps"] = 1e-6
|
| 361 |
+
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
|
| 362 |
+
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
|
| 363 |
+
dit_config["model_type"] = "vace"
|
| 364 |
+
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
| 365 |
+
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
| 366 |
+
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
| 367 |
+
dit_config["model_type"] = "camera"
|
| 368 |
+
else:
|
| 369 |
+
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
| 370 |
+
dit_config["model_type"] = "i2v"
|
| 371 |
+
else:
|
| 372 |
+
dit_config["model_type"] = "t2v"
|
| 373 |
+
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
| 374 |
+
if flf_weight is not None:
|
| 375 |
+
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
| 376 |
+
return dit_config
|
| 377 |
+
|
| 378 |
+
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
| 379 |
+
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
|
| 380 |
+
dit_config = {}
|
| 381 |
+
dit_config["image_model"] = "hunyuan3d2"
|
| 382 |
+
dit_config["in_channels"] = in_shape[1]
|
| 383 |
+
dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
|
| 384 |
+
dit_config["hidden_size"] = in_shape[0]
|
| 385 |
+
dit_config["mlp_ratio"] = 4.0
|
| 386 |
+
dit_config["num_heads"] = 16
|
| 387 |
+
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
| 388 |
+
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
| 389 |
+
dit_config["qkv_bias"] = True
|
| 390 |
+
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
| 391 |
+
return dit_config
|
| 392 |
+
|
| 393 |
+
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
|
| 394 |
+
dit_config = {}
|
| 395 |
+
dit_config["image_model"] = "hidream"
|
| 396 |
+
dit_config["attention_head_dim"] = 128
|
| 397 |
+
dit_config["axes_dims_rope"] = [64, 32, 32]
|
| 398 |
+
dit_config["caption_channels"] = [4096, 4096]
|
| 399 |
+
dit_config["max_resolution"] = [128, 128]
|
| 400 |
+
dit_config["in_channels"] = 16
|
| 401 |
+
dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]
|
| 402 |
+
dit_config["num_attention_heads"] = 20
|
| 403 |
+
dit_config["num_routed_experts"] = 4
|
| 404 |
+
dit_config["num_activated_experts"] = 2
|
| 405 |
+
dit_config["num_layers"] = 16
|
| 406 |
+
dit_config["num_single_layers"] = 32
|
| 407 |
+
dit_config["out_channels"] = 16
|
| 408 |
+
dit_config["patch_size"] = 2
|
| 409 |
+
dit_config["text_emb_dim"] = 2048
|
| 410 |
+
return dit_config
|
| 411 |
+
|
| 412 |
+
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
|
| 413 |
+
dit_config = {}
|
| 414 |
+
dit_config["image_model"] = "cosmos_predict2"
|
| 415 |
+
dit_config["max_img_h"] = 240
|
| 416 |
+
dit_config["max_img_w"] = 240
|
| 417 |
+
dit_config["max_frames"] = 128
|
| 418 |
+
concat_padding_mask = True
|
| 419 |
+
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
| 420 |
+
dit_config["out_channels"] = 16
|
| 421 |
+
dit_config["patch_spatial"] = 2
|
| 422 |
+
dit_config["patch_temporal"] = 1
|
| 423 |
+
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
|
| 424 |
+
dit_config["concat_padding_mask"] = concat_padding_mask
|
| 425 |
+
dit_config["crossattn_emb_channels"] = 1024
|
| 426 |
+
dit_config["pos_emb_cls"] = "rope3d"
|
| 427 |
+
dit_config["pos_emb_learnable"] = True
|
| 428 |
+
dit_config["pos_emb_interpolation"] = "crop"
|
| 429 |
+
dit_config["min_fps"] = 1
|
| 430 |
+
dit_config["max_fps"] = 30
|
| 431 |
+
|
| 432 |
+
dit_config["use_adaln_lora"] = True
|
| 433 |
+
dit_config["adaln_lora_dim"] = 256
|
| 434 |
+
if dit_config["model_channels"] == 2048:
|
| 435 |
+
dit_config["num_blocks"] = 28
|
| 436 |
+
dit_config["num_heads"] = 16
|
| 437 |
+
elif dit_config["model_channels"] == 5120:
|
| 438 |
+
dit_config["num_blocks"] = 36
|
| 439 |
+
dit_config["num_heads"] = 40
|
| 440 |
+
|
| 441 |
+
if dit_config["in_channels"] == 16:
|
| 442 |
+
dit_config["extra_per_block_abs_pos_emb"] = False
|
| 443 |
+
dit_config["rope_h_extrapolation_ratio"] = 4.0
|
| 444 |
+
dit_config["rope_w_extrapolation_ratio"] = 4.0
|
| 445 |
+
dit_config["rope_t_extrapolation_ratio"] = 1.0
|
| 446 |
+
elif dit_config["in_channels"] == 17: # img to video
|
| 447 |
+
if dit_config["model_channels"] == 2048:
|
| 448 |
+
dit_config["extra_per_block_abs_pos_emb"] = False
|
| 449 |
+
dit_config["rope_h_extrapolation_ratio"] = 3.0
|
| 450 |
+
dit_config["rope_w_extrapolation_ratio"] = 3.0
|
| 451 |
+
dit_config["rope_t_extrapolation_ratio"] = 1.0
|
| 452 |
+
elif dit_config["model_channels"] == 5120:
|
| 453 |
+
dit_config["rope_h_extrapolation_ratio"] = 2.0
|
| 454 |
+
dit_config["rope_w_extrapolation_ratio"] = 2.0
|
| 455 |
+
dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334
|
| 456 |
+
|
| 457 |
+
dit_config["extra_h_extrapolation_ratio"] = 1.0
|
| 458 |
+
dit_config["extra_w_extrapolation_ratio"] = 1.0
|
| 459 |
+
dit_config["extra_t_extrapolation_ratio"] = 1.0
|
| 460 |
+
dit_config["rope_enable_fps_modulation"] = False
|
| 461 |
+
|
| 462 |
+
return dit_config
|
| 463 |
+
|
| 464 |
+
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
|
| 465 |
+
dit_config = {}
|
| 466 |
+
dit_config["image_model"] = "omnigen2"
|
| 467 |
+
dit_config["axes_dim_rope"] = [40, 40, 40]
|
| 468 |
+
dit_config["axes_lens"] = [1024, 1664, 1664]
|
| 469 |
+
dit_config["ffn_dim_multiplier"] = None
|
| 470 |
+
dit_config["hidden_size"] = 2520
|
| 471 |
+
dit_config["in_channels"] = 16
|
| 472 |
+
dit_config["multiple_of"] = 256
|
| 473 |
+
dit_config["norm_eps"] = 1e-05
|
| 474 |
+
dit_config["num_attention_heads"] = 21
|
| 475 |
+
dit_config["num_kv_heads"] = 7
|
| 476 |
+
dit_config["num_layers"] = 32
|
| 477 |
+
dit_config["num_refiner_layers"] = 2
|
| 478 |
+
dit_config["out_channels"] = None
|
| 479 |
+
dit_config["patch_size"] = 2
|
| 480 |
+
dit_config["text_feat_dim"] = 2048
|
| 481 |
+
dit_config["timestep_scale"] = 1000.0
|
| 482 |
+
return dit_config
|
| 483 |
+
|
| 484 |
+
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
| 485 |
+
return None
|
| 486 |
+
|
| 487 |
+
unet_config = {
|
| 488 |
+
"use_checkpoint": False,
|
| 489 |
+
"image_size": 32,
|
| 490 |
+
"use_spatial_transformer": True,
|
| 491 |
+
"legacy": False
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
y_input = '{}label_emb.0.0.weight'.format(key_prefix)
|
| 495 |
+
if y_input in state_dict_keys:
|
| 496 |
+
unet_config["num_classes"] = "sequential"
|
| 497 |
+
unet_config["adm_in_channels"] = state_dict[y_input].shape[1]
|
| 498 |
+
else:
|
| 499 |
+
unet_config["adm_in_channels"] = None
|
| 500 |
+
|
| 501 |
+
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
| 502 |
+
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
| 503 |
+
|
| 504 |
+
out_key = '{}out.2.weight'.format(key_prefix)
|
| 505 |
+
if out_key in state_dict:
|
| 506 |
+
out_channels = state_dict[out_key].shape[0]
|
| 507 |
+
else:
|
| 508 |
+
out_channels = 4
|
| 509 |
+
|
| 510 |
+
num_res_blocks = []
|
| 511 |
+
channel_mult = []
|
| 512 |
+
transformer_depth = []
|
| 513 |
+
transformer_depth_output = []
|
| 514 |
+
context_dim = None
|
| 515 |
+
use_linear_in_transformer = False
|
| 516 |
+
|
| 517 |
+
video_model = False
|
| 518 |
+
video_model_cross = False
|
| 519 |
+
|
| 520 |
+
current_res = 1
|
| 521 |
+
count = 0
|
| 522 |
+
|
| 523 |
+
last_res_blocks = 0
|
| 524 |
+
last_channel_mult = 0
|
| 525 |
+
|
| 526 |
+
input_block_count = count_blocks(state_dict_keys, '{}input_blocks'.format(key_prefix) + '.{}.')
|
| 527 |
+
for count in range(input_block_count):
|
| 528 |
+
prefix = '{}input_blocks.{}.'.format(key_prefix, count)
|
| 529 |
+
prefix_output = '{}output_blocks.{}.'.format(key_prefix, input_block_count - count - 1)
|
| 530 |
+
|
| 531 |
+
block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys)))
|
| 532 |
+
if len(block_keys) == 0:
|
| 533 |
+
break
|
| 534 |
+
|
| 535 |
+
block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys)))
|
| 536 |
+
|
| 537 |
+
if "{}0.op.weight".format(prefix) in block_keys: #new layer
|
| 538 |
+
num_res_blocks.append(last_res_blocks)
|
| 539 |
+
channel_mult.append(last_channel_mult)
|
| 540 |
+
|
| 541 |
+
current_res *= 2
|
| 542 |
+
last_res_blocks = 0
|
| 543 |
+
last_channel_mult = 0
|
| 544 |
+
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
|
| 545 |
+
if out is not None:
|
| 546 |
+
transformer_depth_output.append(out[0])
|
| 547 |
+
else:
|
| 548 |
+
transformer_depth_output.append(0)
|
| 549 |
+
else:
|
| 550 |
+
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
|
| 551 |
+
if res_block_prefix in block_keys:
|
| 552 |
+
last_res_blocks += 1
|
| 553 |
+
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
|
| 554 |
+
|
| 555 |
+
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
|
| 556 |
+
if out is not None:
|
| 557 |
+
transformer_depth.append(out[0])
|
| 558 |
+
if context_dim is None:
|
| 559 |
+
context_dim = out[1]
|
| 560 |
+
use_linear_in_transformer = out[2]
|
| 561 |
+
video_model = out[3]
|
| 562 |
+
video_model_cross = out[4]
|
| 563 |
+
else:
|
| 564 |
+
transformer_depth.append(0)
|
| 565 |
+
|
| 566 |
+
res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output)
|
| 567 |
+
if res_block_prefix in block_keys_output:
|
| 568 |
+
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
|
| 569 |
+
if out is not None:
|
| 570 |
+
transformer_depth_output.append(out[0])
|
| 571 |
+
else:
|
| 572 |
+
transformer_depth_output.append(0)
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
num_res_blocks.append(last_res_blocks)
|
| 576 |
+
channel_mult.append(last_channel_mult)
|
| 577 |
+
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
|
| 578 |
+
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
|
| 579 |
+
elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys:
|
| 580 |
+
transformer_depth_middle = -1
|
| 581 |
+
else:
|
| 582 |
+
transformer_depth_middle = -2
|
| 583 |
+
|
| 584 |
+
unet_config["in_channels"] = in_channels
|
| 585 |
+
unet_config["out_channels"] = out_channels
|
| 586 |
+
unet_config["model_channels"] = model_channels
|
| 587 |
+
unet_config["num_res_blocks"] = num_res_blocks
|
| 588 |
+
unet_config["transformer_depth"] = transformer_depth
|
| 589 |
+
unet_config["transformer_depth_output"] = transformer_depth_output
|
| 590 |
+
unet_config["channel_mult"] = channel_mult
|
| 591 |
+
unet_config["transformer_depth_middle"] = transformer_depth_middle
|
| 592 |
+
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
|
| 593 |
+
unet_config["context_dim"] = context_dim
|
| 594 |
+
|
| 595 |
+
if video_model:
|
| 596 |
+
unet_config["extra_ff_mix_layer"] = True
|
| 597 |
+
unet_config["use_spatial_context"] = True
|
| 598 |
+
unet_config["merge_strategy"] = "learned_with_images"
|
| 599 |
+
unet_config["merge_factor"] = 0.0
|
| 600 |
+
unet_config["video_kernel_size"] = [3, 1, 1]
|
| 601 |
+
unet_config["use_temporal_resblock"] = True
|
| 602 |
+
unet_config["use_temporal_attention"] = True
|
| 603 |
+
unet_config["disable_temporal_crossattention"] = not video_model_cross
|
| 604 |
+
else:
|
| 605 |
+
unet_config["use_temporal_resblock"] = False
|
| 606 |
+
unet_config["use_temporal_attention"] = False
|
| 607 |
+
|
| 608 |
+
return unet_config
|
| 609 |
+
|
| 610 |
+
def model_config_from_unet_config(unet_config, state_dict=None):
|
| 611 |
+
for model_config in comfy.supported_models.models:
|
| 612 |
+
if model_config.matches(unet_config, state_dict):
|
| 613 |
+
return model_config(unet_config)
|
| 614 |
+
|
| 615 |
+
logging.error("no match {}".format(unet_config))
|
| 616 |
+
return None
|
| 617 |
+
|
| 618 |
+
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None):
|
| 619 |
+
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
|
| 620 |
+
if unet_config is None:
|
| 621 |
+
return None
|
| 622 |
+
model_config = model_config_from_unet_config(unet_config, state_dict)
|
| 623 |
+
if model_config is None and use_base_if_no_match:
|
| 624 |
+
model_config = comfy.supported_models_base.BASE(unet_config)
|
| 625 |
+
|
| 626 |
+
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
|
| 627 |
+
if scaled_fp8_key in state_dict:
|
| 628 |
+
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
|
| 629 |
+
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
| 630 |
+
if model_config.scaled_fp8 == torch.float32:
|
| 631 |
+
model_config.scaled_fp8 = torch.float8_e4m3fn
|
| 632 |
+
if scaled_fp8_weight.nelement() == 2:
|
| 633 |
+
model_config.optimizations["fp8"] = False
|
| 634 |
+
else:
|
| 635 |
+
model_config.optimizations["fp8"] = True
|
| 636 |
+
|
| 637 |
+
return model_config
|
| 638 |
+
|
| 639 |
+
def unet_prefix_from_state_dict(state_dict):
|
| 640 |
+
candidates = ["model.diffusion_model.", #ldm/sgm models
|
| 641 |
+
"model.model.", #audio models
|
| 642 |
+
"net.", #cosmos
|
| 643 |
+
]
|
| 644 |
+
counts = {k: 0 for k in candidates}
|
| 645 |
+
for k in state_dict:
|
| 646 |
+
for c in candidates:
|
| 647 |
+
if k.startswith(c):
|
| 648 |
+
counts[c] += 1
|
| 649 |
+
break
|
| 650 |
+
|
| 651 |
+
top = max(counts, key=counts.get)
|
| 652 |
+
if counts[top] > 5:
|
| 653 |
+
return top
|
| 654 |
+
else:
|
| 655 |
+
return "model." #aura flow and others
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
def convert_config(unet_config):
|
| 659 |
+
new_config = unet_config.copy()
|
| 660 |
+
num_res_blocks = new_config.get("num_res_blocks", None)
|
| 661 |
+
channel_mult = new_config.get("channel_mult", None)
|
| 662 |
+
|
| 663 |
+
if isinstance(num_res_blocks, int):
|
| 664 |
+
num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| 665 |
+
|
| 666 |
+
if "attention_resolutions" in new_config:
|
| 667 |
+
attention_resolutions = new_config.pop("attention_resolutions")
|
| 668 |
+
transformer_depth = new_config.get("transformer_depth", None)
|
| 669 |
+
transformer_depth_middle = new_config.get("transformer_depth_middle", None)
|
| 670 |
+
|
| 671 |
+
if isinstance(transformer_depth, int):
|
| 672 |
+
transformer_depth = len(channel_mult) * [transformer_depth]
|
| 673 |
+
if transformer_depth_middle is None:
|
| 674 |
+
transformer_depth_middle = transformer_depth[-1]
|
| 675 |
+
t_in = []
|
| 676 |
+
t_out = []
|
| 677 |
+
s = 1
|
| 678 |
+
for i in range(len(num_res_blocks)):
|
| 679 |
+
res = num_res_blocks[i]
|
| 680 |
+
d = 0
|
| 681 |
+
if s in attention_resolutions:
|
| 682 |
+
d = transformer_depth[i]
|
| 683 |
+
|
| 684 |
+
t_in += [d] * res
|
| 685 |
+
t_out += [d] * (res + 1)
|
| 686 |
+
s *= 2
|
| 687 |
+
transformer_depth = t_in
|
| 688 |
+
new_config["transformer_depth"] = t_in
|
| 689 |
+
new_config["transformer_depth_output"] = t_out
|
| 690 |
+
new_config["transformer_depth_middle"] = transformer_depth_middle
|
| 691 |
+
|
| 692 |
+
new_config["num_res_blocks"] = num_res_blocks
|
| 693 |
+
return new_config
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
| 697 |
+
if "conv_in.weight" not in state_dict:
|
| 698 |
+
return None
|
| 699 |
+
|
| 700 |
+
match = {}
|
| 701 |
+
transformer_depth = []
|
| 702 |
+
|
| 703 |
+
attn_res = 1
|
| 704 |
+
down_blocks = count_blocks(state_dict, "down_blocks.{}")
|
| 705 |
+
for i in range(down_blocks):
|
| 706 |
+
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
|
| 707 |
+
res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
|
| 708 |
+
for ab in range(attn_blocks):
|
| 709 |
+
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
|
| 710 |
+
transformer_depth.append(transformer_count)
|
| 711 |
+
if transformer_count > 0:
|
| 712 |
+
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
|
| 713 |
+
|
| 714 |
+
attn_res *= 2
|
| 715 |
+
if attn_blocks == 0:
|
| 716 |
+
for i in range(res_blocks):
|
| 717 |
+
transformer_depth.append(0)
|
| 718 |
+
|
| 719 |
+
match["transformer_depth"] = transformer_depth
|
| 720 |
+
|
| 721 |
+
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
|
| 722 |
+
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
|
| 723 |
+
match["adm_in_channels"] = None
|
| 724 |
+
if "class_embedding.linear_1.weight" in state_dict:
|
| 725 |
+
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
|
| 726 |
+
elif "add_embedding.linear_1.weight" in state_dict:
|
| 727 |
+
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
| 728 |
+
|
| 729 |
+
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 730 |
+
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
| 731 |
+
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
| 732 |
+
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
| 733 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 734 |
+
|
| 735 |
+
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 736 |
+
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
|
| 737 |
+
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4,
|
| 738 |
+
'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0],
|
| 739 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 740 |
+
|
| 741 |
+
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 742 |
+
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2],
|
| 743 |
+
'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True,
|
| 744 |
+
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
| 745 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 746 |
+
|
| 747 |
+
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 748 |
+
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
| 749 |
+
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
|
| 750 |
+
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
| 751 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 752 |
+
|
| 753 |
+
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 754 |
+
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
| 755 |
+
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
|
| 756 |
+
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
| 757 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 758 |
+
|
| 759 |
+
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
| 760 |
+
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
| 761 |
+
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
| 762 |
+
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
| 763 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 764 |
+
|
| 765 |
+
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 766 |
+
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
| 767 |
+
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1,
|
| 768 |
+
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
|
| 769 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 770 |
+
|
| 771 |
+
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 772 |
+
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
| 773 |
+
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0,
|
| 774 |
+
'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 775 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 776 |
+
|
| 777 |
+
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 778 |
+
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
|
| 779 |
+
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
| 780 |
+
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
| 781 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 782 |
+
|
| 783 |
+
SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 784 |
+
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
|
| 785 |
+
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
| 786 |
+
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
| 787 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 788 |
+
|
| 789 |
+
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 790 |
+
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
| 791 |
+
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
|
| 792 |
+
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
| 793 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 794 |
+
|
| 795 |
+
Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 796 |
+
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
| 797 |
+
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2],
|
| 798 |
+
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
| 799 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 800 |
+
|
| 801 |
+
KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 802 |
+
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
| 803 |
+
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5],
|
| 804 |
+
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
| 805 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 806 |
+
|
| 807 |
+
KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 808 |
+
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
| 809 |
+
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6],
|
| 810 |
+
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
| 811 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 812 |
+
|
| 813 |
+
SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 814 |
+
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
|
| 815 |
+
'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True,
|
| 816 |
+
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
|
| 817 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
|
| 818 |
+
|
| 819 |
+
SD_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
| 820 |
+
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
|
| 821 |
+
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
| 822 |
+
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
| 823 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 824 |
+
|
| 825 |
+
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
| 826 |
+
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
| 827 |
+
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
| 828 |
+
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
| 829 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 830 |
+
|
| 831 |
+
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
|
| 832 |
+
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
| 833 |
+
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
|
| 834 |
+
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
| 835 |
+
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
| 836 |
+
|
| 837 |
+
supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
| 838 |
+
|
| 839 |
+
for unet_config in supported_models:
|
| 840 |
+
matches = True
|
| 841 |
+
for k in match:
|
| 842 |
+
if match[k] != unet_config[k]:
|
| 843 |
+
matches = False
|
| 844 |
+
break
|
| 845 |
+
if matches:
|
| 846 |
+
return convert_config(unet_config)
|
| 847 |
+
return None
|
| 848 |
+
|
| 849 |
+
def model_config_from_diffusers_unet(state_dict):
|
| 850 |
+
unet_config = unet_config_from_diffusers_unet(state_dict)
|
| 851 |
+
if unet_config is not None:
|
| 852 |
+
return model_config_from_unet_config(unet_config)
|
| 853 |
+
return None
|
| 854 |
+
|
| 855 |
+
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
| 856 |
+
out_sd = {}
|
| 857 |
+
|
| 858 |
+
if 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
|
| 859 |
+
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
| 860 |
+
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
| 861 |
+
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
| 862 |
+
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
|
| 863 |
+
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
| 864 |
+
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
|
| 865 |
+
elif 'x_embedder.weight' in state_dict: #Flux
|
| 866 |
+
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
| 867 |
+
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
| 868 |
+
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
| 869 |
+
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
| 870 |
+
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
| 871 |
+
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
| 872 |
+
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
| 873 |
+
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
| 874 |
+
else:
|
| 875 |
+
return None
|
| 876 |
+
|
| 877 |
+
for k in sd_map:
|
| 878 |
+
weight = state_dict.get(k, None)
|
| 879 |
+
if weight is not None:
|
| 880 |
+
t = sd_map[k]
|
| 881 |
+
|
| 882 |
+
if not isinstance(t, str):
|
| 883 |
+
if len(t) > 2:
|
| 884 |
+
fun = t[2]
|
| 885 |
+
else:
|
| 886 |
+
fun = lambda a: a
|
| 887 |
+
offset = t[1]
|
| 888 |
+
if offset is not None:
|
| 889 |
+
old_weight = out_sd.get(t[0], None)
|
| 890 |
+
if old_weight is None:
|
| 891 |
+
old_weight = torch.empty_like(weight)
|
| 892 |
+
if old_weight.shape[offset[0]] < offset[1] + offset[2]:
|
| 893 |
+
exp = list(weight.shape)
|
| 894 |
+
exp[offset[0]] = offset[1] + offset[2]
|
| 895 |
+
new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
|
| 896 |
+
new[:old_weight.shape[0]] = old_weight
|
| 897 |
+
old_weight = new
|
| 898 |
+
|
| 899 |
+
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
| 900 |
+
else:
|
| 901 |
+
old_weight = weight
|
| 902 |
+
w = weight
|
| 903 |
+
w[:] = fun(weight)
|
| 904 |
+
t = t[0]
|
| 905 |
+
out_sd[t] = old_weight
|
| 906 |
+
else:
|
| 907 |
+
out_sd[t] = weight
|
| 908 |
+
state_dict.pop(k)
|
| 909 |
+
|
| 910 |
+
return out_sd
|
ComfyUI/comfy/model_patcher.py
ADDED
|
@@ -0,0 +1,1215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is part of ComfyUI.
|
| 3 |
+
Copyright (C) 2024 Comfy
|
| 4 |
+
|
| 5 |
+
This program is free software: you can redistribute it and/or modify
|
| 6 |
+
it under the terms of the GNU General Public License as published by
|
| 7 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 8 |
+
(at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This program is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 13 |
+
GNU General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU General Public License
|
| 16 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import collections
|
| 22 |
+
import copy
|
| 23 |
+
import inspect
|
| 24 |
+
import logging
|
| 25 |
+
import math
|
| 26 |
+
import uuid
|
| 27 |
+
from typing import Callable, Optional
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
|
| 31 |
+
import comfy.float
|
| 32 |
+
import comfy.hooks
|
| 33 |
+
import comfy.lora
|
| 34 |
+
import comfy.model_management
|
| 35 |
+
import comfy.patcher_extension
|
| 36 |
+
import comfy.utils
|
| 37 |
+
from comfy.comfy_types import UnetWrapperFunction
|
| 38 |
+
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def string_to_seed(data):
|
| 42 |
+
crc = 0xFFFFFFFF
|
| 43 |
+
for byte in data:
|
| 44 |
+
if isinstance(byte, str):
|
| 45 |
+
byte = ord(byte)
|
| 46 |
+
crc ^= byte
|
| 47 |
+
for _ in range(8):
|
| 48 |
+
if crc & 1:
|
| 49 |
+
crc = (crc >> 1) ^ 0xEDB88320
|
| 50 |
+
else:
|
| 51 |
+
crc >>= 1
|
| 52 |
+
return crc ^ 0xFFFFFFFF
|
| 53 |
+
|
| 54 |
+
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
| 55 |
+
to = model_options["transformer_options"].copy()
|
| 56 |
+
|
| 57 |
+
if "patches_replace" not in to:
|
| 58 |
+
to["patches_replace"] = {}
|
| 59 |
+
else:
|
| 60 |
+
to["patches_replace"] = to["patches_replace"].copy()
|
| 61 |
+
|
| 62 |
+
if name not in to["patches_replace"]:
|
| 63 |
+
to["patches_replace"][name] = {}
|
| 64 |
+
else:
|
| 65 |
+
to["patches_replace"][name] = to["patches_replace"][name].copy()
|
| 66 |
+
|
| 67 |
+
if transformer_index is not None:
|
| 68 |
+
block = (block_name, number, transformer_index)
|
| 69 |
+
else:
|
| 70 |
+
block = (block_name, number)
|
| 71 |
+
to["patches_replace"][name][block] = patch
|
| 72 |
+
model_options["transformer_options"] = to
|
| 73 |
+
return model_options
|
| 74 |
+
|
| 75 |
+
def set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=False):
|
| 76 |
+
model_options["sampler_post_cfg_function"] = model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
| 77 |
+
if disable_cfg1_optimization:
|
| 78 |
+
model_options["disable_cfg1_optimization"] = True
|
| 79 |
+
return model_options
|
| 80 |
+
|
| 81 |
+
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
|
| 82 |
+
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
|
| 83 |
+
if disable_cfg1_optimization:
|
| 84 |
+
model_options["disable_cfg1_optimization"] = True
|
| 85 |
+
return model_options
|
| 86 |
+
|
| 87 |
+
def create_model_options_clone(orig_model_options: dict):
|
| 88 |
+
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
| 89 |
+
|
| 90 |
+
def create_hook_patches_clone(orig_hook_patches):
|
| 91 |
+
new_hook_patches = {}
|
| 92 |
+
for hook_ref in orig_hook_patches:
|
| 93 |
+
new_hook_patches[hook_ref] = {}
|
| 94 |
+
for k in orig_hook_patches[hook_ref]:
|
| 95 |
+
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
| 96 |
+
return new_hook_patches
|
| 97 |
+
|
| 98 |
+
def wipe_lowvram_weight(m):
|
| 99 |
+
if hasattr(m, "prev_comfy_cast_weights"):
|
| 100 |
+
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
| 101 |
+
del m.prev_comfy_cast_weights
|
| 102 |
+
|
| 103 |
+
if hasattr(m, "weight_function"):
|
| 104 |
+
m.weight_function = []
|
| 105 |
+
|
| 106 |
+
if hasattr(m, "bias_function"):
|
| 107 |
+
m.bias_function = []
|
| 108 |
+
|
| 109 |
+
def move_weight_functions(m, device):
|
| 110 |
+
if device is None:
|
| 111 |
+
return 0
|
| 112 |
+
|
| 113 |
+
memory = 0
|
| 114 |
+
if hasattr(m, "weight_function"):
|
| 115 |
+
for f in m.weight_function:
|
| 116 |
+
if hasattr(f, "move_to"):
|
| 117 |
+
memory += f.move_to(device=device)
|
| 118 |
+
|
| 119 |
+
if hasattr(m, "bias_function"):
|
| 120 |
+
for f in m.bias_function:
|
| 121 |
+
if hasattr(f, "move_to"):
|
| 122 |
+
memory += f.move_to(device=device)
|
| 123 |
+
return memory
|
| 124 |
+
|
| 125 |
+
class LowVramPatch:
|
| 126 |
+
def __init__(self, key, patches):
|
| 127 |
+
self.key = key
|
| 128 |
+
self.patches = patches
|
| 129 |
+
def __call__(self, weight):
|
| 130 |
+
intermediate_dtype = weight.dtype
|
| 131 |
+
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
| 132 |
+
intermediate_dtype = torch.float32
|
| 133 |
+
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
|
| 134 |
+
|
| 135 |
+
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
| 136 |
+
|
| 137 |
+
def get_key_weight(model, key):
|
| 138 |
+
set_func = None
|
| 139 |
+
convert_func = None
|
| 140 |
+
op_keys = key.rsplit('.', 1)
|
| 141 |
+
if len(op_keys) < 2:
|
| 142 |
+
weight = comfy.utils.get_attr(model, key)
|
| 143 |
+
else:
|
| 144 |
+
op = comfy.utils.get_attr(model, op_keys[0])
|
| 145 |
+
try:
|
| 146 |
+
set_func = getattr(op, "set_{}".format(op_keys[1]))
|
| 147 |
+
except AttributeError:
|
| 148 |
+
pass
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
|
| 152 |
+
except AttributeError:
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
weight = getattr(op, op_keys[1])
|
| 156 |
+
if convert_func is not None:
|
| 157 |
+
weight = comfy.utils.get_attr(model, key)
|
| 158 |
+
|
| 159 |
+
return weight, set_func, convert_func
|
| 160 |
+
|
| 161 |
+
class AutoPatcherEjector:
|
| 162 |
+
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
|
| 163 |
+
self.model = model
|
| 164 |
+
self.was_injected = False
|
| 165 |
+
self.prev_skip_injection = False
|
| 166 |
+
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
|
| 167 |
+
|
| 168 |
+
def __enter__(self):
|
| 169 |
+
self.was_injected = False
|
| 170 |
+
self.prev_skip_injection = self.model.skip_injection
|
| 171 |
+
if self.skip_and_inject_on_exit_only:
|
| 172 |
+
self.model.skip_injection = True
|
| 173 |
+
if self.model.is_injected:
|
| 174 |
+
self.model.eject_model()
|
| 175 |
+
self.was_injected = True
|
| 176 |
+
|
| 177 |
+
def __exit__(self, *args):
|
| 178 |
+
if self.skip_and_inject_on_exit_only:
|
| 179 |
+
self.model.skip_injection = self.prev_skip_injection
|
| 180 |
+
self.model.inject_model()
|
| 181 |
+
if self.was_injected and not self.model.skip_injection:
|
| 182 |
+
self.model.inject_model()
|
| 183 |
+
self.model.skip_injection = self.prev_skip_injection
|
| 184 |
+
|
| 185 |
+
class MemoryCounter:
|
| 186 |
+
def __init__(self, initial: int, minimum=0):
|
| 187 |
+
self.value = initial
|
| 188 |
+
self.minimum = minimum
|
| 189 |
+
# TODO: add a safe limit besides 0
|
| 190 |
+
|
| 191 |
+
def use(self, weight: torch.Tensor):
|
| 192 |
+
weight_size = weight.nelement() * weight.element_size()
|
| 193 |
+
if self.is_useable(weight_size):
|
| 194 |
+
self.decrement(weight_size)
|
| 195 |
+
return True
|
| 196 |
+
return False
|
| 197 |
+
|
| 198 |
+
def is_useable(self, used: int):
|
| 199 |
+
return self.value - used > self.minimum
|
| 200 |
+
|
| 201 |
+
def decrement(self, used: int):
|
| 202 |
+
self.value -= used
|
| 203 |
+
|
| 204 |
+
class ModelPatcher:
|
| 205 |
+
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
| 206 |
+
self.size = size
|
| 207 |
+
self.model = model
|
| 208 |
+
if not hasattr(self.model, 'device'):
|
| 209 |
+
logging.debug("Model doesn't have a device attribute.")
|
| 210 |
+
self.model.device = offload_device
|
| 211 |
+
elif self.model.device is None:
|
| 212 |
+
self.model.device = offload_device
|
| 213 |
+
|
| 214 |
+
self.patches = {}
|
| 215 |
+
self.backup = {}
|
| 216 |
+
self.object_patches = {}
|
| 217 |
+
self.object_patches_backup = {}
|
| 218 |
+
self.weight_wrapper_patches = {}
|
| 219 |
+
self.model_options = {"transformer_options":{}}
|
| 220 |
+
self.model_size()
|
| 221 |
+
self.load_device = load_device
|
| 222 |
+
self.offload_device = offload_device
|
| 223 |
+
self.weight_inplace_update = weight_inplace_update
|
| 224 |
+
self.force_cast_weights = False
|
| 225 |
+
self.patches_uuid = uuid.uuid4()
|
| 226 |
+
self.parent = None
|
| 227 |
+
|
| 228 |
+
self.attachments: dict[str] = {}
|
| 229 |
+
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
| 230 |
+
self.callbacks: dict[str, dict[str, list[Callable]]] = CallbacksMP.init_callbacks()
|
| 231 |
+
self.wrappers: dict[str, dict[str, list[Callable]]] = WrappersMP.init_wrappers()
|
| 232 |
+
|
| 233 |
+
self.is_injected = False
|
| 234 |
+
self.skip_injection = False
|
| 235 |
+
self.injections: dict[str, list[PatcherInjection]] = {}
|
| 236 |
+
|
| 237 |
+
self.hook_patches: dict[comfy.hooks._HookRef] = {}
|
| 238 |
+
self.hook_patches_backup: dict[comfy.hooks._HookRef] = None
|
| 239 |
+
self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {}
|
| 240 |
+
self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {}
|
| 241 |
+
self.current_hooks: Optional[comfy.hooks.HookGroup] = None
|
| 242 |
+
self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP at this time
|
| 243 |
+
self.is_clip = False
|
| 244 |
+
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
| 245 |
+
|
| 246 |
+
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
| 247 |
+
self.model.model_loaded_weight_memory = 0
|
| 248 |
+
|
| 249 |
+
if not hasattr(self.model, 'lowvram_patch_counter'):
|
| 250 |
+
self.model.lowvram_patch_counter = 0
|
| 251 |
+
|
| 252 |
+
if not hasattr(self.model, 'model_lowvram'):
|
| 253 |
+
self.model.model_lowvram = False
|
| 254 |
+
|
| 255 |
+
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
| 256 |
+
self.model.current_weight_patches_uuid = None
|
| 257 |
+
|
| 258 |
+
def model_size(self):
|
| 259 |
+
if self.size > 0:
|
| 260 |
+
return self.size
|
| 261 |
+
self.size = comfy.model_management.module_size(self.model)
|
| 262 |
+
return self.size
|
| 263 |
+
|
| 264 |
+
def loaded_size(self):
|
| 265 |
+
return self.model.model_loaded_weight_memory
|
| 266 |
+
|
| 267 |
+
def lowvram_patch_counter(self):
|
| 268 |
+
return self.model.lowvram_patch_counter
|
| 269 |
+
|
| 270 |
+
def clone(self):
|
| 271 |
+
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
| 272 |
+
n.patches = {}
|
| 273 |
+
for k in self.patches:
|
| 274 |
+
n.patches[k] = self.patches[k][:]
|
| 275 |
+
n.patches_uuid = self.patches_uuid
|
| 276 |
+
|
| 277 |
+
n.object_patches = self.object_patches.copy()
|
| 278 |
+
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
| 279 |
+
n.model_options = copy.deepcopy(self.model_options)
|
| 280 |
+
n.backup = self.backup
|
| 281 |
+
n.object_patches_backup = self.object_patches_backup
|
| 282 |
+
n.parent = self
|
| 283 |
+
|
| 284 |
+
n.force_cast_weights = self.force_cast_weights
|
| 285 |
+
|
| 286 |
+
# attachments
|
| 287 |
+
n.attachments = {}
|
| 288 |
+
for k in self.attachments:
|
| 289 |
+
if hasattr(self.attachments[k], "on_model_patcher_clone"):
|
| 290 |
+
n.attachments[k] = self.attachments[k].on_model_patcher_clone()
|
| 291 |
+
else:
|
| 292 |
+
n.attachments[k] = self.attachments[k]
|
| 293 |
+
# additional models
|
| 294 |
+
for k, c in self.additional_models.items():
|
| 295 |
+
n.additional_models[k] = [x.clone() for x in c]
|
| 296 |
+
# callbacks
|
| 297 |
+
for k, c in self.callbacks.items():
|
| 298 |
+
n.callbacks[k] = {}
|
| 299 |
+
for k1, c1 in c.items():
|
| 300 |
+
n.callbacks[k][k1] = c1.copy()
|
| 301 |
+
# sample wrappers
|
| 302 |
+
for k, w in self.wrappers.items():
|
| 303 |
+
n.wrappers[k] = {}
|
| 304 |
+
for k1, w1 in w.items():
|
| 305 |
+
n.wrappers[k][k1] = w1.copy()
|
| 306 |
+
# injection
|
| 307 |
+
n.is_injected = self.is_injected
|
| 308 |
+
n.skip_injection = self.skip_injection
|
| 309 |
+
for k, i in self.injections.items():
|
| 310 |
+
n.injections[k] = i.copy()
|
| 311 |
+
# hooks
|
| 312 |
+
n.hook_patches = create_hook_patches_clone(self.hook_patches)
|
| 313 |
+
n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup
|
| 314 |
+
for group in self.cached_hook_patches:
|
| 315 |
+
n.cached_hook_patches[group] = {}
|
| 316 |
+
for k in self.cached_hook_patches[group]:
|
| 317 |
+
n.cached_hook_patches[group][k] = self.cached_hook_patches[group][k]
|
| 318 |
+
n.hook_backup = self.hook_backup
|
| 319 |
+
n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks
|
| 320 |
+
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
|
| 321 |
+
n.is_clip = self.is_clip
|
| 322 |
+
n.hook_mode = self.hook_mode
|
| 323 |
+
|
| 324 |
+
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
| 325 |
+
callback(self, n)
|
| 326 |
+
return n
|
| 327 |
+
|
| 328 |
+
def is_clone(self, other):
|
| 329 |
+
if hasattr(other, 'model') and self.model is other.model:
|
| 330 |
+
return True
|
| 331 |
+
return False
|
| 332 |
+
|
| 333 |
+
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
| 334 |
+
if not self.is_clone(clone):
|
| 335 |
+
return False
|
| 336 |
+
|
| 337 |
+
if self.current_hooks != clone.current_hooks:
|
| 338 |
+
return False
|
| 339 |
+
if self.forced_hooks != clone.forced_hooks:
|
| 340 |
+
return False
|
| 341 |
+
if self.hook_patches.keys() != clone.hook_patches.keys():
|
| 342 |
+
return False
|
| 343 |
+
if self.attachments.keys() != clone.attachments.keys():
|
| 344 |
+
return False
|
| 345 |
+
if self.additional_models.keys() != clone.additional_models.keys():
|
| 346 |
+
return False
|
| 347 |
+
for key in self.callbacks:
|
| 348 |
+
if len(self.callbacks[key]) != len(clone.callbacks[key]):
|
| 349 |
+
return False
|
| 350 |
+
for key in self.wrappers:
|
| 351 |
+
if len(self.wrappers[key]) != len(clone.wrappers[key]):
|
| 352 |
+
return False
|
| 353 |
+
if self.injections.keys() != clone.injections.keys():
|
| 354 |
+
return False
|
| 355 |
+
|
| 356 |
+
if len(self.patches) == 0 and len(clone.patches) == 0:
|
| 357 |
+
return True
|
| 358 |
+
|
| 359 |
+
if self.patches_uuid == clone.patches_uuid:
|
| 360 |
+
if len(self.patches) != len(clone.patches):
|
| 361 |
+
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
|
| 362 |
+
else:
|
| 363 |
+
return True
|
| 364 |
+
|
| 365 |
+
def memory_required(self, input_shape):
|
| 366 |
+
return self.model.memory_required(input_shape=input_shape)
|
| 367 |
+
|
| 368 |
+
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
| 369 |
+
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
| 370 |
+
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
| 371 |
+
else:
|
| 372 |
+
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
| 373 |
+
if disable_cfg1_optimization:
|
| 374 |
+
self.model_options["disable_cfg1_optimization"] = True
|
| 375 |
+
|
| 376 |
+
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
| 377 |
+
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
| 378 |
+
|
| 379 |
+
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
|
| 380 |
+
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
|
| 381 |
+
|
| 382 |
+
def set_model_sampler_calc_cond_batch_function(self, sampler_calc_cond_batch_function):
|
| 383 |
+
self.model_options["sampler_calc_cond_batch_function"] = sampler_calc_cond_batch_function
|
| 384 |
+
|
| 385 |
+
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
|
| 386 |
+
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
| 387 |
+
|
| 388 |
+
def set_model_denoise_mask_function(self, denoise_mask_function):
|
| 389 |
+
self.model_options["denoise_mask_function"] = denoise_mask_function
|
| 390 |
+
|
| 391 |
+
def set_model_patch(self, patch, name):
|
| 392 |
+
to = self.model_options["transformer_options"]
|
| 393 |
+
if "patches" not in to:
|
| 394 |
+
to["patches"] = {}
|
| 395 |
+
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
| 396 |
+
|
| 397 |
+
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
| 398 |
+
self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index)
|
| 399 |
+
|
| 400 |
+
def set_model_attn1_patch(self, patch):
|
| 401 |
+
self.set_model_patch(patch, "attn1_patch")
|
| 402 |
+
|
| 403 |
+
def set_model_attn2_patch(self, patch):
|
| 404 |
+
self.set_model_patch(patch, "attn2_patch")
|
| 405 |
+
|
| 406 |
+
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
|
| 407 |
+
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
|
| 408 |
+
|
| 409 |
+
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
|
| 410 |
+
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
|
| 411 |
+
|
| 412 |
+
def set_model_attn1_output_patch(self, patch):
|
| 413 |
+
self.set_model_patch(patch, "attn1_output_patch")
|
| 414 |
+
|
| 415 |
+
def set_model_attn2_output_patch(self, patch):
|
| 416 |
+
self.set_model_patch(patch, "attn2_output_patch")
|
| 417 |
+
|
| 418 |
+
def set_model_input_block_patch(self, patch):
|
| 419 |
+
self.set_model_patch(patch, "input_block_patch")
|
| 420 |
+
|
| 421 |
+
def set_model_input_block_patch_after_skip(self, patch):
|
| 422 |
+
self.set_model_patch(patch, "input_block_patch_after_skip")
|
| 423 |
+
|
| 424 |
+
def set_model_output_block_patch(self, patch):
|
| 425 |
+
self.set_model_patch(patch, "output_block_patch")
|
| 426 |
+
|
| 427 |
+
def set_model_emb_patch(self, patch):
|
| 428 |
+
self.set_model_patch(patch, "emb_patch")
|
| 429 |
+
|
| 430 |
+
def set_model_forward_timestep_embed_patch(self, patch):
|
| 431 |
+
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
| 432 |
+
|
| 433 |
+
def add_object_patch(self, name, obj):
|
| 434 |
+
self.object_patches[name] = obj
|
| 435 |
+
|
| 436 |
+
def set_model_compute_dtype(self, dtype):
|
| 437 |
+
self.add_object_patch("manual_cast_dtype", dtype)
|
| 438 |
+
if dtype is not None:
|
| 439 |
+
self.force_cast_weights = True
|
| 440 |
+
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
|
| 441 |
+
|
| 442 |
+
def add_weight_wrapper(self, name, function):
|
| 443 |
+
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
| 444 |
+
self.patches_uuid = uuid.uuid4()
|
| 445 |
+
|
| 446 |
+
def get_model_object(self, name: str) -> torch.nn.Module:
|
| 447 |
+
"""Retrieves a nested attribute from an object using dot notation considering
|
| 448 |
+
object patches.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
name (str): The attribute path using dot notation (e.g. "model.layer.weight")
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
The value of the requested attribute
|
| 455 |
+
|
| 456 |
+
Example:
|
| 457 |
+
patcher = ModelPatcher()
|
| 458 |
+
weight = patcher.get_model_object("layer1.conv.weight")
|
| 459 |
+
"""
|
| 460 |
+
if name in self.object_patches:
|
| 461 |
+
return self.object_patches[name]
|
| 462 |
+
else:
|
| 463 |
+
if name in self.object_patches_backup:
|
| 464 |
+
return self.object_patches_backup[name]
|
| 465 |
+
else:
|
| 466 |
+
return comfy.utils.get_attr(self.model, name)
|
| 467 |
+
|
| 468 |
+
def model_patches_to(self, device):
|
| 469 |
+
to = self.model_options["transformer_options"]
|
| 470 |
+
if "patches" in to:
|
| 471 |
+
patches = to["patches"]
|
| 472 |
+
for name in patches:
|
| 473 |
+
patch_list = patches[name]
|
| 474 |
+
for i in range(len(patch_list)):
|
| 475 |
+
if hasattr(patch_list[i], "to"):
|
| 476 |
+
patch_list[i] = patch_list[i].to(device)
|
| 477 |
+
if "patches_replace" in to:
|
| 478 |
+
patches = to["patches_replace"]
|
| 479 |
+
for name in patches:
|
| 480 |
+
patch_list = patches[name]
|
| 481 |
+
for k in patch_list:
|
| 482 |
+
if hasattr(patch_list[k], "to"):
|
| 483 |
+
patch_list[k] = patch_list[k].to(device)
|
| 484 |
+
if "model_function_wrapper" in self.model_options:
|
| 485 |
+
wrap_func = self.model_options["model_function_wrapper"]
|
| 486 |
+
if hasattr(wrap_func, "to"):
|
| 487 |
+
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
| 488 |
+
|
| 489 |
+
def model_dtype(self):
|
| 490 |
+
if hasattr(self.model, "get_dtype"):
|
| 491 |
+
return self.model.get_dtype()
|
| 492 |
+
|
| 493 |
+
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
| 494 |
+
with self.use_ejected():
|
| 495 |
+
p = set()
|
| 496 |
+
model_sd = self.model.state_dict()
|
| 497 |
+
for k in patches:
|
| 498 |
+
offset = None
|
| 499 |
+
function = None
|
| 500 |
+
if isinstance(k, str):
|
| 501 |
+
key = k
|
| 502 |
+
else:
|
| 503 |
+
offset = k[1]
|
| 504 |
+
key = k[0]
|
| 505 |
+
if len(k) > 2:
|
| 506 |
+
function = k[2]
|
| 507 |
+
|
| 508 |
+
if key in model_sd:
|
| 509 |
+
p.add(k)
|
| 510 |
+
current_patches = self.patches.get(key, [])
|
| 511 |
+
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
| 512 |
+
self.patches[key] = current_patches
|
| 513 |
+
|
| 514 |
+
self.patches_uuid = uuid.uuid4()
|
| 515 |
+
return list(p)
|
| 516 |
+
|
| 517 |
+
def get_key_patches(self, filter_prefix=None):
|
| 518 |
+
model_sd = self.model_state_dict()
|
| 519 |
+
p = {}
|
| 520 |
+
for k in model_sd:
|
| 521 |
+
if filter_prefix is not None:
|
| 522 |
+
if not k.startswith(filter_prefix):
|
| 523 |
+
continue
|
| 524 |
+
bk = self.backup.get(k, None)
|
| 525 |
+
hbk = self.hook_backup.get(k, None)
|
| 526 |
+
weight, set_func, convert_func = get_key_weight(self.model, k)
|
| 527 |
+
if bk is not None:
|
| 528 |
+
weight = bk.weight
|
| 529 |
+
if hbk is not None:
|
| 530 |
+
weight = hbk[0]
|
| 531 |
+
if convert_func is None:
|
| 532 |
+
convert_func = lambda a, **kwargs: a
|
| 533 |
+
|
| 534 |
+
if k in self.patches:
|
| 535 |
+
p[k] = [(weight, convert_func)] + self.patches[k]
|
| 536 |
+
else:
|
| 537 |
+
p[k] = [(weight, convert_func)]
|
| 538 |
+
return p
|
| 539 |
+
|
| 540 |
+
def model_state_dict(self, filter_prefix=None):
|
| 541 |
+
with self.use_ejected():
|
| 542 |
+
sd = self.model.state_dict()
|
| 543 |
+
keys = list(sd.keys())
|
| 544 |
+
if filter_prefix is not None:
|
| 545 |
+
for k in keys:
|
| 546 |
+
if not k.startswith(filter_prefix):
|
| 547 |
+
sd.pop(k)
|
| 548 |
+
return sd
|
| 549 |
+
|
| 550 |
+
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
| 551 |
+
if key not in self.patches:
|
| 552 |
+
return
|
| 553 |
+
|
| 554 |
+
weight, set_func, convert_func = get_key_weight(self.model, key)
|
| 555 |
+
inplace_update = self.weight_inplace_update or inplace_update
|
| 556 |
+
|
| 557 |
+
if key not in self.backup:
|
| 558 |
+
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
| 559 |
+
|
| 560 |
+
if device_to is not None:
|
| 561 |
+
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
| 562 |
+
else:
|
| 563 |
+
temp_weight = weight.to(torch.float32, copy=True)
|
| 564 |
+
if convert_func is not None:
|
| 565 |
+
temp_weight = convert_func(temp_weight, inplace=True)
|
| 566 |
+
|
| 567 |
+
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
| 568 |
+
if set_func is None:
|
| 569 |
+
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
| 570 |
+
if inplace_update:
|
| 571 |
+
comfy.utils.copy_to_param(self.model, key, out_weight)
|
| 572 |
+
else:
|
| 573 |
+
comfy.utils.set_attr_param(self.model, key, out_weight)
|
| 574 |
+
else:
|
| 575 |
+
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
| 576 |
+
|
| 577 |
+
def _load_list(self):
|
| 578 |
+
loading = []
|
| 579 |
+
for n, m in self.model.named_modules():
|
| 580 |
+
params = []
|
| 581 |
+
skip = False
|
| 582 |
+
for name, param in m.named_parameters(recurse=False):
|
| 583 |
+
params.append(name)
|
| 584 |
+
for name, param in m.named_parameters(recurse=True):
|
| 585 |
+
if name not in params:
|
| 586 |
+
skip = True # skip random weights in non leaf modules
|
| 587 |
+
break
|
| 588 |
+
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
| 589 |
+
loading.append((comfy.model_management.module_size(m), n, m, params))
|
| 590 |
+
return loading
|
| 591 |
+
|
| 592 |
+
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
| 593 |
+
with self.use_ejected():
|
| 594 |
+
self.unpatch_hooks()
|
| 595 |
+
mem_counter = 0
|
| 596 |
+
patch_counter = 0
|
| 597 |
+
lowvram_counter = 0
|
| 598 |
+
loading = self._load_list()
|
| 599 |
+
|
| 600 |
+
load_completely = []
|
| 601 |
+
loading.sort(reverse=True)
|
| 602 |
+
for x in loading:
|
| 603 |
+
n = x[1]
|
| 604 |
+
m = x[2]
|
| 605 |
+
params = x[3]
|
| 606 |
+
module_mem = x[0]
|
| 607 |
+
|
| 608 |
+
lowvram_weight = False
|
| 609 |
+
|
| 610 |
+
weight_key = "{}.weight".format(n)
|
| 611 |
+
bias_key = "{}.bias".format(n)
|
| 612 |
+
|
| 613 |
+
if not full_load and hasattr(m, "comfy_cast_weights"):
|
| 614 |
+
if mem_counter + module_mem >= lowvram_model_memory:
|
| 615 |
+
lowvram_weight = True
|
| 616 |
+
lowvram_counter += 1
|
| 617 |
+
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
| 618 |
+
continue
|
| 619 |
+
|
| 620 |
+
cast_weight = self.force_cast_weights
|
| 621 |
+
if lowvram_weight:
|
| 622 |
+
if hasattr(m, "comfy_cast_weights"):
|
| 623 |
+
m.weight_function = []
|
| 624 |
+
m.bias_function = []
|
| 625 |
+
|
| 626 |
+
if weight_key in self.patches:
|
| 627 |
+
if force_patch_weights:
|
| 628 |
+
self.patch_weight_to_device(weight_key)
|
| 629 |
+
else:
|
| 630 |
+
m.weight_function = [LowVramPatch(weight_key, self.patches)]
|
| 631 |
+
patch_counter += 1
|
| 632 |
+
if bias_key in self.patches:
|
| 633 |
+
if force_patch_weights:
|
| 634 |
+
self.patch_weight_to_device(bias_key)
|
| 635 |
+
else:
|
| 636 |
+
m.bias_function = [LowVramPatch(bias_key, self.patches)]
|
| 637 |
+
patch_counter += 1
|
| 638 |
+
|
| 639 |
+
cast_weight = True
|
| 640 |
+
else:
|
| 641 |
+
if hasattr(m, "comfy_cast_weights"):
|
| 642 |
+
wipe_lowvram_weight(m)
|
| 643 |
+
|
| 644 |
+
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
| 645 |
+
mem_counter += module_mem
|
| 646 |
+
load_completely.append((module_mem, n, m, params))
|
| 647 |
+
|
| 648 |
+
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
| 649 |
+
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
| 650 |
+
m.comfy_cast_weights = True
|
| 651 |
+
|
| 652 |
+
if weight_key in self.weight_wrapper_patches:
|
| 653 |
+
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
|
| 654 |
+
|
| 655 |
+
if bias_key in self.weight_wrapper_patches:
|
| 656 |
+
m.bias_function.extend(self.weight_wrapper_patches[bias_key])
|
| 657 |
+
|
| 658 |
+
mem_counter += move_weight_functions(m, device_to)
|
| 659 |
+
|
| 660 |
+
load_completely.sort(reverse=True)
|
| 661 |
+
for x in load_completely:
|
| 662 |
+
n = x[1]
|
| 663 |
+
m = x[2]
|
| 664 |
+
params = x[3]
|
| 665 |
+
if hasattr(m, "comfy_patched_weights"):
|
| 666 |
+
if m.comfy_patched_weights == True:
|
| 667 |
+
continue
|
| 668 |
+
|
| 669 |
+
for param in params:
|
| 670 |
+
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
| 671 |
+
|
| 672 |
+
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
| 673 |
+
m.comfy_patched_weights = True
|
| 674 |
+
|
| 675 |
+
for x in load_completely:
|
| 676 |
+
x[2].to(device_to)
|
| 677 |
+
|
| 678 |
+
if lowvram_counter > 0:
|
| 679 |
+
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
| 680 |
+
self.model.model_lowvram = True
|
| 681 |
+
else:
|
| 682 |
+
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
| 683 |
+
self.model.model_lowvram = False
|
| 684 |
+
if full_load:
|
| 685 |
+
self.model.to(device_to)
|
| 686 |
+
mem_counter = self.model_size()
|
| 687 |
+
|
| 688 |
+
self.model.lowvram_patch_counter += patch_counter
|
| 689 |
+
self.model.device = device_to
|
| 690 |
+
self.model.model_loaded_weight_memory = mem_counter
|
| 691 |
+
self.model.current_weight_patches_uuid = self.patches_uuid
|
| 692 |
+
|
| 693 |
+
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
| 694 |
+
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
| 695 |
+
|
| 696 |
+
self.apply_hooks(self.forced_hooks, force_apply=True)
|
| 697 |
+
|
| 698 |
+
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
| 699 |
+
with self.use_ejected():
|
| 700 |
+
for k in self.object_patches:
|
| 701 |
+
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
| 702 |
+
if k not in self.object_patches_backup:
|
| 703 |
+
self.object_patches_backup[k] = old
|
| 704 |
+
|
| 705 |
+
if lowvram_model_memory == 0:
|
| 706 |
+
full_load = True
|
| 707 |
+
else:
|
| 708 |
+
full_load = False
|
| 709 |
+
|
| 710 |
+
if load_weights:
|
| 711 |
+
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
| 712 |
+
self.inject_model()
|
| 713 |
+
return self.model
|
| 714 |
+
|
| 715 |
+
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
| 716 |
+
self.eject_model()
|
| 717 |
+
if unpatch_weights:
|
| 718 |
+
self.unpatch_hooks()
|
| 719 |
+
if self.model.model_lowvram:
|
| 720 |
+
for m in self.model.modules():
|
| 721 |
+
move_weight_functions(m, device_to)
|
| 722 |
+
wipe_lowvram_weight(m)
|
| 723 |
+
|
| 724 |
+
self.model.model_lowvram = False
|
| 725 |
+
self.model.lowvram_patch_counter = 0
|
| 726 |
+
|
| 727 |
+
keys = list(self.backup.keys())
|
| 728 |
+
|
| 729 |
+
for k in keys:
|
| 730 |
+
bk = self.backup[k]
|
| 731 |
+
if bk.inplace_update:
|
| 732 |
+
comfy.utils.copy_to_param(self.model, k, bk.weight)
|
| 733 |
+
else:
|
| 734 |
+
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
| 735 |
+
|
| 736 |
+
self.model.current_weight_patches_uuid = None
|
| 737 |
+
self.backup.clear()
|
| 738 |
+
|
| 739 |
+
if device_to is not None:
|
| 740 |
+
self.model.to(device_to)
|
| 741 |
+
self.model.device = device_to
|
| 742 |
+
self.model.model_loaded_weight_memory = 0
|
| 743 |
+
|
| 744 |
+
for m in self.model.modules():
|
| 745 |
+
if hasattr(m, "comfy_patched_weights"):
|
| 746 |
+
del m.comfy_patched_weights
|
| 747 |
+
|
| 748 |
+
keys = list(self.object_patches_backup.keys())
|
| 749 |
+
for k in keys:
|
| 750 |
+
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
| 751 |
+
|
| 752 |
+
self.object_patches_backup.clear()
|
| 753 |
+
|
| 754 |
+
def partially_unload(self, device_to, memory_to_free=0):
|
| 755 |
+
with self.use_ejected():
|
| 756 |
+
hooks_unpatched = False
|
| 757 |
+
memory_freed = 0
|
| 758 |
+
patch_counter = 0
|
| 759 |
+
unload_list = self._load_list()
|
| 760 |
+
unload_list.sort()
|
| 761 |
+
for unload in unload_list:
|
| 762 |
+
if memory_to_free < memory_freed:
|
| 763 |
+
break
|
| 764 |
+
module_mem = unload[0]
|
| 765 |
+
n = unload[1]
|
| 766 |
+
m = unload[2]
|
| 767 |
+
params = unload[3]
|
| 768 |
+
|
| 769 |
+
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
| 770 |
+
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
| 771 |
+
move_weight = True
|
| 772 |
+
for param in params:
|
| 773 |
+
key = "{}.{}".format(n, param)
|
| 774 |
+
bk = self.backup.get(key, None)
|
| 775 |
+
if bk is not None:
|
| 776 |
+
if not lowvram_possible:
|
| 777 |
+
move_weight = False
|
| 778 |
+
break
|
| 779 |
+
|
| 780 |
+
if not hooks_unpatched:
|
| 781 |
+
self.unpatch_hooks()
|
| 782 |
+
hooks_unpatched = True
|
| 783 |
+
|
| 784 |
+
if bk.inplace_update:
|
| 785 |
+
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
| 786 |
+
else:
|
| 787 |
+
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
| 788 |
+
self.backup.pop(key)
|
| 789 |
+
|
| 790 |
+
weight_key = "{}.weight".format(n)
|
| 791 |
+
bias_key = "{}.bias".format(n)
|
| 792 |
+
if move_weight:
|
| 793 |
+
cast_weight = self.force_cast_weights
|
| 794 |
+
m.to(device_to)
|
| 795 |
+
module_mem += move_weight_functions(m, device_to)
|
| 796 |
+
if lowvram_possible:
|
| 797 |
+
if weight_key in self.patches:
|
| 798 |
+
m.weight_function.append(LowVramPatch(weight_key, self.patches))
|
| 799 |
+
patch_counter += 1
|
| 800 |
+
if bias_key in self.patches:
|
| 801 |
+
m.bias_function.append(LowVramPatch(bias_key, self.patches))
|
| 802 |
+
patch_counter += 1
|
| 803 |
+
cast_weight = True
|
| 804 |
+
|
| 805 |
+
if cast_weight:
|
| 806 |
+
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
| 807 |
+
m.comfy_cast_weights = True
|
| 808 |
+
m.comfy_patched_weights = False
|
| 809 |
+
memory_freed += module_mem
|
| 810 |
+
logging.debug("freed {}".format(n))
|
| 811 |
+
|
| 812 |
+
self.model.model_lowvram = True
|
| 813 |
+
self.model.lowvram_patch_counter += patch_counter
|
| 814 |
+
self.model.model_loaded_weight_memory -= memory_freed
|
| 815 |
+
return memory_freed
|
| 816 |
+
|
| 817 |
+
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
| 818 |
+
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
| 819 |
+
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
|
| 820 |
+
# TODO: force_patch_weights should not unload + reload full model
|
| 821 |
+
used = self.model.model_loaded_weight_memory
|
| 822 |
+
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
|
| 823 |
+
if unpatch_weights:
|
| 824 |
+
extra_memory += (used - self.model.model_loaded_weight_memory)
|
| 825 |
+
|
| 826 |
+
self.patch_model(load_weights=False)
|
| 827 |
+
full_load = False
|
| 828 |
+
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
| 829 |
+
self.apply_hooks(self.forced_hooks, force_apply=True)
|
| 830 |
+
return 0
|
| 831 |
+
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
| 832 |
+
full_load = True
|
| 833 |
+
current_used = self.model.model_loaded_weight_memory
|
| 834 |
+
try:
|
| 835 |
+
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
| 836 |
+
except Exception as e:
|
| 837 |
+
self.detach()
|
| 838 |
+
raise e
|
| 839 |
+
|
| 840 |
+
return self.model.model_loaded_weight_memory - current_used
|
| 841 |
+
|
| 842 |
+
def detach(self, unpatch_all=True):
|
| 843 |
+
self.eject_model()
|
| 844 |
+
self.model_patches_to(self.offload_device)
|
| 845 |
+
if unpatch_all:
|
| 846 |
+
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
| 847 |
+
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
|
| 848 |
+
callback(self, unpatch_all)
|
| 849 |
+
return self.model
|
| 850 |
+
|
| 851 |
+
def current_loaded_device(self):
|
| 852 |
+
return self.model.device
|
| 853 |
+
|
| 854 |
+
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
| 855 |
+
logging.warning("The ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
| 856 |
+
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
| 857 |
+
|
| 858 |
+
def cleanup(self):
|
| 859 |
+
self.clean_hooks()
|
| 860 |
+
if hasattr(self.model, "current_patcher"):
|
| 861 |
+
self.model.current_patcher = None
|
| 862 |
+
for callback in self.get_all_callbacks(CallbacksMP.ON_CLEANUP):
|
| 863 |
+
callback(self)
|
| 864 |
+
|
| 865 |
+
def add_callback(self, call_type: str, callback: Callable):
|
| 866 |
+
self.add_callback_with_key(call_type, None, callback)
|
| 867 |
+
|
| 868 |
+
def add_callback_with_key(self, call_type: str, key: str, callback: Callable):
|
| 869 |
+
c = self.callbacks.setdefault(call_type, {}).setdefault(key, [])
|
| 870 |
+
c.append(callback)
|
| 871 |
+
|
| 872 |
+
def remove_callbacks_with_key(self, call_type: str, key: str):
|
| 873 |
+
c = self.callbacks.get(call_type, {})
|
| 874 |
+
if key in c:
|
| 875 |
+
c.pop(key)
|
| 876 |
+
|
| 877 |
+
def get_callbacks(self, call_type: str, key: str):
|
| 878 |
+
return self.callbacks.get(call_type, {}).get(key, [])
|
| 879 |
+
|
| 880 |
+
def get_all_callbacks(self, call_type: str):
|
| 881 |
+
c_list = []
|
| 882 |
+
for c in self.callbacks.get(call_type, {}).values():
|
| 883 |
+
c_list.extend(c)
|
| 884 |
+
return c_list
|
| 885 |
+
|
| 886 |
+
def add_wrapper(self, wrapper_type: str, wrapper: Callable):
|
| 887 |
+
self.add_wrapper_with_key(wrapper_type, None, wrapper)
|
| 888 |
+
|
| 889 |
+
def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable):
|
| 890 |
+
w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
| 891 |
+
w.append(wrapper)
|
| 892 |
+
|
| 893 |
+
def remove_wrappers_with_key(self, wrapper_type: str, key: str):
|
| 894 |
+
w = self.wrappers.get(wrapper_type, {})
|
| 895 |
+
if key in w:
|
| 896 |
+
w.pop(key)
|
| 897 |
+
|
| 898 |
+
def get_wrappers(self, wrapper_type: str, key: str):
|
| 899 |
+
return self.wrappers.get(wrapper_type, {}).get(key, [])
|
| 900 |
+
|
| 901 |
+
def get_all_wrappers(self, wrapper_type: str):
|
| 902 |
+
w_list = []
|
| 903 |
+
for w in self.wrappers.get(wrapper_type, {}).values():
|
| 904 |
+
w_list.extend(w)
|
| 905 |
+
return w_list
|
| 906 |
+
|
| 907 |
+
def set_attachments(self, key: str, attachment):
|
| 908 |
+
self.attachments[key] = attachment
|
| 909 |
+
|
| 910 |
+
def remove_attachments(self, key: str):
|
| 911 |
+
if key in self.attachments:
|
| 912 |
+
self.attachments.pop(key)
|
| 913 |
+
|
| 914 |
+
def get_attachment(self, key: str):
|
| 915 |
+
return self.attachments.get(key, None)
|
| 916 |
+
|
| 917 |
+
def set_injections(self, key: str, injections: list[PatcherInjection]):
|
| 918 |
+
self.injections[key] = injections
|
| 919 |
+
|
| 920 |
+
def remove_injections(self, key: str):
|
| 921 |
+
if key in self.injections:
|
| 922 |
+
self.injections.pop(key)
|
| 923 |
+
|
| 924 |
+
def get_injections(self, key: str):
|
| 925 |
+
return self.injections.get(key, None)
|
| 926 |
+
|
| 927 |
+
def set_additional_models(self, key: str, models: list['ModelPatcher']):
|
| 928 |
+
self.additional_models[key] = models
|
| 929 |
+
|
| 930 |
+
def remove_additional_models(self, key: str):
|
| 931 |
+
if key in self.additional_models:
|
| 932 |
+
self.additional_models.pop(key)
|
| 933 |
+
|
| 934 |
+
def get_additional_models_with_key(self, key: str):
|
| 935 |
+
return self.additional_models.get(key, [])
|
| 936 |
+
|
| 937 |
+
def get_additional_models(self):
|
| 938 |
+
all_models = []
|
| 939 |
+
for models in self.additional_models.values():
|
| 940 |
+
all_models.extend(models)
|
| 941 |
+
return all_models
|
| 942 |
+
|
| 943 |
+
def get_nested_additional_models(self):
|
| 944 |
+
def _evaluate_sub_additional_models(prev_models: list[ModelPatcher], cache_set: set[ModelPatcher]):
|
| 945 |
+
'''Make sure circular references do not cause infinite recursion.'''
|
| 946 |
+
next_models = []
|
| 947 |
+
for model in prev_models:
|
| 948 |
+
candidates = model.get_additional_models()
|
| 949 |
+
for c in candidates:
|
| 950 |
+
if c not in cache_set:
|
| 951 |
+
next_models.append(c)
|
| 952 |
+
cache_set.add(c)
|
| 953 |
+
if len(next_models) == 0:
|
| 954 |
+
return prev_models
|
| 955 |
+
return prev_models + _evaluate_sub_additional_models(next_models, cache_set)
|
| 956 |
+
|
| 957 |
+
all_models = self.get_additional_models()
|
| 958 |
+
models_set = set(all_models)
|
| 959 |
+
real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set)
|
| 960 |
+
return real_all_models
|
| 961 |
+
|
| 962 |
+
def use_ejected(self, skip_and_inject_on_exit_only=False):
|
| 963 |
+
return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only)
|
| 964 |
+
|
| 965 |
+
def inject_model(self):
|
| 966 |
+
if self.is_injected or self.skip_injection:
|
| 967 |
+
return
|
| 968 |
+
for injections in self.injections.values():
|
| 969 |
+
for inj in injections:
|
| 970 |
+
inj.inject(self)
|
| 971 |
+
self.is_injected = True
|
| 972 |
+
if self.is_injected:
|
| 973 |
+
for callback in self.get_all_callbacks(CallbacksMP.ON_INJECT_MODEL):
|
| 974 |
+
callback(self)
|
| 975 |
+
|
| 976 |
+
def eject_model(self):
|
| 977 |
+
if not self.is_injected:
|
| 978 |
+
return
|
| 979 |
+
for injections in self.injections.values():
|
| 980 |
+
for inj in injections:
|
| 981 |
+
inj.eject(self)
|
| 982 |
+
self.is_injected = False
|
| 983 |
+
for callback in self.get_all_callbacks(CallbacksMP.ON_EJECT_MODEL):
|
| 984 |
+
callback(self)
|
| 985 |
+
|
| 986 |
+
def pre_run(self):
|
| 987 |
+
if hasattr(self.model, "current_patcher"):
|
| 988 |
+
self.model.current_patcher = self
|
| 989 |
+
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
| 990 |
+
callback(self)
|
| 991 |
+
|
| 992 |
+
def prepare_state(self, timestep):
|
| 993 |
+
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
| 994 |
+
callback(self, timestep)
|
| 995 |
+
|
| 996 |
+
def restore_hook_patches(self):
|
| 997 |
+
if self.hook_patches_backup is not None:
|
| 998 |
+
self.hook_patches = self.hook_patches_backup
|
| 999 |
+
self.hook_patches_backup = None
|
| 1000 |
+
|
| 1001 |
+
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
|
| 1002 |
+
self.hook_mode = hook_mode
|
| 1003 |
+
|
| 1004 |
+
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
| 1005 |
+
curr_t = t[0]
|
| 1006 |
+
reset_current_hooks = False
|
| 1007 |
+
transformer_options = model_options.get("transformer_options", {})
|
| 1008 |
+
for hook in hook_group.hooks:
|
| 1009 |
+
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
| 1010 |
+
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
| 1011 |
+
# this will cause the weights to be recalculated when sampling
|
| 1012 |
+
if changed:
|
| 1013 |
+
# reset current_hooks if contains hook that changed
|
| 1014 |
+
if self.current_hooks is not None:
|
| 1015 |
+
for current_hook in self.current_hooks.hooks:
|
| 1016 |
+
if current_hook == hook:
|
| 1017 |
+
reset_current_hooks = True
|
| 1018 |
+
break
|
| 1019 |
+
for cached_group in list(self.cached_hook_patches.keys()):
|
| 1020 |
+
if cached_group.contains(hook):
|
| 1021 |
+
self.cached_hook_patches.pop(cached_group)
|
| 1022 |
+
if reset_current_hooks:
|
| 1023 |
+
self.patch_hooks(None)
|
| 1024 |
+
|
| 1025 |
+
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
| 1026 |
+
registered: comfy.hooks.HookGroup = None):
|
| 1027 |
+
self.restore_hook_patches()
|
| 1028 |
+
if registered is None:
|
| 1029 |
+
registered = comfy.hooks.HookGroup()
|
| 1030 |
+
# handle WeightHooks
|
| 1031 |
+
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
|
| 1032 |
+
for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight):
|
| 1033 |
+
if hook.hook_ref not in self.hook_patches:
|
| 1034 |
+
weight_hooks_to_register.append(hook)
|
| 1035 |
+
else:
|
| 1036 |
+
registered.add(hook)
|
| 1037 |
+
if len(weight_hooks_to_register) > 0:
|
| 1038 |
+
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
|
| 1039 |
+
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
|
| 1040 |
+
for hook in weight_hooks_to_register:
|
| 1041 |
+
hook.add_hook_patches(self, model_options, target_dict, registered)
|
| 1042 |
+
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
|
| 1043 |
+
callback(self, hooks, target_dict, model_options, registered)
|
| 1044 |
+
return registered
|
| 1045 |
+
|
| 1046 |
+
def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
|
| 1047 |
+
with self.use_ejected():
|
| 1048 |
+
# NOTE: this mirrors behavior of add_patches func
|
| 1049 |
+
current_hook_patches: dict[str,list] = self.hook_patches.get(hook.hook_ref, {})
|
| 1050 |
+
p = set()
|
| 1051 |
+
model_sd = self.model.state_dict()
|
| 1052 |
+
for k in patches:
|
| 1053 |
+
offset = None
|
| 1054 |
+
function = None
|
| 1055 |
+
if isinstance(k, str):
|
| 1056 |
+
key = k
|
| 1057 |
+
else:
|
| 1058 |
+
offset = k[1]
|
| 1059 |
+
key = k[0]
|
| 1060 |
+
if len(k) > 2:
|
| 1061 |
+
function = k[2]
|
| 1062 |
+
|
| 1063 |
+
if key in model_sd:
|
| 1064 |
+
p.add(k)
|
| 1065 |
+
current_patches: list[tuple] = current_hook_patches.get(key, [])
|
| 1066 |
+
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
| 1067 |
+
current_hook_patches[key] = current_patches
|
| 1068 |
+
self.hook_patches[hook.hook_ref] = current_hook_patches
|
| 1069 |
+
# since should care about these patches too to determine if same model, reroll patches_uuid
|
| 1070 |
+
self.patches_uuid = uuid.uuid4()
|
| 1071 |
+
return list(p)
|
| 1072 |
+
|
| 1073 |
+
def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup):
|
| 1074 |
+
# combined_patches will contain weights of all relevant hooks, per key
|
| 1075 |
+
combined_patches = {}
|
| 1076 |
+
if hooks is not None:
|
| 1077 |
+
for hook in hooks.hooks:
|
| 1078 |
+
hook_patches: dict = self.hook_patches.get(hook.hook_ref, {})
|
| 1079 |
+
for key in hook_patches.keys():
|
| 1080 |
+
current_patches: list[tuple] = combined_patches.get(key, [])
|
| 1081 |
+
if math.isclose(hook.strength, 1.0):
|
| 1082 |
+
current_patches.extend(hook_patches[key])
|
| 1083 |
+
else:
|
| 1084 |
+
# patches are stored as tuples: (strength_patch, (tuple_with_weights,), strength_model)
|
| 1085 |
+
for patch in hook_patches[key]:
|
| 1086 |
+
new_patch = list(patch)
|
| 1087 |
+
new_patch[0] *= hook.strength
|
| 1088 |
+
current_patches.append(tuple(new_patch))
|
| 1089 |
+
combined_patches[key] = current_patches
|
| 1090 |
+
return combined_patches
|
| 1091 |
+
|
| 1092 |
+
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
|
| 1093 |
+
# TODO: return transformer_options dict with any additions from hooks
|
| 1094 |
+
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
|
| 1095 |
+
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
| 1096 |
+
self.patch_hooks(hooks=hooks)
|
| 1097 |
+
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
|
| 1098 |
+
callback(self, hooks)
|
| 1099 |
+
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
|
| 1100 |
+
|
| 1101 |
+
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
| 1102 |
+
with self.use_ejected():
|
| 1103 |
+
if hooks is not None:
|
| 1104 |
+
model_sd_keys = list(self.model_state_dict().keys())
|
| 1105 |
+
memory_counter = None
|
| 1106 |
+
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
| 1107 |
+
# TODO: minimum_counter should have a minimum that conforms to loaded model requirements
|
| 1108 |
+
memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device),
|
| 1109 |
+
minimum=comfy.model_management.minimum_inference_memory()*2)
|
| 1110 |
+
# if have cached weights for hooks, use it
|
| 1111 |
+
cached_weights = self.cached_hook_patches.get(hooks, None)
|
| 1112 |
+
if cached_weights is not None:
|
| 1113 |
+
model_sd_keys_set = set(model_sd_keys)
|
| 1114 |
+
for key in cached_weights:
|
| 1115 |
+
if key not in model_sd_keys:
|
| 1116 |
+
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
| 1117 |
+
continue
|
| 1118 |
+
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
| 1119 |
+
model_sd_keys_set.remove(key)
|
| 1120 |
+
self.unpatch_hooks(model_sd_keys_set)
|
| 1121 |
+
else:
|
| 1122 |
+
self.unpatch_hooks()
|
| 1123 |
+
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
| 1124 |
+
original_weights = None
|
| 1125 |
+
if len(relevant_patches) > 0:
|
| 1126 |
+
original_weights = self.get_key_patches()
|
| 1127 |
+
for key in relevant_patches:
|
| 1128 |
+
if key not in model_sd_keys:
|
| 1129 |
+
logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
|
| 1130 |
+
continue
|
| 1131 |
+
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
| 1132 |
+
memory_counter=memory_counter)
|
| 1133 |
+
else:
|
| 1134 |
+
self.unpatch_hooks()
|
| 1135 |
+
self.current_hooks = hooks
|
| 1136 |
+
|
| 1137 |
+
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
| 1138 |
+
if key not in self.hook_backup:
|
| 1139 |
+
weight: torch.Tensor = comfy.utils.get_attr(self.model, key)
|
| 1140 |
+
target_device = self.offload_device
|
| 1141 |
+
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
| 1142 |
+
used = memory_counter.use(weight)
|
| 1143 |
+
if used:
|
| 1144 |
+
target_device = weight.device
|
| 1145 |
+
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
|
| 1146 |
+
comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1]))
|
| 1147 |
+
|
| 1148 |
+
def clear_cached_hook_weights(self):
|
| 1149 |
+
self.cached_hook_patches.clear()
|
| 1150 |
+
self.patch_hooks(None)
|
| 1151 |
+
|
| 1152 |
+
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
| 1153 |
+
if key not in combined_patches:
|
| 1154 |
+
return
|
| 1155 |
+
|
| 1156 |
+
weight, set_func, convert_func = get_key_weight(self.model, key)
|
| 1157 |
+
weight: torch.Tensor
|
| 1158 |
+
if key not in self.hook_backup:
|
| 1159 |
+
target_device = self.offload_device
|
| 1160 |
+
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
| 1161 |
+
used = memory_counter.use(weight)
|
| 1162 |
+
if used:
|
| 1163 |
+
target_device = weight.device
|
| 1164 |
+
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
|
| 1165 |
+
# TODO: properly handle LowVramPatch, if it ends up an issue
|
| 1166 |
+
temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True)
|
| 1167 |
+
if convert_func is not None:
|
| 1168 |
+
temp_weight = convert_func(temp_weight, inplace=True)
|
| 1169 |
+
|
| 1170 |
+
out_weight = comfy.lora.calculate_weight(combined_patches[key],
|
| 1171 |
+
temp_weight,
|
| 1172 |
+
key, original_weights=original_weights)
|
| 1173 |
+
del original_weights[key]
|
| 1174 |
+
if set_func is None:
|
| 1175 |
+
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
| 1176 |
+
comfy.utils.copy_to_param(self.model, key, out_weight)
|
| 1177 |
+
else:
|
| 1178 |
+
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
|
| 1179 |
+
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
| 1180 |
+
# TODO: disable caching if not enough system RAM to do so
|
| 1181 |
+
target_device = self.offload_device
|
| 1182 |
+
used = memory_counter.use(weight)
|
| 1183 |
+
if used:
|
| 1184 |
+
target_device = weight.device
|
| 1185 |
+
self.cached_hook_patches.setdefault(hooks, {})
|
| 1186 |
+
self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device)
|
| 1187 |
+
del temp_weight
|
| 1188 |
+
del out_weight
|
| 1189 |
+
del weight
|
| 1190 |
+
|
| 1191 |
+
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
| 1192 |
+
with self.use_ejected():
|
| 1193 |
+
if len(self.hook_backup) == 0:
|
| 1194 |
+
self.current_hooks = None
|
| 1195 |
+
return
|
| 1196 |
+
keys = list(self.hook_backup.keys())
|
| 1197 |
+
if whitelist_keys_set:
|
| 1198 |
+
for k in keys:
|
| 1199 |
+
if k in whitelist_keys_set:
|
| 1200 |
+
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
| 1201 |
+
self.hook_backup.pop(k)
|
| 1202 |
+
else:
|
| 1203 |
+
for k in keys:
|
| 1204 |
+
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
| 1205 |
+
|
| 1206 |
+
self.hook_backup.clear()
|
| 1207 |
+
self.current_hooks = None
|
| 1208 |
+
|
| 1209 |
+
def clean_hooks(self):
|
| 1210 |
+
self.unpatch_hooks()
|
| 1211 |
+
self.clear_cached_hook_weights()
|
| 1212 |
+
|
| 1213 |
+
def __del__(self):
|
| 1214 |
+
self.detach(unpatch_all=False)
|
| 1215 |
+
|
ComfyUI/comfy/ops.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is part of ComfyUI.
|
| 3 |
+
Copyright (C) 2024 Stability AI
|
| 4 |
+
|
| 5 |
+
This program is free software: you can redistribute it and/or modify
|
| 6 |
+
it under the terms of the GNU General Public License as published by
|
| 7 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 8 |
+
(at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This program is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 13 |
+
GNU General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU General Public License
|
| 16 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import logging
|
| 21 |
+
import comfy.model_management
|
| 22 |
+
from comfy.cli_args import args, PerformanceFeature
|
| 23 |
+
import comfy.float
|
| 24 |
+
import comfy.rmsnorm
|
| 25 |
+
import contextlib
|
| 26 |
+
|
| 27 |
+
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
| 28 |
+
|
| 29 |
+
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
| 30 |
+
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
| 31 |
+
|
| 32 |
+
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
| 33 |
+
if input is not None:
|
| 34 |
+
if dtype is None:
|
| 35 |
+
dtype = input.dtype
|
| 36 |
+
if bias_dtype is None:
|
| 37 |
+
bias_dtype = dtype
|
| 38 |
+
if device is None:
|
| 39 |
+
device = input.device
|
| 40 |
+
|
| 41 |
+
offload_stream = comfy.model_management.get_offload_stream(device)
|
| 42 |
+
if offload_stream is not None:
|
| 43 |
+
wf_context = offload_stream
|
| 44 |
+
else:
|
| 45 |
+
wf_context = contextlib.nullcontext()
|
| 46 |
+
|
| 47 |
+
bias = None
|
| 48 |
+
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
| 49 |
+
if s.bias is not None:
|
| 50 |
+
has_function = len(s.bias_function) > 0
|
| 51 |
+
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
| 52 |
+
|
| 53 |
+
if has_function:
|
| 54 |
+
with wf_context:
|
| 55 |
+
for f in s.bias_function:
|
| 56 |
+
bias = f(bias)
|
| 57 |
+
|
| 58 |
+
has_function = len(s.weight_function) > 0
|
| 59 |
+
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
| 60 |
+
if has_function:
|
| 61 |
+
with wf_context:
|
| 62 |
+
for f in s.weight_function:
|
| 63 |
+
weight = f(weight)
|
| 64 |
+
|
| 65 |
+
comfy.model_management.sync_stream(device, offload_stream)
|
| 66 |
+
return weight, bias
|
| 67 |
+
|
| 68 |
+
class CastWeightBiasOp:
|
| 69 |
+
comfy_cast_weights = False
|
| 70 |
+
weight_function = []
|
| 71 |
+
bias_function = []
|
| 72 |
+
|
| 73 |
+
class disable_weight_init:
|
| 74 |
+
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
| 75 |
+
def reset_parameters(self):
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
def forward_comfy_cast_weights(self, input):
|
| 79 |
+
weight, bias = cast_bias_weight(self, input)
|
| 80 |
+
return torch.nn.functional.linear(input, weight, bias)
|
| 81 |
+
|
| 82 |
+
def forward(self, *args, **kwargs):
|
| 83 |
+
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
| 84 |
+
return self.forward_comfy_cast_weights(*args, **kwargs)
|
| 85 |
+
else:
|
| 86 |
+
return super().forward(*args, **kwargs)
|
| 87 |
+
|
| 88 |
+
class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
|
| 89 |
+
def reset_parameters(self):
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
def forward_comfy_cast_weights(self, input):
|
| 93 |
+
weight, bias = cast_bias_weight(self, input)
|
| 94 |
+
return self._conv_forward(input, weight, bias)
|
| 95 |
+
|
| 96 |
+
def forward(self, *args, **kwargs):
|
| 97 |
+
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
| 98 |
+
return self.forward_comfy_cast_weights(*args, **kwargs)
|
| 99 |
+
else:
|
| 100 |
+
return super().forward(*args, **kwargs)
|
| 101 |
+
|
| 102 |
+
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
|
| 103 |
+
def reset_parameters(self):
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
def forward_comfy_cast_weights(self, input):
|
| 107 |
+
weight, bias = cast_bias_weight(self, input)
|
| 108 |
+
return self._conv_forward(input, weight, bias)
|
| 109 |
+
|
| 110 |
+
def forward(self, *args, **kwargs):
|
| 111 |
+
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
| 112 |
+
return self.forward_comfy_cast_weights(*args, **kwargs)
|
| 113 |
+
else:
|
| 114 |
+
return super().forward(*args, **kwargs)
|
| 115 |
+
|
| 116 |
+
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
|
| 117 |
+
def reset_parameters(self):
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
def forward_comfy_cast_weights(self, input):
|
| 121 |
+
weight, bias = cast_bias_weight(self, input)
|
| 122 |
+
return self._conv_forward(input, weight, bias)
|
| 123 |
+
|
| 124 |
+
def forward(self, *args, **kwargs):
|
| 125 |
+
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
| 126 |
+
return self.forward_comfy_cast_weights(*args, **kwargs)
|
| 127 |
+
else:
|
| 128 |
+
return super().forward(*args, **kwargs)
|
| 129 |
+
|
| 130 |
+
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
|
| 131 |
+
def reset_parameters(self):
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
def forward_comfy_cast_weights(self, input):
|
| 135 |
+
weight, bias = cast_bias_weight(self, input)
|
| 136 |
+
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
| 137 |
+
|
| 138 |
+
def forward(self, *args, **kwargs):
|
| 139 |
+
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
| 140 |
+
return self.forward_comfy_cast_weights(*args, **kwargs)
|
| 141 |
+
else:
|
| 142 |
+
return super().forward(*args, **kwargs)
|
| 143 |
+
|
| 144 |
+
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
| 145 |
+
def reset_parameters(self):
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
def forward_comfy_cast_weights(self, input):
|
| 149 |
+
if self.weight is not None:
|
| 150 |
+
weight, bias = cast_bias_weight(self, input)
|
| 151 |
+
else:
|
| 152 |
+
weight = None
|
| 153 |
+
bias = None
|
| 154 |
+
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
| 155 |
+
|
| 156 |
+
def forward(self, *args, **kwargs):
|
| 157 |
+
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
| 158 |
+
return self.forward_comfy_cast_weights(*args, **kwargs)
|
| 159 |
+
else:
|
| 160 |
+
return super().forward(*args, **kwargs)
|
| 161 |
+
|
| 162 |
+
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
| 163 |
+
def reset_parameters(self):
|
| 164 |
+
self.bias = None
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
def forward_comfy_cast_weights(self, input):
|
| 168 |
+
if self.weight is not None:
|
| 169 |
+
weight, bias = cast_bias_weight(self, input)
|
| 170 |
+
else:
|
| 171 |
+
weight = None
|
| 172 |
+
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
| 173 |
+
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
| 174 |
+
|
| 175 |
+
def forward(self, *args, **kwargs):
|
| 176 |
+
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
| 177 |
+
return self.forward_comfy_cast_weights(*args, **kwargs)
|
| 178 |
+
else:
|
| 179 |
+
return super().forward(*args, **kwargs)
|
| 180 |
+
|
| 181 |
+
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
| 182 |
+
def reset_parameters(self):
|
| 183 |
+
return None
|
| 184 |
+
|
| 185 |
+
def forward_comfy_cast_weights(self, input, output_size=None):
|
| 186 |
+
num_spatial_dims = 2
|
| 187 |
+
output_padding = self._output_padding(
|
| 188 |
+
input, output_size, self.stride, self.padding, self.kernel_size,
|
| 189 |
+
num_spatial_dims, self.dilation)
|
| 190 |
+
|
| 191 |
+
weight, bias = cast_bias_weight(self, input)
|
| 192 |
+
return torch.nn.functional.conv_transpose2d(
|
| 193 |
+
input, weight, bias, self.stride, self.padding,
|
| 194 |
+
output_padding, self.groups, self.dilation)
|
| 195 |
+
|
| 196 |
+
def forward(self, *args, **kwargs):
|
| 197 |
+
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
| 198 |
+
return self.forward_comfy_cast_weights(*args, **kwargs)
|
| 199 |
+
else:
|
| 200 |
+
return super().forward(*args, **kwargs)
|
| 201 |
+
|
| 202 |
+
class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
|
| 203 |
+
def reset_parameters(self):
|
| 204 |
+
return None
|
| 205 |
+
|
| 206 |
+
def forward_comfy_cast_weights(self, input, output_size=None):
|
| 207 |
+
num_spatial_dims = 1
|
| 208 |
+
output_padding = self._output_padding(
|
| 209 |
+
input, output_size, self.stride, self.padding, self.kernel_size,
|
| 210 |
+
num_spatial_dims, self.dilation)
|
| 211 |
+
|
| 212 |
+
weight, bias = cast_bias_weight(self, input)
|
| 213 |
+
return torch.nn.functional.conv_transpose1d(
|
| 214 |
+
input, weight, bias, self.stride, self.padding,
|
| 215 |
+
output_padding, self.groups, self.dilation)
|
| 216 |
+
|
| 217 |
+
def forward(self, *args, **kwargs):
|
| 218 |
+
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
| 219 |
+
return self.forward_comfy_cast_weights(*args, **kwargs)
|
| 220 |
+
else:
|
| 221 |
+
return super().forward(*args, **kwargs)
|
| 222 |
+
|
| 223 |
+
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
| 224 |
+
def reset_parameters(self):
|
| 225 |
+
self.bias = None
|
| 226 |
+
return None
|
| 227 |
+
|
| 228 |
+
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
| 229 |
+
output_dtype = out_dtype
|
| 230 |
+
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
| 231 |
+
out_dtype = None
|
| 232 |
+
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
| 233 |
+
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
| 234 |
+
|
| 235 |
+
def forward(self, *args, **kwargs):
|
| 236 |
+
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
| 237 |
+
return self.forward_comfy_cast_weights(*args, **kwargs)
|
| 238 |
+
else:
|
| 239 |
+
if "out_dtype" in kwargs:
|
| 240 |
+
kwargs.pop("out_dtype")
|
| 241 |
+
return super().forward(*args, **kwargs)
|
| 242 |
+
|
| 243 |
+
@classmethod
|
| 244 |
+
def conv_nd(s, dims, *args, **kwargs):
|
| 245 |
+
if dims == 2:
|
| 246 |
+
return s.Conv2d(*args, **kwargs)
|
| 247 |
+
elif dims == 3:
|
| 248 |
+
return s.Conv3d(*args, **kwargs)
|
| 249 |
+
else:
|
| 250 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class manual_cast(disable_weight_init):
|
| 254 |
+
class Linear(disable_weight_init.Linear):
|
| 255 |
+
comfy_cast_weights = True
|
| 256 |
+
|
| 257 |
+
class Conv1d(disable_weight_init.Conv1d):
|
| 258 |
+
comfy_cast_weights = True
|
| 259 |
+
|
| 260 |
+
class Conv2d(disable_weight_init.Conv2d):
|
| 261 |
+
comfy_cast_weights = True
|
| 262 |
+
|
| 263 |
+
class Conv3d(disable_weight_init.Conv3d):
|
| 264 |
+
comfy_cast_weights = True
|
| 265 |
+
|
| 266 |
+
class GroupNorm(disable_weight_init.GroupNorm):
|
| 267 |
+
comfy_cast_weights = True
|
| 268 |
+
|
| 269 |
+
class LayerNorm(disable_weight_init.LayerNorm):
|
| 270 |
+
comfy_cast_weights = True
|
| 271 |
+
|
| 272 |
+
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
|
| 273 |
+
comfy_cast_weights = True
|
| 274 |
+
|
| 275 |
+
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
| 276 |
+
comfy_cast_weights = True
|
| 277 |
+
|
| 278 |
+
class RMSNorm(disable_weight_init.RMSNorm):
|
| 279 |
+
comfy_cast_weights = True
|
| 280 |
+
|
| 281 |
+
class Embedding(disable_weight_init.Embedding):
|
| 282 |
+
comfy_cast_weights = True
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def fp8_linear(self, input):
|
| 286 |
+
dtype = self.weight.dtype
|
| 287 |
+
if dtype not in [torch.float8_e4m3fn]:
|
| 288 |
+
return None
|
| 289 |
+
|
| 290 |
+
tensor_2d = False
|
| 291 |
+
if len(input.shape) == 2:
|
| 292 |
+
tensor_2d = True
|
| 293 |
+
input = input.unsqueeze(1)
|
| 294 |
+
|
| 295 |
+
input_shape = input.shape
|
| 296 |
+
input_dtype = input.dtype
|
| 297 |
+
if len(input.shape) == 3:
|
| 298 |
+
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
| 299 |
+
w = w.t()
|
| 300 |
+
|
| 301 |
+
scale_weight = self.scale_weight
|
| 302 |
+
scale_input = self.scale_input
|
| 303 |
+
if scale_weight is None:
|
| 304 |
+
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
| 305 |
+
else:
|
| 306 |
+
scale_weight = scale_weight.to(input.device)
|
| 307 |
+
|
| 308 |
+
if scale_input is None:
|
| 309 |
+
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
| 310 |
+
input = torch.clamp(input, min=-448, max=448, out=input)
|
| 311 |
+
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
| 312 |
+
else:
|
| 313 |
+
scale_input = scale_input.to(input.device)
|
| 314 |
+
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
| 315 |
+
|
| 316 |
+
if bias is not None:
|
| 317 |
+
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
| 318 |
+
else:
|
| 319 |
+
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
| 320 |
+
|
| 321 |
+
if isinstance(o, tuple):
|
| 322 |
+
o = o[0]
|
| 323 |
+
|
| 324 |
+
if tensor_2d:
|
| 325 |
+
return o.reshape(input_shape[0], -1)
|
| 326 |
+
|
| 327 |
+
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
| 328 |
+
|
| 329 |
+
return None
|
| 330 |
+
|
| 331 |
+
class fp8_ops(manual_cast):
|
| 332 |
+
class Linear(manual_cast.Linear):
|
| 333 |
+
def reset_parameters(self):
|
| 334 |
+
self.scale_weight = None
|
| 335 |
+
self.scale_input = None
|
| 336 |
+
return None
|
| 337 |
+
|
| 338 |
+
def forward_comfy_cast_weights(self, input):
|
| 339 |
+
try:
|
| 340 |
+
out = fp8_linear(self, input)
|
| 341 |
+
if out is not None:
|
| 342 |
+
return out
|
| 343 |
+
except Exception as e:
|
| 344 |
+
logging.info("Exception during fp8 op: {}".format(e))
|
| 345 |
+
|
| 346 |
+
weight, bias = cast_bias_weight(self, input)
|
| 347 |
+
return torch.nn.functional.linear(input, weight, bias)
|
| 348 |
+
|
| 349 |
+
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
| 350 |
+
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
| 351 |
+
class scaled_fp8_op(manual_cast):
|
| 352 |
+
class Linear(manual_cast.Linear):
|
| 353 |
+
def __init__(self, *args, **kwargs):
|
| 354 |
+
if override_dtype is not None:
|
| 355 |
+
kwargs['dtype'] = override_dtype
|
| 356 |
+
super().__init__(*args, **kwargs)
|
| 357 |
+
|
| 358 |
+
def reset_parameters(self):
|
| 359 |
+
if not hasattr(self, 'scale_weight'):
|
| 360 |
+
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
| 361 |
+
|
| 362 |
+
if not scale_input:
|
| 363 |
+
self.scale_input = None
|
| 364 |
+
|
| 365 |
+
if not hasattr(self, 'scale_input'):
|
| 366 |
+
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
| 367 |
+
return None
|
| 368 |
+
|
| 369 |
+
def forward_comfy_cast_weights(self, input):
|
| 370 |
+
if fp8_matrix_mult:
|
| 371 |
+
out = fp8_linear(self, input)
|
| 372 |
+
if out is not None:
|
| 373 |
+
return out
|
| 374 |
+
|
| 375 |
+
weight, bias = cast_bias_weight(self, input)
|
| 376 |
+
|
| 377 |
+
if weight.numel() < input.numel(): #TODO: optimize
|
| 378 |
+
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
| 379 |
+
else:
|
| 380 |
+
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
| 381 |
+
|
| 382 |
+
def convert_weight(self, weight, inplace=False, **kwargs):
|
| 383 |
+
if inplace:
|
| 384 |
+
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
| 385 |
+
return weight
|
| 386 |
+
else:
|
| 387 |
+
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
| 388 |
+
|
| 389 |
+
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
| 390 |
+
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
| 391 |
+
if inplace_update:
|
| 392 |
+
self.weight.data.copy_(weight)
|
| 393 |
+
else:
|
| 394 |
+
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
| 395 |
+
|
| 396 |
+
return scaled_fp8_op
|
| 397 |
+
|
| 398 |
+
CUBLAS_IS_AVAILABLE = False
|
| 399 |
+
try:
|
| 400 |
+
from cublas_ops import CublasLinear
|
| 401 |
+
CUBLAS_IS_AVAILABLE = True
|
| 402 |
+
except ImportError:
|
| 403 |
+
pass
|
| 404 |
+
|
| 405 |
+
if CUBLAS_IS_AVAILABLE:
|
| 406 |
+
class cublas_ops(disable_weight_init):
|
| 407 |
+
class Linear(CublasLinear, disable_weight_init.Linear):
|
| 408 |
+
def reset_parameters(self):
|
| 409 |
+
return None
|
| 410 |
+
|
| 411 |
+
def forward_comfy_cast_weights(self, input):
|
| 412 |
+
return super().forward(input)
|
| 413 |
+
|
| 414 |
+
def forward(self, *args, **kwargs):
|
| 415 |
+
return super().forward(*args, **kwargs)
|
| 416 |
+
|
| 417 |
+
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
| 418 |
+
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
| 419 |
+
if scaled_fp8 is not None:
|
| 420 |
+
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
| 421 |
+
|
| 422 |
+
if (
|
| 423 |
+
fp8_compute and
|
| 424 |
+
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
|
| 425 |
+
not disable_fast_fp8
|
| 426 |
+
):
|
| 427 |
+
return fp8_ops
|
| 428 |
+
|
| 429 |
+
if (
|
| 430 |
+
PerformanceFeature.CublasOps in args.fast and
|
| 431 |
+
CUBLAS_IS_AVAILABLE and
|
| 432 |
+
weight_dtype == torch.float16 and
|
| 433 |
+
(compute_dtype == torch.float16 or compute_dtype is None)
|
| 434 |
+
):
|
| 435 |
+
logging.info("Using cublas ops")
|
| 436 |
+
return cublas_ops
|
| 437 |
+
|
| 438 |
+
if compute_dtype is None or weight_dtype == compute_dtype:
|
| 439 |
+
return disable_weight_init
|
| 440 |
+
|
| 441 |
+
return manual_cast
|
ComfyUI/comfy/patcher_extension.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import Callable
|
| 3 |
+
|
| 4 |
+
class CallbacksMP:
|
| 5 |
+
ON_CLONE = "on_clone"
|
| 6 |
+
ON_LOAD = "on_load_after"
|
| 7 |
+
ON_DETACH = "on_detach_after"
|
| 8 |
+
ON_CLEANUP = "on_cleanup"
|
| 9 |
+
ON_PRE_RUN = "on_pre_run"
|
| 10 |
+
ON_PREPARE_STATE = "on_prepare_state"
|
| 11 |
+
ON_APPLY_HOOKS = "on_apply_hooks"
|
| 12 |
+
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
|
| 13 |
+
ON_INJECT_MODEL = "on_inject_model"
|
| 14 |
+
ON_EJECT_MODEL = "on_eject_model"
|
| 15 |
+
|
| 16 |
+
# callbacks dict is in the format:
|
| 17 |
+
# {"call_type": {"key": [Callable1, Callable2, ...]} }
|
| 18 |
+
@classmethod
|
| 19 |
+
def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
|
| 20 |
+
return {}
|
| 21 |
+
|
| 22 |
+
def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
| 23 |
+
add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
|
| 24 |
+
|
| 25 |
+
def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False):
|
| 26 |
+
if is_model_options:
|
| 27 |
+
transformer_options = transformer_options.setdefault("transformer_options", {})
|
| 28 |
+
callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {})
|
| 29 |
+
c = callbacks.setdefault(call_type, {}).setdefault(key, [])
|
| 30 |
+
c.append(callback)
|
| 31 |
+
|
| 32 |
+
def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False):
|
| 33 |
+
if is_model_options:
|
| 34 |
+
transformer_options = transformer_options.get("transformer_options", {})
|
| 35 |
+
c_list = []
|
| 36 |
+
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
| 37 |
+
c_list.extend(callbacks.get(call_type, {}).get(key, []))
|
| 38 |
+
return c_list
|
| 39 |
+
|
| 40 |
+
def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
|
| 41 |
+
if is_model_options:
|
| 42 |
+
transformer_options = transformer_options.get("transformer_options", {})
|
| 43 |
+
c_list = []
|
| 44 |
+
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
|
| 45 |
+
for c in callbacks.get(call_type, {}).values():
|
| 46 |
+
c_list.extend(c)
|
| 47 |
+
return c_list
|
| 48 |
+
|
| 49 |
+
class WrappersMP:
|
| 50 |
+
OUTER_SAMPLE = "outer_sample"
|
| 51 |
+
PREPARE_SAMPLING = "prepare_sampling"
|
| 52 |
+
SAMPLER_SAMPLE = "sampler_sample"
|
| 53 |
+
CALC_COND_BATCH = "calc_cond_batch"
|
| 54 |
+
APPLY_MODEL = "apply_model"
|
| 55 |
+
DIFFUSION_MODEL = "diffusion_model"
|
| 56 |
+
|
| 57 |
+
# wrappers dict is in the format:
|
| 58 |
+
# {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
|
| 59 |
+
@classmethod
|
| 60 |
+
def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
|
| 61 |
+
return {}
|
| 62 |
+
|
| 63 |
+
def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
| 64 |
+
add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
|
| 65 |
+
|
| 66 |
+
def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
|
| 67 |
+
if is_model_options:
|
| 68 |
+
transformer_options = transformer_options.setdefault("transformer_options", {})
|
| 69 |
+
wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {})
|
| 70 |
+
w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
|
| 71 |
+
w.append(wrapper)
|
| 72 |
+
|
| 73 |
+
def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False):
|
| 74 |
+
if is_model_options:
|
| 75 |
+
transformer_options = transformer_options.get("transformer_options", {})
|
| 76 |
+
w_list = []
|
| 77 |
+
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
| 78 |
+
w_list.extend(wrappers.get(wrapper_type, {}).get(key, []))
|
| 79 |
+
return w_list
|
| 80 |
+
|
| 81 |
+
def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
|
| 82 |
+
if is_model_options:
|
| 83 |
+
transformer_options = transformer_options.get("transformer_options", {})
|
| 84 |
+
w_list = []
|
| 85 |
+
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
|
| 86 |
+
for w in wrappers.get(wrapper_type, {}).values():
|
| 87 |
+
w_list.extend(w)
|
| 88 |
+
return w_list
|
| 89 |
+
|
| 90 |
+
class WrapperExecutor:
|
| 91 |
+
"""Handles call stack of wrappers around a function in an ordered manner."""
|
| 92 |
+
def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
|
| 93 |
+
# NOTE: class_obj exists so that wrappers surrounding a class method can access
|
| 94 |
+
# the class instance at runtime via executor.class_obj
|
| 95 |
+
self.original = original
|
| 96 |
+
self.class_obj = class_obj
|
| 97 |
+
self.wrappers = wrappers.copy()
|
| 98 |
+
self.idx = idx
|
| 99 |
+
self.is_last = idx == len(wrappers)
|
| 100 |
+
|
| 101 |
+
def __call__(self, *args, **kwargs):
|
| 102 |
+
"""Calls the next wrapper or original function, whichever is appropriate."""
|
| 103 |
+
new_executor = self._create_next_executor()
|
| 104 |
+
return new_executor.execute(*args, **kwargs)
|
| 105 |
+
|
| 106 |
+
def execute(self, *args, **kwargs):
|
| 107 |
+
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
|
| 108 |
+
args = list(args)
|
| 109 |
+
kwargs = dict(kwargs)
|
| 110 |
+
if self.is_last:
|
| 111 |
+
return self.original(*args, **kwargs)
|
| 112 |
+
return self.wrappers[self.idx](self, *args, **kwargs)
|
| 113 |
+
|
| 114 |
+
def _create_next_executor(self) -> 'WrapperExecutor':
|
| 115 |
+
new_idx = self.idx + 1
|
| 116 |
+
if new_idx > len(self.wrappers):
|
| 117 |
+
raise Exception("Wrapper idx exceeded available wrappers; something went very wrong.")
|
| 118 |
+
if self.class_obj is None:
|
| 119 |
+
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
|
| 120 |
+
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
|
| 121 |
+
|
| 122 |
+
@classmethod
|
| 123 |
+
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
|
| 124 |
+
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
|
| 125 |
+
|
| 126 |
+
@classmethod
|
| 127 |
+
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
|
| 128 |
+
return cls(original, class_obj, wrappers, idx=idx)
|
| 129 |
+
|
| 130 |
+
class PatcherInjection:
|
| 131 |
+
def __init__(self, inject: Callable, eject: Callable):
|
| 132 |
+
self.inject = inject
|
| 133 |
+
self.eject = eject
|
| 134 |
+
|
| 135 |
+
def copy_nested_dicts(input_dict: dict):
|
| 136 |
+
new_dict = input_dict.copy()
|
| 137 |
+
for key, value in input_dict.items():
|
| 138 |
+
if isinstance(value, dict):
|
| 139 |
+
new_dict[key] = copy_nested_dicts(value)
|
| 140 |
+
elif isinstance(value, list):
|
| 141 |
+
new_dict[key] = value.copy()
|
| 142 |
+
return new_dict
|
| 143 |
+
|
| 144 |
+
def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
|
| 145 |
+
if copy_dict1:
|
| 146 |
+
merged_dict = copy_nested_dicts(dict1)
|
| 147 |
+
else:
|
| 148 |
+
merged_dict = dict1
|
| 149 |
+
for key, value in dict2.items():
|
| 150 |
+
if isinstance(value, dict):
|
| 151 |
+
curr_value = merged_dict.setdefault(key, {})
|
| 152 |
+
merged_dict[key] = merge_nested_dicts(value, curr_value)
|
| 153 |
+
elif isinstance(value, list):
|
| 154 |
+
merged_dict.setdefault(key, []).extend(value)
|
| 155 |
+
else:
|
| 156 |
+
merged_dict[key] = value
|
| 157 |
+
return merged_dict
|
ComfyUI/comfy/sample.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import comfy.model_management
|
| 3 |
+
import comfy.samplers
|
| 4 |
+
import comfy.utils
|
| 5 |
+
import numpy as np
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
def prepare_noise(latent_image, seed, noise_inds=None):
|
| 9 |
+
"""
|
| 10 |
+
creates random noise given a latent image and a seed.
|
| 11 |
+
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
| 12 |
+
"""
|
| 13 |
+
generator = torch.manual_seed(seed)
|
| 14 |
+
if noise_inds is None:
|
| 15 |
+
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
| 16 |
+
|
| 17 |
+
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
| 18 |
+
noises = []
|
| 19 |
+
for i in range(unique_inds[-1]+1):
|
| 20 |
+
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
| 21 |
+
if i in unique_inds:
|
| 22 |
+
noises.append(noise)
|
| 23 |
+
noises = [noises[i] for i in inverse]
|
| 24 |
+
noises = torch.cat(noises, axis=0)
|
| 25 |
+
return noises
|
| 26 |
+
|
| 27 |
+
def fix_empty_latent_channels(model, latent_image):
|
| 28 |
+
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
| 29 |
+
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
| 30 |
+
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
| 31 |
+
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
|
| 32 |
+
latent_image = latent_image.unsqueeze(2)
|
| 33 |
+
return latent_image
|
| 34 |
+
|
| 35 |
+
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
| 36 |
+
logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
|
| 37 |
+
return model, positive, negative, noise_mask, []
|
| 38 |
+
|
| 39 |
+
def cleanup_additional_models(models):
|
| 40 |
+
logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed")
|
| 41 |
+
|
| 42 |
+
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
| 43 |
+
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
| 44 |
+
|
| 45 |
+
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
| 46 |
+
samples = samples.to(comfy.model_management.intermediate_device())
|
| 47 |
+
return samples
|
| 48 |
+
|
| 49 |
+
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
| 50 |
+
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
| 51 |
+
samples = samples.to(comfy.model_management.intermediate_device())
|
| 52 |
+
return samples
|
ComfyUI/comfy/samplers.py
ADDED
|
@@ -0,0 +1,1143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from .k_diffusion import sampling as k_diffusion_sampling
|
| 3 |
+
from .extra_samplers import uni_pc
|
| 4 |
+
from typing import TYPE_CHECKING, Callable, NamedTuple
|
| 5 |
+
if TYPE_CHECKING:
|
| 6 |
+
from comfy.model_patcher import ModelPatcher
|
| 7 |
+
from comfy.model_base import BaseModel
|
| 8 |
+
from comfy.controlnet import ControlBase
|
| 9 |
+
import torch
|
| 10 |
+
from functools import partial
|
| 11 |
+
import collections
|
| 12 |
+
from comfy import model_management
|
| 13 |
+
import math
|
| 14 |
+
import logging
|
| 15 |
+
import comfy.sampler_helpers
|
| 16 |
+
import comfy.model_patcher
|
| 17 |
+
import comfy.patcher_extension
|
| 18 |
+
import comfy.hooks
|
| 19 |
+
import scipy.stats
|
| 20 |
+
import numpy
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def add_area_dims(area, num_dims):
|
| 24 |
+
while (len(area) // 2) < num_dims:
|
| 25 |
+
area = [2147483648] + area[:len(area) // 2] + [0] + area[len(area) // 2:]
|
| 26 |
+
return area
|
| 27 |
+
|
| 28 |
+
def get_area_and_mult(conds, x_in, timestep_in):
|
| 29 |
+
dims = tuple(x_in.shape[2:])
|
| 30 |
+
area = None
|
| 31 |
+
strength = 1.0
|
| 32 |
+
|
| 33 |
+
if 'timestep_start' in conds:
|
| 34 |
+
timestep_start = conds['timestep_start']
|
| 35 |
+
if timestep_in[0] > timestep_start:
|
| 36 |
+
return None
|
| 37 |
+
if 'timestep_end' in conds:
|
| 38 |
+
timestep_end = conds['timestep_end']
|
| 39 |
+
if timestep_in[0] < timestep_end:
|
| 40 |
+
return None
|
| 41 |
+
if 'area' in conds:
|
| 42 |
+
area = list(conds['area'])
|
| 43 |
+
area = add_area_dims(area, len(dims))
|
| 44 |
+
if (len(area) // 2) > len(dims):
|
| 45 |
+
area = area[:len(dims)] + area[len(area) // 2:(len(area) // 2) + len(dims)]
|
| 46 |
+
|
| 47 |
+
if 'strength' in conds:
|
| 48 |
+
strength = conds['strength']
|
| 49 |
+
|
| 50 |
+
input_x = x_in
|
| 51 |
+
if area is not None:
|
| 52 |
+
for i in range(len(dims)):
|
| 53 |
+
area[i] = min(input_x.shape[i + 2] - area[len(dims) + i], area[i])
|
| 54 |
+
input_x = input_x.narrow(i + 2, area[len(dims) + i], area[i])
|
| 55 |
+
|
| 56 |
+
if 'mask' in conds:
|
| 57 |
+
# Scale the mask to the size of the input
|
| 58 |
+
# The mask should have been resized as we began the sampling process
|
| 59 |
+
mask_strength = 1.0
|
| 60 |
+
if "mask_strength" in conds:
|
| 61 |
+
mask_strength = conds["mask_strength"]
|
| 62 |
+
mask = conds['mask']
|
| 63 |
+
assert (mask.shape[1:] == x_in.shape[2:])
|
| 64 |
+
|
| 65 |
+
mask = mask[:input_x.shape[0]]
|
| 66 |
+
if area is not None:
|
| 67 |
+
for i in range(len(dims)):
|
| 68 |
+
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
|
| 69 |
+
|
| 70 |
+
mask = mask * mask_strength
|
| 71 |
+
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
| 72 |
+
else:
|
| 73 |
+
mask = torch.ones_like(input_x)
|
| 74 |
+
mult = mask * strength
|
| 75 |
+
|
| 76 |
+
if 'mask' not in conds and area is not None:
|
| 77 |
+
fuzz = 8
|
| 78 |
+
for i in range(len(dims)):
|
| 79 |
+
rr = min(fuzz, mult.shape[2 + i] // 4)
|
| 80 |
+
if area[len(dims) + i] != 0:
|
| 81 |
+
for t in range(rr):
|
| 82 |
+
m = mult.narrow(i + 2, t, 1)
|
| 83 |
+
m *= ((1.0 / rr) * (t + 1))
|
| 84 |
+
if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]:
|
| 85 |
+
for t in range(rr):
|
| 86 |
+
m = mult.narrow(i + 2, area[i] - 1 - t, 1)
|
| 87 |
+
m *= ((1.0 / rr) * (t + 1))
|
| 88 |
+
|
| 89 |
+
conditioning = {}
|
| 90 |
+
model_conds = conds["model_conds"]
|
| 91 |
+
for c in model_conds:
|
| 92 |
+
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
| 93 |
+
|
| 94 |
+
hooks = conds.get('hooks', None)
|
| 95 |
+
control = conds.get('control', None)
|
| 96 |
+
|
| 97 |
+
patches = None
|
| 98 |
+
if 'gligen' in conds:
|
| 99 |
+
gligen = conds['gligen']
|
| 100 |
+
patches = {}
|
| 101 |
+
gligen_type = gligen[0]
|
| 102 |
+
gligen_model = gligen[1]
|
| 103 |
+
if gligen_type == "position":
|
| 104 |
+
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
|
| 105 |
+
else:
|
| 106 |
+
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
|
| 107 |
+
|
| 108 |
+
patches['middle_patch'] = [gligen_patch]
|
| 109 |
+
|
| 110 |
+
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks'])
|
| 111 |
+
return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks)
|
| 112 |
+
|
| 113 |
+
def cond_equal_size(c1, c2):
|
| 114 |
+
if c1 is c2:
|
| 115 |
+
return True
|
| 116 |
+
if c1.keys() != c2.keys():
|
| 117 |
+
return False
|
| 118 |
+
for k in c1:
|
| 119 |
+
if not c1[k].can_concat(c2[k]):
|
| 120 |
+
return False
|
| 121 |
+
return True
|
| 122 |
+
|
| 123 |
+
def can_concat_cond(c1, c2):
|
| 124 |
+
if c1.input_x.shape != c2.input_x.shape:
|
| 125 |
+
return False
|
| 126 |
+
|
| 127 |
+
def objects_concatable(obj1, obj2):
|
| 128 |
+
if (obj1 is None) != (obj2 is None):
|
| 129 |
+
return False
|
| 130 |
+
if obj1 is not None:
|
| 131 |
+
if obj1 is not obj2:
|
| 132 |
+
return False
|
| 133 |
+
return True
|
| 134 |
+
|
| 135 |
+
if not objects_concatable(c1.control, c2.control):
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
if not objects_concatable(c1.patches, c2.patches):
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
return cond_equal_size(c1.conditioning, c2.conditioning)
|
| 142 |
+
|
| 143 |
+
def cond_cat(c_list):
|
| 144 |
+
temp = {}
|
| 145 |
+
for x in c_list:
|
| 146 |
+
for k in x:
|
| 147 |
+
cur = temp.get(k, [])
|
| 148 |
+
cur.append(x[k])
|
| 149 |
+
temp[k] = cur
|
| 150 |
+
|
| 151 |
+
out = {}
|
| 152 |
+
for k in temp:
|
| 153 |
+
conds = temp[k]
|
| 154 |
+
out[k] = conds[0].concat(conds[1:])
|
| 155 |
+
|
| 156 |
+
return out
|
| 157 |
+
|
| 158 |
+
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options):
|
| 159 |
+
# need to figure out remaining unmasked area for conds
|
| 160 |
+
default_mults = []
|
| 161 |
+
for _ in default_conds:
|
| 162 |
+
default_mults.append(torch.ones_like(x_in))
|
| 163 |
+
# look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond
|
| 164 |
+
for lora_hooks, to_run in hooked_to_run.items():
|
| 165 |
+
for cond_obj, i in to_run:
|
| 166 |
+
# if no default_cond for cond_type, do nothing
|
| 167 |
+
if len(default_conds[i]) == 0:
|
| 168 |
+
continue
|
| 169 |
+
area: list[int] = cond_obj.area
|
| 170 |
+
if area is not None:
|
| 171 |
+
curr_default_mult: torch.Tensor = default_mults[i]
|
| 172 |
+
dims = len(area) // 2
|
| 173 |
+
for i in range(dims):
|
| 174 |
+
curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i])
|
| 175 |
+
curr_default_mult -= cond_obj.mult
|
| 176 |
+
else:
|
| 177 |
+
default_mults[i] -= cond_obj.mult
|
| 178 |
+
# for each default_mult, ReLU to make negatives=0, and then check for any nonzeros
|
| 179 |
+
for i, mult in enumerate(default_mults):
|
| 180 |
+
# if no default_cond for cond type, do nothing
|
| 181 |
+
if len(default_conds[i]) == 0:
|
| 182 |
+
continue
|
| 183 |
+
torch.nn.functional.relu(mult, inplace=True)
|
| 184 |
+
# if mult is all zeros, then don't add default_cond
|
| 185 |
+
if torch.max(mult) == 0.0:
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
cond = default_conds[i]
|
| 189 |
+
for x in cond:
|
| 190 |
+
# do get_area_and_mult to get all the expected values
|
| 191 |
+
p = get_area_and_mult(x, x_in, timestep)
|
| 192 |
+
if p is None:
|
| 193 |
+
continue
|
| 194 |
+
# replace p's mult with calculated mult
|
| 195 |
+
p = p._replace(mult=mult)
|
| 196 |
+
if p.hooks is not None:
|
| 197 |
+
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
| 198 |
+
hooked_to_run.setdefault(p.hooks, list())
|
| 199 |
+
hooked_to_run[p.hooks] += [(p, i)]
|
| 200 |
+
|
| 201 |
+
def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
| 202 |
+
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
| 203 |
+
_calc_cond_batch,
|
| 204 |
+
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
| 205 |
+
)
|
| 206 |
+
return executor.execute(model, conds, x_in, timestep, model_options)
|
| 207 |
+
|
| 208 |
+
def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
| 209 |
+
out_conds = []
|
| 210 |
+
out_counts = []
|
| 211 |
+
# separate conds by matching hooks
|
| 212 |
+
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
| 213 |
+
default_conds = []
|
| 214 |
+
has_default_conds = False
|
| 215 |
+
|
| 216 |
+
for i in range(len(conds)):
|
| 217 |
+
out_conds.append(torch.zeros_like(x_in))
|
| 218 |
+
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
| 219 |
+
|
| 220 |
+
cond = conds[i]
|
| 221 |
+
default_c = []
|
| 222 |
+
if cond is not None:
|
| 223 |
+
for x in cond:
|
| 224 |
+
if 'default' in x:
|
| 225 |
+
default_c.append(x)
|
| 226 |
+
has_default_conds = True
|
| 227 |
+
continue
|
| 228 |
+
p = get_area_and_mult(x, x_in, timestep)
|
| 229 |
+
if p is None:
|
| 230 |
+
continue
|
| 231 |
+
if p.hooks is not None:
|
| 232 |
+
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
| 233 |
+
hooked_to_run.setdefault(p.hooks, list())
|
| 234 |
+
hooked_to_run[p.hooks] += [(p, i)]
|
| 235 |
+
default_conds.append(default_c)
|
| 236 |
+
|
| 237 |
+
if has_default_conds:
|
| 238 |
+
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
| 239 |
+
|
| 240 |
+
model.current_patcher.prepare_state(timestep)
|
| 241 |
+
|
| 242 |
+
# run every hooked_to_run separately
|
| 243 |
+
for hooks, to_run in hooked_to_run.items():
|
| 244 |
+
while len(to_run) > 0:
|
| 245 |
+
first = to_run[0]
|
| 246 |
+
first_shape = first[0][0].shape
|
| 247 |
+
to_batch_temp = []
|
| 248 |
+
for x in range(len(to_run)):
|
| 249 |
+
if can_concat_cond(to_run[x][0], first[0]):
|
| 250 |
+
to_batch_temp += [x]
|
| 251 |
+
|
| 252 |
+
to_batch_temp.reverse()
|
| 253 |
+
to_batch = to_batch_temp[:1]
|
| 254 |
+
|
| 255 |
+
free_memory = model_management.get_free_memory(x_in.device)
|
| 256 |
+
for i in range(1, len(to_batch_temp) + 1):
|
| 257 |
+
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
| 258 |
+
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
| 259 |
+
cond_shapes = collections.defaultdict(list)
|
| 260 |
+
for tt in batch_amount:
|
| 261 |
+
cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
|
| 262 |
+
for k, v in to_run[tt][0].conditioning.items():
|
| 263 |
+
cond_shapes[k].append(v.size())
|
| 264 |
+
|
| 265 |
+
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
| 266 |
+
to_batch = batch_amount
|
| 267 |
+
break
|
| 268 |
+
|
| 269 |
+
input_x = []
|
| 270 |
+
mult = []
|
| 271 |
+
c = []
|
| 272 |
+
cond_or_uncond = []
|
| 273 |
+
uuids = []
|
| 274 |
+
area = []
|
| 275 |
+
control = None
|
| 276 |
+
patches = None
|
| 277 |
+
for x in to_batch:
|
| 278 |
+
o = to_run.pop(x)
|
| 279 |
+
p = o[0]
|
| 280 |
+
input_x.append(p.input_x)
|
| 281 |
+
mult.append(p.mult)
|
| 282 |
+
c.append(p.conditioning)
|
| 283 |
+
area.append(p.area)
|
| 284 |
+
cond_or_uncond.append(o[1])
|
| 285 |
+
uuids.append(p.uuid)
|
| 286 |
+
control = p.control
|
| 287 |
+
patches = p.patches
|
| 288 |
+
|
| 289 |
+
batch_chunks = len(cond_or_uncond)
|
| 290 |
+
input_x = torch.cat(input_x)
|
| 291 |
+
c = cond_cat(c)
|
| 292 |
+
timestep_ = torch.cat([timestep] * batch_chunks)
|
| 293 |
+
|
| 294 |
+
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
| 295 |
+
if 'transformer_options' in model_options:
|
| 296 |
+
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
| 297 |
+
model_options['transformer_options'],
|
| 298 |
+
copy_dict1=False)
|
| 299 |
+
|
| 300 |
+
if patches is not None:
|
| 301 |
+
# TODO: replace with merge_nested_dicts function
|
| 302 |
+
if "patches" in transformer_options:
|
| 303 |
+
cur_patches = transformer_options["patches"].copy()
|
| 304 |
+
for p in patches:
|
| 305 |
+
if p in cur_patches:
|
| 306 |
+
cur_patches[p] = cur_patches[p] + patches[p]
|
| 307 |
+
else:
|
| 308 |
+
cur_patches[p] = patches[p]
|
| 309 |
+
transformer_options["patches"] = cur_patches
|
| 310 |
+
else:
|
| 311 |
+
transformer_options["patches"] = patches
|
| 312 |
+
|
| 313 |
+
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
| 314 |
+
transformer_options["uuids"] = uuids[:]
|
| 315 |
+
transformer_options["sigmas"] = timestep
|
| 316 |
+
|
| 317 |
+
c['transformer_options'] = transformer_options
|
| 318 |
+
|
| 319 |
+
if control is not None:
|
| 320 |
+
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
| 321 |
+
|
| 322 |
+
if 'model_function_wrapper' in model_options:
|
| 323 |
+
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
| 324 |
+
else:
|
| 325 |
+
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
| 326 |
+
|
| 327 |
+
for o in range(batch_chunks):
|
| 328 |
+
cond_index = cond_or_uncond[o]
|
| 329 |
+
a = area[o]
|
| 330 |
+
if a is None:
|
| 331 |
+
out_conds[cond_index] += output[o] * mult[o]
|
| 332 |
+
out_counts[cond_index] += mult[o]
|
| 333 |
+
else:
|
| 334 |
+
out_c = out_conds[cond_index]
|
| 335 |
+
out_cts = out_counts[cond_index]
|
| 336 |
+
dims = len(a) // 2
|
| 337 |
+
for i in range(dims):
|
| 338 |
+
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
| 339 |
+
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
| 340 |
+
out_c += output[o] * mult[o]
|
| 341 |
+
out_cts += mult[o]
|
| 342 |
+
|
| 343 |
+
for i in range(len(out_conds)):
|
| 344 |
+
out_conds[i] /= out_counts[i]
|
| 345 |
+
|
| 346 |
+
return out_conds
|
| 347 |
+
|
| 348 |
+
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
| 349 |
+
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
| 350 |
+
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
| 351 |
+
|
| 352 |
+
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
|
| 353 |
+
if "sampler_cfg_function" in model_options:
|
| 354 |
+
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
| 355 |
+
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
| 356 |
+
cfg_result = x - model_options["sampler_cfg_function"](args)
|
| 357 |
+
else:
|
| 358 |
+
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
| 359 |
+
|
| 360 |
+
for fn in model_options.get("sampler_post_cfg_function", []):
|
| 361 |
+
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "cond_scale": cond_scale, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
| 362 |
+
"sigma": timestep, "model_options": model_options, "input": x}
|
| 363 |
+
cfg_result = fn(args)
|
| 364 |
+
|
| 365 |
+
return cfg_result
|
| 366 |
+
|
| 367 |
+
#The main sampling function shared by all the samplers
|
| 368 |
+
#Returns denoised
|
| 369 |
+
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
| 370 |
+
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
| 371 |
+
uncond_ = None
|
| 372 |
+
else:
|
| 373 |
+
uncond_ = uncond
|
| 374 |
+
|
| 375 |
+
conds = [cond, uncond_]
|
| 376 |
+
if "sampler_calc_cond_batch_function" in model_options:
|
| 377 |
+
args = {"conds": conds, "input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
| 378 |
+
out = model_options["sampler_calc_cond_batch_function"](args)
|
| 379 |
+
else:
|
| 380 |
+
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
| 381 |
+
|
| 382 |
+
for fn in model_options.get("sampler_pre_cfg_function", []):
|
| 383 |
+
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
|
| 384 |
+
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
| 385 |
+
out = fn(args)
|
| 386 |
+
|
| 387 |
+
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class KSamplerX0Inpaint:
|
| 391 |
+
def __init__(self, model, sigmas):
|
| 392 |
+
self.inner_model = model
|
| 393 |
+
self.sigmas = sigmas
|
| 394 |
+
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
|
| 395 |
+
if denoise_mask is not None:
|
| 396 |
+
if "denoise_mask_function" in model_options:
|
| 397 |
+
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
| 398 |
+
latent_mask = 1. - denoise_mask
|
| 399 |
+
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
|
| 400 |
+
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
| 401 |
+
if denoise_mask is not None:
|
| 402 |
+
out = out * denoise_mask + self.latent_image * latent_mask
|
| 403 |
+
return out
|
| 404 |
+
|
| 405 |
+
def simple_scheduler(model_sampling, steps):
|
| 406 |
+
s = model_sampling
|
| 407 |
+
sigs = []
|
| 408 |
+
ss = len(s.sigmas) / steps
|
| 409 |
+
for x in range(steps):
|
| 410 |
+
sigs += [float(s.sigmas[-(1 + int(x * ss))])]
|
| 411 |
+
sigs += [0.0]
|
| 412 |
+
return torch.FloatTensor(sigs)
|
| 413 |
+
|
| 414 |
+
def ddim_scheduler(model_sampling, steps):
|
| 415 |
+
s = model_sampling
|
| 416 |
+
sigs = []
|
| 417 |
+
x = 1
|
| 418 |
+
if math.isclose(float(s.sigmas[x]), 0, abs_tol=0.00001):
|
| 419 |
+
steps += 1
|
| 420 |
+
sigs = []
|
| 421 |
+
else:
|
| 422 |
+
sigs = [0.0]
|
| 423 |
+
|
| 424 |
+
ss = max(len(s.sigmas) // steps, 1)
|
| 425 |
+
while x < len(s.sigmas):
|
| 426 |
+
sigs += [float(s.sigmas[x])]
|
| 427 |
+
x += ss
|
| 428 |
+
sigs = sigs[::-1]
|
| 429 |
+
return torch.FloatTensor(sigs)
|
| 430 |
+
|
| 431 |
+
def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
| 432 |
+
s = model_sampling
|
| 433 |
+
start = s.timestep(s.sigma_max)
|
| 434 |
+
end = s.timestep(s.sigma_min)
|
| 435 |
+
|
| 436 |
+
append_zero = True
|
| 437 |
+
if sgm:
|
| 438 |
+
timesteps = torch.linspace(start, end, steps + 1)[:-1]
|
| 439 |
+
else:
|
| 440 |
+
if math.isclose(float(s.sigma(end)), 0, abs_tol=0.00001):
|
| 441 |
+
steps += 1
|
| 442 |
+
append_zero = False
|
| 443 |
+
timesteps = torch.linspace(start, end, steps)
|
| 444 |
+
|
| 445 |
+
sigs = []
|
| 446 |
+
for x in range(len(timesteps)):
|
| 447 |
+
ts = timesteps[x]
|
| 448 |
+
sigs.append(float(s.sigma(ts)))
|
| 449 |
+
|
| 450 |
+
if append_zero:
|
| 451 |
+
sigs += [0.0]
|
| 452 |
+
|
| 453 |
+
return torch.FloatTensor(sigs)
|
| 454 |
+
|
| 455 |
+
# Implemented based on: https://arxiv.org/abs/2407.12173
|
| 456 |
+
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
| 457 |
+
total_timesteps = (len(model_sampling.sigmas) - 1)
|
| 458 |
+
ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
|
| 459 |
+
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps)
|
| 460 |
+
|
| 461 |
+
sigs = []
|
| 462 |
+
last_t = -1
|
| 463 |
+
for t in ts:
|
| 464 |
+
if t != last_t:
|
| 465 |
+
sigs += [float(model_sampling.sigmas[int(t)])]
|
| 466 |
+
last_t = t
|
| 467 |
+
sigs += [0.0]
|
| 468 |
+
return torch.FloatTensor(sigs)
|
| 469 |
+
|
| 470 |
+
# from: https://github.com/genmoai/models/blob/main/src/mochi_preview/infer.py#L41
|
| 471 |
+
def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, linear_steps=None):
|
| 472 |
+
if steps == 1:
|
| 473 |
+
sigma_schedule = [1.0, 0.0]
|
| 474 |
+
else:
|
| 475 |
+
if linear_steps is None:
|
| 476 |
+
linear_steps = steps // 2
|
| 477 |
+
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
| 478 |
+
threshold_noise_step_diff = linear_steps - threshold_noise * steps
|
| 479 |
+
quadratic_steps = steps - linear_steps
|
| 480 |
+
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps ** 2)
|
| 481 |
+
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps ** 2)
|
| 482 |
+
const = quadratic_coef * (linear_steps ** 2)
|
| 483 |
+
quadratic_sigma_schedule = [
|
| 484 |
+
quadratic_coef * (i ** 2) + linear_coef * i + const
|
| 485 |
+
for i in range(linear_steps, steps)
|
| 486 |
+
]
|
| 487 |
+
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
| 488 |
+
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
| 489 |
+
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
|
| 490 |
+
|
| 491 |
+
# Referenced from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
|
| 492 |
+
def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Tensor:
|
| 493 |
+
adj_idxs = torch.arange(n, dtype=torch.float).div_(n - 1)
|
| 494 |
+
sigmas = adj_idxs.new_zeros(n + 1)
|
| 495 |
+
sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_()
|
| 496 |
+
return sigmas
|
| 497 |
+
|
| 498 |
+
def get_mask_aabb(masks):
|
| 499 |
+
if masks.numel() == 0:
|
| 500 |
+
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
| 501 |
+
|
| 502 |
+
b = masks.shape[0]
|
| 503 |
+
|
| 504 |
+
bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int)
|
| 505 |
+
is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool)
|
| 506 |
+
for i in range(b):
|
| 507 |
+
mask = masks[i]
|
| 508 |
+
if mask.numel() == 0:
|
| 509 |
+
continue
|
| 510 |
+
if torch.max(mask != 0) == False:
|
| 511 |
+
is_empty[i] = True
|
| 512 |
+
continue
|
| 513 |
+
y, x = torch.where(mask)
|
| 514 |
+
bounding_boxes[i, 0] = torch.min(x)
|
| 515 |
+
bounding_boxes[i, 1] = torch.min(y)
|
| 516 |
+
bounding_boxes[i, 2] = torch.max(x)
|
| 517 |
+
bounding_boxes[i, 3] = torch.max(y)
|
| 518 |
+
|
| 519 |
+
return bounding_boxes, is_empty
|
| 520 |
+
|
| 521 |
+
def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
|
| 522 |
+
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
|
| 523 |
+
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
| 524 |
+
for i in range(len(conditions)):
|
| 525 |
+
c = conditions[i]
|
| 526 |
+
if 'area' in c:
|
| 527 |
+
area = c['area']
|
| 528 |
+
if area[0] == "percentage":
|
| 529 |
+
modified = c.copy()
|
| 530 |
+
a = area[1:]
|
| 531 |
+
a_len = len(a) // 2
|
| 532 |
+
area = ()
|
| 533 |
+
for d in range(len(dims)):
|
| 534 |
+
area += (max(1, round(a[d] * dims[d])),)
|
| 535 |
+
for d in range(len(dims)):
|
| 536 |
+
area += (round(a[d + a_len] * dims[d]),)
|
| 537 |
+
|
| 538 |
+
modified['area'] = area
|
| 539 |
+
c = modified
|
| 540 |
+
conditions[i] = c
|
| 541 |
+
|
| 542 |
+
if 'mask' in c:
|
| 543 |
+
mask = c['mask']
|
| 544 |
+
mask = mask.to(device=device)
|
| 545 |
+
modified = c.copy()
|
| 546 |
+
if len(mask.shape) == len(dims):
|
| 547 |
+
mask = mask.unsqueeze(0)
|
| 548 |
+
if mask.shape[1:] != dims:
|
| 549 |
+
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
|
| 550 |
+
|
| 551 |
+
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
|
| 552 |
+
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
| 553 |
+
boxes, is_empty = get_mask_aabb(bounds)
|
| 554 |
+
if is_empty[0]:
|
| 555 |
+
# Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway)
|
| 556 |
+
modified['area'] = (8, 8, 0, 0)
|
| 557 |
+
else:
|
| 558 |
+
box = boxes[0]
|
| 559 |
+
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
|
| 560 |
+
H = max(8, H)
|
| 561 |
+
W = max(8, W)
|
| 562 |
+
area = (int(H), int(W), int(Y), int(X))
|
| 563 |
+
modified['area'] = area
|
| 564 |
+
|
| 565 |
+
modified['mask'] = mask
|
| 566 |
+
conditions[i] = modified
|
| 567 |
+
|
| 568 |
+
def resolve_areas_and_cond_masks(conditions, h, w, device):
|
| 569 |
+
logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
|
| 570 |
+
return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
|
| 571 |
+
|
| 572 |
+
def create_cond_with_same_area_if_none(conds, c):
|
| 573 |
+
if 'area' not in c:
|
| 574 |
+
return
|
| 575 |
+
|
| 576 |
+
def area_inside(a, area_cmp):
|
| 577 |
+
a = add_area_dims(a, len(area_cmp) // 2)
|
| 578 |
+
area_cmp = add_area_dims(area_cmp, len(a) // 2)
|
| 579 |
+
|
| 580 |
+
a_l = len(a) // 2
|
| 581 |
+
area_cmp_l = len(area_cmp) // 2
|
| 582 |
+
for i in range(min(a_l, area_cmp_l)):
|
| 583 |
+
if a[a_l + i] < area_cmp[area_cmp_l + i]:
|
| 584 |
+
return False
|
| 585 |
+
for i in range(min(a_l, area_cmp_l)):
|
| 586 |
+
if (a[i] + a[a_l + i]) > (area_cmp[i] + area_cmp[area_cmp_l + i]):
|
| 587 |
+
return False
|
| 588 |
+
return True
|
| 589 |
+
|
| 590 |
+
c_area = c['area']
|
| 591 |
+
smallest = None
|
| 592 |
+
for x in conds:
|
| 593 |
+
if 'area' in x:
|
| 594 |
+
a = x['area']
|
| 595 |
+
if area_inside(c_area, a):
|
| 596 |
+
if smallest is None:
|
| 597 |
+
smallest = x
|
| 598 |
+
elif 'area' not in smallest:
|
| 599 |
+
smallest = x
|
| 600 |
+
else:
|
| 601 |
+
if math.prod(smallest['area'][:len(smallest['area']) // 2]) > math.prod(a[:len(a) // 2]):
|
| 602 |
+
smallest = x
|
| 603 |
+
else:
|
| 604 |
+
if smallest is None:
|
| 605 |
+
smallest = x
|
| 606 |
+
if smallest is None:
|
| 607 |
+
return
|
| 608 |
+
if 'area' in smallest:
|
| 609 |
+
if smallest['area'] == c_area:
|
| 610 |
+
return
|
| 611 |
+
|
| 612 |
+
out = c.copy()
|
| 613 |
+
out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied?
|
| 614 |
+
conds += [out]
|
| 615 |
+
|
| 616 |
+
def calculate_start_end_timesteps(model, conds):
|
| 617 |
+
s = model.model_sampling
|
| 618 |
+
for t in range(len(conds)):
|
| 619 |
+
x = conds[t]
|
| 620 |
+
|
| 621 |
+
timestep_start = None
|
| 622 |
+
timestep_end = None
|
| 623 |
+
# handle clip hook schedule, if needed
|
| 624 |
+
if 'clip_start_percent' in x:
|
| 625 |
+
timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0)))
|
| 626 |
+
timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0)))
|
| 627 |
+
else:
|
| 628 |
+
if 'start_percent' in x:
|
| 629 |
+
timestep_start = s.percent_to_sigma(x['start_percent'])
|
| 630 |
+
if 'end_percent' in x:
|
| 631 |
+
timestep_end = s.percent_to_sigma(x['end_percent'])
|
| 632 |
+
|
| 633 |
+
if (timestep_start is not None) or (timestep_end is not None):
|
| 634 |
+
n = x.copy()
|
| 635 |
+
if (timestep_start is not None):
|
| 636 |
+
n['timestep_start'] = timestep_start
|
| 637 |
+
if (timestep_end is not None):
|
| 638 |
+
n['timestep_end'] = timestep_end
|
| 639 |
+
conds[t] = n
|
| 640 |
+
|
| 641 |
+
def pre_run_control(model, conds):
|
| 642 |
+
s = model.model_sampling
|
| 643 |
+
for t in range(len(conds)):
|
| 644 |
+
x = conds[t]
|
| 645 |
+
|
| 646 |
+
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
| 647 |
+
if 'control' in x:
|
| 648 |
+
x['control'].pre_run(model, percent_to_timestep_function)
|
| 649 |
+
|
| 650 |
+
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
| 651 |
+
cond_cnets = []
|
| 652 |
+
cond_other = []
|
| 653 |
+
uncond_cnets = []
|
| 654 |
+
uncond_other = []
|
| 655 |
+
for t in range(len(conds)):
|
| 656 |
+
x = conds[t]
|
| 657 |
+
if 'area' not in x:
|
| 658 |
+
if name in x and x[name] is not None:
|
| 659 |
+
cond_cnets.append(x[name])
|
| 660 |
+
else:
|
| 661 |
+
cond_other.append((x, t))
|
| 662 |
+
for t in range(len(uncond)):
|
| 663 |
+
x = uncond[t]
|
| 664 |
+
if 'area' not in x:
|
| 665 |
+
if name in x and x[name] is not None:
|
| 666 |
+
uncond_cnets.append(x[name])
|
| 667 |
+
else:
|
| 668 |
+
uncond_other.append((x, t))
|
| 669 |
+
|
| 670 |
+
if len(uncond_cnets) > 0:
|
| 671 |
+
return
|
| 672 |
+
|
| 673 |
+
for x in range(len(cond_cnets)):
|
| 674 |
+
temp = uncond_other[x % len(uncond_other)]
|
| 675 |
+
o = temp[0]
|
| 676 |
+
if name in o and o[name] is not None:
|
| 677 |
+
n = o.copy()
|
| 678 |
+
n[name] = uncond_fill_func(cond_cnets, x)
|
| 679 |
+
uncond += [n]
|
| 680 |
+
else:
|
| 681 |
+
n = o.copy()
|
| 682 |
+
n[name] = uncond_fill_func(cond_cnets, x)
|
| 683 |
+
uncond[temp[1]] = n
|
| 684 |
+
|
| 685 |
+
def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs):
|
| 686 |
+
for t in range(len(conds)):
|
| 687 |
+
x = conds[t]
|
| 688 |
+
params = x.copy()
|
| 689 |
+
params["device"] = device
|
| 690 |
+
params["noise"] = noise
|
| 691 |
+
default_width = None
|
| 692 |
+
if len(noise.shape) >= 4: #TODO: 8 multiple should be set by the model
|
| 693 |
+
default_width = noise.shape[3] * 8
|
| 694 |
+
params["width"] = params.get("width", default_width)
|
| 695 |
+
params["height"] = params.get("height", noise.shape[2] * 8)
|
| 696 |
+
params["prompt_type"] = params.get("prompt_type", prompt_type)
|
| 697 |
+
for k in kwargs:
|
| 698 |
+
if k not in params:
|
| 699 |
+
params[k] = kwargs[k]
|
| 700 |
+
|
| 701 |
+
out = model_function(**params)
|
| 702 |
+
x = x.copy()
|
| 703 |
+
model_conds = x['model_conds'].copy()
|
| 704 |
+
for k in out:
|
| 705 |
+
model_conds[k] = out[k]
|
| 706 |
+
x['model_conds'] = model_conds
|
| 707 |
+
conds[t] = x
|
| 708 |
+
return conds
|
| 709 |
+
|
| 710 |
+
class Sampler:
|
| 711 |
+
def sample(self):
|
| 712 |
+
pass
|
| 713 |
+
|
| 714 |
+
def max_denoise(self, model_wrap, sigmas):
|
| 715 |
+
max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
|
| 716 |
+
sigma = float(sigmas[0])
|
| 717 |
+
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
| 718 |
+
|
| 719 |
+
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
| 720 |
+
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
| 721 |
+
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
| 722 |
+
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
| 723 |
+
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
| 724 |
+
|
| 725 |
+
class KSAMPLER(Sampler):
|
| 726 |
+
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
| 727 |
+
self.sampler_function = sampler_function
|
| 728 |
+
self.extra_options = extra_options
|
| 729 |
+
self.inpaint_options = inpaint_options
|
| 730 |
+
|
| 731 |
+
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
| 732 |
+
extra_args["denoise_mask"] = denoise_mask
|
| 733 |
+
model_k = KSamplerX0Inpaint(model_wrap, sigmas)
|
| 734 |
+
model_k.latent_image = latent_image
|
| 735 |
+
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
|
| 736 |
+
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
|
| 737 |
+
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
|
| 738 |
+
else:
|
| 739 |
+
model_k.noise = noise
|
| 740 |
+
|
| 741 |
+
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
|
| 742 |
+
|
| 743 |
+
k_callback = None
|
| 744 |
+
total_steps = len(sigmas) - 1
|
| 745 |
+
if callback is not None:
|
| 746 |
+
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
| 747 |
+
|
| 748 |
+
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
|
| 749 |
+
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
|
| 750 |
+
return samples
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
| 754 |
+
if sampler_name == "dpm_fast":
|
| 755 |
+
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
|
| 756 |
+
if len(sigmas) <= 1:
|
| 757 |
+
return noise
|
| 758 |
+
|
| 759 |
+
sigma_min = sigmas[-1]
|
| 760 |
+
if sigma_min == 0:
|
| 761 |
+
sigma_min = sigmas[-2]
|
| 762 |
+
total_steps = len(sigmas) - 1
|
| 763 |
+
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
|
| 764 |
+
sampler_function = dpm_fast_function
|
| 765 |
+
elif sampler_name == "dpm_adaptive":
|
| 766 |
+
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
|
| 767 |
+
if len(sigmas) <= 1:
|
| 768 |
+
return noise
|
| 769 |
+
|
| 770 |
+
sigma_min = sigmas[-1]
|
| 771 |
+
if sigma_min == 0:
|
| 772 |
+
sigma_min = sigmas[-2]
|
| 773 |
+
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options)
|
| 774 |
+
sampler_function = dpm_adaptive_function
|
| 775 |
+
else:
|
| 776 |
+
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
|
| 777 |
+
|
| 778 |
+
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
| 782 |
+
for k in conds:
|
| 783 |
+
conds[k] = conds[k][:]
|
| 784 |
+
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
|
| 785 |
+
|
| 786 |
+
for k in conds:
|
| 787 |
+
calculate_start_end_timesteps(model, conds[k])
|
| 788 |
+
|
| 789 |
+
if hasattr(model, 'extra_conds'):
|
| 790 |
+
for k in conds:
|
| 791 |
+
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
| 792 |
+
|
| 793 |
+
#make sure each cond area has an opposite one with the same area
|
| 794 |
+
for k in conds:
|
| 795 |
+
for c in conds[k]:
|
| 796 |
+
for kk in conds:
|
| 797 |
+
if k != kk:
|
| 798 |
+
create_cond_with_same_area_if_none(conds[kk], c)
|
| 799 |
+
|
| 800 |
+
for k in conds:
|
| 801 |
+
for c in conds[k]:
|
| 802 |
+
if 'hooks' in c:
|
| 803 |
+
for hook in c['hooks'].hooks:
|
| 804 |
+
hook.initialize_timesteps(model)
|
| 805 |
+
|
| 806 |
+
for k in conds:
|
| 807 |
+
pre_run_control(model, conds[k])
|
| 808 |
+
|
| 809 |
+
if "positive" in conds:
|
| 810 |
+
positive = conds["positive"]
|
| 811 |
+
for k in conds:
|
| 812 |
+
if k != "positive":
|
| 813 |
+
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x])
|
| 814 |
+
apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
| 815 |
+
|
| 816 |
+
return conds
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
|
| 820 |
+
# determine which ControlNets have extra_hooks that should be combined with normal hooks
|
| 821 |
+
hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {}
|
| 822 |
+
for k in conds:
|
| 823 |
+
for kk in conds[k]:
|
| 824 |
+
if 'control' in kk:
|
| 825 |
+
control: 'ControlBase' = kk['control']
|
| 826 |
+
extra_hooks = control.get_extra_hooks()
|
| 827 |
+
if len(extra_hooks) > 0:
|
| 828 |
+
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
|
| 829 |
+
to_replace = hook_replacement.setdefault((control, hooks), [])
|
| 830 |
+
to_replace.append(kk)
|
| 831 |
+
# if nothing to replace, do nothing
|
| 832 |
+
if len(hook_replacement) == 0:
|
| 833 |
+
return
|
| 834 |
+
|
| 835 |
+
# for optimal sampling performance, common ControlNets + hook combos should have identical hooks
|
| 836 |
+
# on the cond dicts
|
| 837 |
+
for key, conds_to_modify in hook_replacement.items():
|
| 838 |
+
control = key[0]
|
| 839 |
+
hooks = key[1]
|
| 840 |
+
hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks])
|
| 841 |
+
# if combined hooks are not None, set as new hooks for all relevant conds
|
| 842 |
+
if hooks is not None:
|
| 843 |
+
for cond in conds_to_modify:
|
| 844 |
+
cond['hooks'] = hooks
|
| 845 |
+
|
| 846 |
+
def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]):
|
| 847 |
+
'''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for
|
| 848 |
+
HookGroups that have the same reference.'''
|
| 849 |
+
registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None)
|
| 850 |
+
# if None were registered, make sure all hooks are cleaned from conds
|
| 851 |
+
if registered is None:
|
| 852 |
+
for k in conds:
|
| 853 |
+
for kk in conds[k]:
|
| 854 |
+
kk.pop('hooks', None)
|
| 855 |
+
return
|
| 856 |
+
# find conds that contain hooks to be replaced - group by common HookGroup refs
|
| 857 |
+
hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {}
|
| 858 |
+
for k in conds:
|
| 859 |
+
for kk in conds[k]:
|
| 860 |
+
hooks: comfy.hooks.HookGroup = kk.get('hooks', None)
|
| 861 |
+
if hooks is not None:
|
| 862 |
+
if not hooks.is_subset_of(registered):
|
| 863 |
+
to_replace = hook_replacement.setdefault(hooks, [])
|
| 864 |
+
to_replace.append(kk)
|
| 865 |
+
# for each hook to replace, create a new proper HookGroup and assign to all common conds
|
| 866 |
+
for hooks, conds_to_modify in hook_replacement.items():
|
| 867 |
+
new_hooks = hooks.new_with_common_hooks(registered)
|
| 868 |
+
if len(new_hooks) == 0:
|
| 869 |
+
new_hooks = None
|
| 870 |
+
for kk in conds_to_modify:
|
| 871 |
+
kk['hooks'] = new_hooks
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
|
| 875 |
+
hooks_set = set()
|
| 876 |
+
for k in conds:
|
| 877 |
+
for kk in conds[k]:
|
| 878 |
+
hooks_set.add(kk.get('hooks', None))
|
| 879 |
+
return len(hooks_set)
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
| 883 |
+
'''
|
| 884 |
+
If any patches from hooks, wrappers, or callbacks have .to to be called, call it.
|
| 885 |
+
'''
|
| 886 |
+
if model_options is None:
|
| 887 |
+
return
|
| 888 |
+
to_load_options = model_options.get("to_load_options", None)
|
| 889 |
+
if to_load_options is None:
|
| 890 |
+
return
|
| 891 |
+
|
| 892 |
+
casts = []
|
| 893 |
+
if device is not None:
|
| 894 |
+
casts.append(device)
|
| 895 |
+
if dtype is not None:
|
| 896 |
+
casts.append(dtype)
|
| 897 |
+
# if nothing to apply, do nothing
|
| 898 |
+
if len(casts) == 0:
|
| 899 |
+
return
|
| 900 |
+
|
| 901 |
+
# try to call .to on patches
|
| 902 |
+
if "patches" in to_load_options:
|
| 903 |
+
patches = to_load_options["patches"]
|
| 904 |
+
for name in patches:
|
| 905 |
+
patch_list = patches[name]
|
| 906 |
+
for i in range(len(patch_list)):
|
| 907 |
+
if hasattr(patch_list[i], "to"):
|
| 908 |
+
for cast in casts:
|
| 909 |
+
patch_list[i] = patch_list[i].to(cast)
|
| 910 |
+
if "patches_replace" in to_load_options:
|
| 911 |
+
patches = to_load_options["patches_replace"]
|
| 912 |
+
for name in patches:
|
| 913 |
+
patch_list = patches[name]
|
| 914 |
+
for k in patch_list:
|
| 915 |
+
if hasattr(patch_list[k], "to"):
|
| 916 |
+
for cast in casts:
|
| 917 |
+
patch_list[k] = patch_list[k].to(cast)
|
| 918 |
+
# try to call .to on any wrappers/callbacks
|
| 919 |
+
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
| 920 |
+
for wc_name in wrappers_and_callbacks:
|
| 921 |
+
if wc_name in to_load_options:
|
| 922 |
+
wc: dict[str, list] = to_load_options[wc_name]
|
| 923 |
+
for wc_dict in wc.values():
|
| 924 |
+
for wc_list in wc_dict.values():
|
| 925 |
+
for i in range(len(wc_list)):
|
| 926 |
+
if hasattr(wc_list[i], "to"):
|
| 927 |
+
for cast in casts:
|
| 928 |
+
wc_list[i] = wc_list[i].to(cast)
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
class CFGGuider:
|
| 932 |
+
def __init__(self, model_patcher: ModelPatcher):
|
| 933 |
+
self.model_patcher = model_patcher
|
| 934 |
+
self.model_options = model_patcher.model_options
|
| 935 |
+
self.original_conds = {}
|
| 936 |
+
self.cfg = 1.0
|
| 937 |
+
|
| 938 |
+
def set_conds(self, positive, negative):
|
| 939 |
+
self.inner_set_conds({"positive": positive, "negative": negative})
|
| 940 |
+
|
| 941 |
+
def set_cfg(self, cfg):
|
| 942 |
+
self.cfg = cfg
|
| 943 |
+
|
| 944 |
+
def inner_set_conds(self, conds):
|
| 945 |
+
for k in conds:
|
| 946 |
+
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
| 947 |
+
|
| 948 |
+
def __call__(self, *args, **kwargs):
|
| 949 |
+
return self.predict_noise(*args, **kwargs)
|
| 950 |
+
|
| 951 |
+
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
| 952 |
+
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
| 953 |
+
|
| 954 |
+
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
|
| 955 |
+
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
| 956 |
+
latent_image = self.inner_model.process_latent_in(latent_image)
|
| 957 |
+
|
| 958 |
+
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
| 959 |
+
|
| 960 |
+
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
| 961 |
+
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
|
| 962 |
+
extra_args = {"model_options": extra_model_options, "seed": seed}
|
| 963 |
+
|
| 964 |
+
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
| 965 |
+
sampler.sample,
|
| 966 |
+
sampler,
|
| 967 |
+
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
|
| 968 |
+
)
|
| 969 |
+
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
| 970 |
+
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
| 971 |
+
|
| 972 |
+
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
| 973 |
+
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
| 974 |
+
device = self.model_patcher.load_device
|
| 975 |
+
|
| 976 |
+
if denoise_mask is not None:
|
| 977 |
+
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
|
| 978 |
+
|
| 979 |
+
noise = noise.to(device)
|
| 980 |
+
latent_image = latent_image.to(device)
|
| 981 |
+
sigmas = sigmas.to(device)
|
| 982 |
+
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
| 983 |
+
|
| 984 |
+
try:
|
| 985 |
+
self.model_patcher.pre_run()
|
| 986 |
+
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
| 987 |
+
finally:
|
| 988 |
+
self.model_patcher.cleanup()
|
| 989 |
+
|
| 990 |
+
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
| 991 |
+
del self.inner_model
|
| 992 |
+
del self.loaded_models
|
| 993 |
+
return output
|
| 994 |
+
|
| 995 |
+
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
| 996 |
+
if sigmas.shape[-1] == 0:
|
| 997 |
+
return latent_image
|
| 998 |
+
|
| 999 |
+
self.conds = {}
|
| 1000 |
+
for k in self.original_conds:
|
| 1001 |
+
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
| 1002 |
+
preprocess_conds_hooks(self.conds)
|
| 1003 |
+
|
| 1004 |
+
try:
|
| 1005 |
+
orig_model_options = self.model_options
|
| 1006 |
+
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
| 1007 |
+
# if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step)
|
| 1008 |
+
orig_hook_mode = self.model_patcher.hook_mode
|
| 1009 |
+
if get_total_hook_groups_in_conds(self.conds) <= 1:
|
| 1010 |
+
self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
| 1011 |
+
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
|
| 1012 |
+
filter_registered_hooks_on_conds(self.conds, self.model_options)
|
| 1013 |
+
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
| 1014 |
+
self.outer_sample,
|
| 1015 |
+
self,
|
| 1016 |
+
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
| 1017 |
+
)
|
| 1018 |
+
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
| 1019 |
+
finally:
|
| 1020 |
+
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
| 1021 |
+
self.model_options = orig_model_options
|
| 1022 |
+
self.model_patcher.hook_mode = orig_hook_mode
|
| 1023 |
+
self.model_patcher.restore_hook_patches()
|
| 1024 |
+
|
| 1025 |
+
del self.conds
|
| 1026 |
+
return output
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
| 1030 |
+
cfg_guider = CFGGuider(model)
|
| 1031 |
+
cfg_guider.set_conds(positive, negative)
|
| 1032 |
+
cfg_guider.set_cfg(cfg)
|
| 1033 |
+
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
| 1037 |
+
|
| 1038 |
+
class SchedulerHandler(NamedTuple):
|
| 1039 |
+
handler: Callable[..., torch.Tensor]
|
| 1040 |
+
# Boolean indicates whether to call the handler like:
|
| 1041 |
+
# scheduler_function(model_sampling, steps) or
|
| 1042 |
+
# scheduler_function(n, sigma_min: float, sigma_max: float)
|
| 1043 |
+
use_ms: bool = True
|
| 1044 |
+
|
| 1045 |
+
SCHEDULER_HANDLERS = {
|
| 1046 |
+
"simple": SchedulerHandler(simple_scheduler),
|
| 1047 |
+
"sgm_uniform": SchedulerHandler(partial(normal_scheduler, sgm=True)),
|
| 1048 |
+
"karras": SchedulerHandler(k_diffusion_sampling.get_sigmas_karras, use_ms=False),
|
| 1049 |
+
"exponential": SchedulerHandler(k_diffusion_sampling.get_sigmas_exponential, use_ms=False),
|
| 1050 |
+
"ddim_uniform": SchedulerHandler(ddim_scheduler),
|
| 1051 |
+
"beta": SchedulerHandler(beta_scheduler),
|
| 1052 |
+
"normal": SchedulerHandler(normal_scheduler),
|
| 1053 |
+
"linear_quadratic": SchedulerHandler(linear_quadratic_schedule),
|
| 1054 |
+
"kl_optimal": SchedulerHandler(kl_optimal_scheduler, use_ms=False),
|
| 1055 |
+
}
|
| 1056 |
+
SCHEDULER_NAMES = list(SCHEDULER_HANDLERS)
|
| 1057 |
+
|
| 1058 |
+
def calculate_sigmas(model_sampling: object, scheduler_name: str, steps: int) -> torch.Tensor:
|
| 1059 |
+
handler = SCHEDULER_HANDLERS.get(scheduler_name)
|
| 1060 |
+
if handler is None:
|
| 1061 |
+
err = f"error invalid scheduler {scheduler_name}"
|
| 1062 |
+
logging.error(err)
|
| 1063 |
+
raise ValueError(err)
|
| 1064 |
+
if handler.use_ms:
|
| 1065 |
+
return handler.handler(model_sampling, steps)
|
| 1066 |
+
return handler.handler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
| 1067 |
+
|
| 1068 |
+
def sampler_object(name):
|
| 1069 |
+
if name == "uni_pc":
|
| 1070 |
+
sampler = KSAMPLER(uni_pc.sample_unipc)
|
| 1071 |
+
elif name == "uni_pc_bh2":
|
| 1072 |
+
sampler = KSAMPLER(uni_pc.sample_unipc_bh2)
|
| 1073 |
+
elif name == "ddim":
|
| 1074 |
+
sampler = ksampler("euler", inpaint_options={"random": True})
|
| 1075 |
+
else:
|
| 1076 |
+
sampler = ksampler(name)
|
| 1077 |
+
return sampler
|
| 1078 |
+
|
| 1079 |
+
class KSampler:
|
| 1080 |
+
SCHEDULERS = SCHEDULER_NAMES
|
| 1081 |
+
SAMPLERS = SAMPLER_NAMES
|
| 1082 |
+
DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set(('dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2'))
|
| 1083 |
+
|
| 1084 |
+
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
| 1085 |
+
self.model = model
|
| 1086 |
+
self.device = device
|
| 1087 |
+
if scheduler not in self.SCHEDULERS:
|
| 1088 |
+
scheduler = self.SCHEDULERS[0]
|
| 1089 |
+
if sampler not in self.SAMPLERS:
|
| 1090 |
+
sampler = self.SAMPLERS[0]
|
| 1091 |
+
self.scheduler = scheduler
|
| 1092 |
+
self.sampler = sampler
|
| 1093 |
+
self.set_steps(steps, denoise)
|
| 1094 |
+
self.denoise = denoise
|
| 1095 |
+
self.model_options = model_options
|
| 1096 |
+
|
| 1097 |
+
def calculate_sigmas(self, steps):
|
| 1098 |
+
sigmas = None
|
| 1099 |
+
|
| 1100 |
+
discard_penultimate_sigma = False
|
| 1101 |
+
if self.sampler in self.DISCARD_PENULTIMATE_SIGMA_SAMPLERS:
|
| 1102 |
+
steps += 1
|
| 1103 |
+
discard_penultimate_sigma = True
|
| 1104 |
+
|
| 1105 |
+
sigmas = calculate_sigmas(self.model.get_model_object("model_sampling"), self.scheduler, steps)
|
| 1106 |
+
|
| 1107 |
+
if discard_penultimate_sigma:
|
| 1108 |
+
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
| 1109 |
+
return sigmas
|
| 1110 |
+
|
| 1111 |
+
def set_steps(self, steps, denoise=None):
|
| 1112 |
+
self.steps = steps
|
| 1113 |
+
if denoise is None or denoise > 0.9999:
|
| 1114 |
+
self.sigmas = self.calculate_sigmas(steps).to(self.device)
|
| 1115 |
+
else:
|
| 1116 |
+
if denoise <= 0.0:
|
| 1117 |
+
self.sigmas = torch.FloatTensor([])
|
| 1118 |
+
else:
|
| 1119 |
+
new_steps = int(steps/denoise)
|
| 1120 |
+
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
| 1121 |
+
self.sigmas = sigmas[-(steps + 1):]
|
| 1122 |
+
|
| 1123 |
+
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
| 1124 |
+
if sigmas is None:
|
| 1125 |
+
sigmas = self.sigmas
|
| 1126 |
+
|
| 1127 |
+
if last_step is not None and last_step < (len(sigmas) - 1):
|
| 1128 |
+
sigmas = sigmas[:last_step + 1]
|
| 1129 |
+
if force_full_denoise:
|
| 1130 |
+
sigmas[-1] = 0
|
| 1131 |
+
|
| 1132 |
+
if start_step is not None:
|
| 1133 |
+
if start_step < (len(sigmas) - 1):
|
| 1134 |
+
sigmas = sigmas[start_step:]
|
| 1135 |
+
else:
|
| 1136 |
+
if latent_image is not None:
|
| 1137 |
+
return latent_image
|
| 1138 |
+
else:
|
| 1139 |
+
return torch.zeros_like(noise)
|
| 1140 |
+
|
| 1141 |
+
sampler = sampler_object(self.sampler)
|
| 1142 |
+
|
| 1143 |
+
return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
ComfyUI/comfy/sd1_clip.py
ADDED
|
@@ -0,0 +1,687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from transformers import CLIPTokenizer
|
| 4 |
+
import comfy.ops
|
| 5 |
+
import torch
|
| 6 |
+
import traceback
|
| 7 |
+
import zipfile
|
| 8 |
+
from . import model_management
|
| 9 |
+
import comfy.clip_model
|
| 10 |
+
import json
|
| 11 |
+
import logging
|
| 12 |
+
import numbers
|
| 13 |
+
import re
|
| 14 |
+
|
| 15 |
+
def gen_empty_tokens(special_tokens, length):
|
| 16 |
+
start_token = special_tokens.get("start", None)
|
| 17 |
+
end_token = special_tokens.get("end", None)
|
| 18 |
+
pad_token = special_tokens.get("pad")
|
| 19 |
+
output = []
|
| 20 |
+
if start_token is not None:
|
| 21 |
+
output.append(start_token)
|
| 22 |
+
if end_token is not None:
|
| 23 |
+
output.append(end_token)
|
| 24 |
+
output += [pad_token] * (length - len(output))
|
| 25 |
+
return output
|
| 26 |
+
|
| 27 |
+
class ClipTokenWeightEncoder:
|
| 28 |
+
def encode_token_weights(self, token_weight_pairs):
|
| 29 |
+
to_encode = list()
|
| 30 |
+
max_token_len = 0
|
| 31 |
+
has_weights = False
|
| 32 |
+
for x in token_weight_pairs:
|
| 33 |
+
tokens = list(map(lambda a: a[0], x))
|
| 34 |
+
max_token_len = max(len(tokens), max_token_len)
|
| 35 |
+
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
| 36 |
+
to_encode.append(tokens)
|
| 37 |
+
|
| 38 |
+
sections = len(to_encode)
|
| 39 |
+
if has_weights or sections == 0:
|
| 40 |
+
if hasattr(self, "gen_empty_tokens"):
|
| 41 |
+
to_encode.append(self.gen_empty_tokens(self.special_tokens, max_token_len))
|
| 42 |
+
else:
|
| 43 |
+
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
| 44 |
+
|
| 45 |
+
o = self.encode(to_encode)
|
| 46 |
+
out, pooled = o[:2]
|
| 47 |
+
|
| 48 |
+
if pooled is not None:
|
| 49 |
+
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
| 50 |
+
else:
|
| 51 |
+
first_pooled = pooled
|
| 52 |
+
|
| 53 |
+
output = []
|
| 54 |
+
for k in range(0, sections):
|
| 55 |
+
z = out[k:k+1]
|
| 56 |
+
if has_weights:
|
| 57 |
+
z_empty = out[-1]
|
| 58 |
+
for i in range(len(z)):
|
| 59 |
+
for j in range(len(z[i])):
|
| 60 |
+
weight = token_weight_pairs[k][j][1]
|
| 61 |
+
if weight != 1.0:
|
| 62 |
+
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
| 63 |
+
output.append(z)
|
| 64 |
+
|
| 65 |
+
if (len(output) == 0):
|
| 66 |
+
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
| 67 |
+
else:
|
| 68 |
+
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
| 69 |
+
|
| 70 |
+
if len(o) > 2:
|
| 71 |
+
extra = {}
|
| 72 |
+
for k in o[2]:
|
| 73 |
+
v = o[2][k]
|
| 74 |
+
if k == "attention_mask":
|
| 75 |
+
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
| 76 |
+
extra[k] = v
|
| 77 |
+
|
| 78 |
+
r = r + (extra,)
|
| 79 |
+
return r
|
| 80 |
+
|
| 81 |
+
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
| 82 |
+
LAYERS = [
|
| 83 |
+
"last",
|
| 84 |
+
"pooled",
|
| 85 |
+
"hidden",
|
| 86 |
+
"all"
|
| 87 |
+
]
|
| 88 |
+
def __init__(self, device="cpu", max_length=77,
|
| 89 |
+
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
| 90 |
+
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
| 91 |
+
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
| 92 |
+
super().__init__()
|
| 93 |
+
assert layer in self.LAYERS
|
| 94 |
+
|
| 95 |
+
if textmodel_json_config is None:
|
| 96 |
+
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
| 97 |
+
if "model_name" not in model_options:
|
| 98 |
+
model_options = {**model_options, "model_name": "clip_l"}
|
| 99 |
+
|
| 100 |
+
if isinstance(textmodel_json_config, dict):
|
| 101 |
+
config = textmodel_json_config
|
| 102 |
+
else:
|
| 103 |
+
with open(textmodel_json_config) as f:
|
| 104 |
+
config = json.load(f)
|
| 105 |
+
|
| 106 |
+
te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {})
|
| 107 |
+
for k, v in te_model_options.items():
|
| 108 |
+
config[k] = v
|
| 109 |
+
|
| 110 |
+
operations = model_options.get("custom_operations", None)
|
| 111 |
+
scaled_fp8 = None
|
| 112 |
+
|
| 113 |
+
if operations is None:
|
| 114 |
+
scaled_fp8 = model_options.get("scaled_fp8", None)
|
| 115 |
+
if scaled_fp8 is not None:
|
| 116 |
+
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
| 117 |
+
else:
|
| 118 |
+
operations = comfy.ops.manual_cast
|
| 119 |
+
|
| 120 |
+
self.operations = operations
|
| 121 |
+
self.transformer = model_class(config, dtype, device, self.operations)
|
| 122 |
+
if scaled_fp8 is not None:
|
| 123 |
+
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
| 124 |
+
|
| 125 |
+
self.num_layers = self.transformer.num_layers
|
| 126 |
+
|
| 127 |
+
self.max_length = max_length
|
| 128 |
+
if freeze:
|
| 129 |
+
self.freeze()
|
| 130 |
+
self.layer = layer
|
| 131 |
+
self.layer_idx = None
|
| 132 |
+
self.special_tokens = special_tokens
|
| 133 |
+
|
| 134 |
+
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
| 135 |
+
self.enable_attention_masks = enable_attention_masks
|
| 136 |
+
self.zero_out_masked = zero_out_masked
|
| 137 |
+
|
| 138 |
+
self.layer_norm_hidden_state = layer_norm_hidden_state
|
| 139 |
+
self.return_projected_pooled = return_projected_pooled
|
| 140 |
+
self.return_attention_masks = return_attention_masks
|
| 141 |
+
|
| 142 |
+
if layer == "hidden":
|
| 143 |
+
assert layer_idx is not None
|
| 144 |
+
assert abs(layer_idx) < self.num_layers
|
| 145 |
+
self.set_clip_options({"layer": layer_idx})
|
| 146 |
+
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
|
| 147 |
+
|
| 148 |
+
def freeze(self):
|
| 149 |
+
self.transformer = self.transformer.eval()
|
| 150 |
+
#self.train = disabled_train
|
| 151 |
+
for param in self.parameters():
|
| 152 |
+
param.requires_grad = False
|
| 153 |
+
|
| 154 |
+
def set_clip_options(self, options):
|
| 155 |
+
layer_idx = options.get("layer", self.layer_idx)
|
| 156 |
+
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
| 157 |
+
if self.layer == "all":
|
| 158 |
+
pass
|
| 159 |
+
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
| 160 |
+
self.layer = "last"
|
| 161 |
+
else:
|
| 162 |
+
self.layer = "hidden"
|
| 163 |
+
self.layer_idx = layer_idx
|
| 164 |
+
|
| 165 |
+
def reset_clip_options(self):
|
| 166 |
+
self.layer = self.options_default[0]
|
| 167 |
+
self.layer_idx = self.options_default[1]
|
| 168 |
+
self.return_projected_pooled = self.options_default[2]
|
| 169 |
+
|
| 170 |
+
def process_tokens(self, tokens, device):
|
| 171 |
+
end_token = self.special_tokens.get("end", None)
|
| 172 |
+
if end_token is None:
|
| 173 |
+
cmp_token = self.special_tokens.get("pad", -1)
|
| 174 |
+
else:
|
| 175 |
+
cmp_token = end_token
|
| 176 |
+
|
| 177 |
+
embeds_out = []
|
| 178 |
+
attention_masks = []
|
| 179 |
+
num_tokens = []
|
| 180 |
+
|
| 181 |
+
for x in tokens:
|
| 182 |
+
attention_mask = []
|
| 183 |
+
tokens_temp = []
|
| 184 |
+
other_embeds = []
|
| 185 |
+
eos = False
|
| 186 |
+
index = 0
|
| 187 |
+
for y in x:
|
| 188 |
+
if isinstance(y, numbers.Integral):
|
| 189 |
+
if eos:
|
| 190 |
+
attention_mask.append(0)
|
| 191 |
+
else:
|
| 192 |
+
attention_mask.append(1)
|
| 193 |
+
token = int(y)
|
| 194 |
+
tokens_temp += [token]
|
| 195 |
+
if not eos and token == cmp_token:
|
| 196 |
+
if end_token is None:
|
| 197 |
+
attention_mask[-1] = 0
|
| 198 |
+
eos = True
|
| 199 |
+
else:
|
| 200 |
+
other_embeds.append((index, y))
|
| 201 |
+
index += 1
|
| 202 |
+
|
| 203 |
+
tokens_embed = torch.tensor([tokens_temp], device=device, dtype=torch.long)
|
| 204 |
+
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
|
| 205 |
+
index = 0
|
| 206 |
+
pad_extra = 0
|
| 207 |
+
for o in other_embeds:
|
| 208 |
+
emb = o[1]
|
| 209 |
+
if torch.is_tensor(emb):
|
| 210 |
+
emb = {"type": "embedding", "data": emb}
|
| 211 |
+
|
| 212 |
+
emb_type = emb.get("type", None)
|
| 213 |
+
if emb_type == "embedding":
|
| 214 |
+
emb = emb.get("data", None)
|
| 215 |
+
else:
|
| 216 |
+
if hasattr(self.transformer, "preprocess_embed"):
|
| 217 |
+
emb = self.transformer.preprocess_embed(emb, device=device)
|
| 218 |
+
else:
|
| 219 |
+
emb = None
|
| 220 |
+
|
| 221 |
+
if emb is None:
|
| 222 |
+
index += -1
|
| 223 |
+
continue
|
| 224 |
+
|
| 225 |
+
ind = index + o[0]
|
| 226 |
+
emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32)
|
| 227 |
+
emb_shape = emb.shape[1]
|
| 228 |
+
if emb.shape[-1] == tokens_embed.shape[-1]:
|
| 229 |
+
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
|
| 230 |
+
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
|
| 231 |
+
index += emb_shape - 1
|
| 232 |
+
else:
|
| 233 |
+
index += -1
|
| 234 |
+
pad_extra += emb_shape
|
| 235 |
+
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))
|
| 236 |
+
|
| 237 |
+
if pad_extra > 0:
|
| 238 |
+
padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
|
| 239 |
+
tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1)
|
| 240 |
+
attention_mask = attention_mask + [0] * pad_extra
|
| 241 |
+
|
| 242 |
+
embeds_out.append(tokens_embed)
|
| 243 |
+
attention_masks.append(attention_mask)
|
| 244 |
+
num_tokens.append(sum(attention_mask))
|
| 245 |
+
|
| 246 |
+
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
|
| 247 |
+
|
| 248 |
+
def forward(self, tokens):
|
| 249 |
+
device = self.transformer.get_input_embeddings().weight.device
|
| 250 |
+
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
|
| 251 |
+
|
| 252 |
+
attention_mask_model = None
|
| 253 |
+
if self.enable_attention_masks:
|
| 254 |
+
attention_mask_model = attention_mask
|
| 255 |
+
|
| 256 |
+
if self.layer == "all":
|
| 257 |
+
intermediate_output = "all"
|
| 258 |
+
else:
|
| 259 |
+
intermediate_output = self.layer_idx
|
| 260 |
+
|
| 261 |
+
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
| 262 |
+
|
| 263 |
+
if self.layer == "last":
|
| 264 |
+
z = outputs[0].float()
|
| 265 |
+
else:
|
| 266 |
+
z = outputs[1].float()
|
| 267 |
+
|
| 268 |
+
if self.zero_out_masked:
|
| 269 |
+
z *= attention_mask.unsqueeze(-1).float()
|
| 270 |
+
|
| 271 |
+
pooled_output = None
|
| 272 |
+
if len(outputs) >= 3:
|
| 273 |
+
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
|
| 274 |
+
pooled_output = outputs[3].float()
|
| 275 |
+
elif outputs[2] is not None:
|
| 276 |
+
pooled_output = outputs[2].float()
|
| 277 |
+
|
| 278 |
+
extra = {}
|
| 279 |
+
if self.return_attention_masks:
|
| 280 |
+
extra["attention_mask"] = attention_mask
|
| 281 |
+
|
| 282 |
+
if len(extra) > 0:
|
| 283 |
+
return z, pooled_output, extra
|
| 284 |
+
|
| 285 |
+
return z, pooled_output
|
| 286 |
+
|
| 287 |
+
def encode(self, tokens):
|
| 288 |
+
return self(tokens)
|
| 289 |
+
|
| 290 |
+
def load_sd(self, sd):
|
| 291 |
+
return self.transformer.load_state_dict(sd, strict=False)
|
| 292 |
+
|
| 293 |
+
def parse_parentheses(string):
|
| 294 |
+
result = []
|
| 295 |
+
current_item = ""
|
| 296 |
+
nesting_level = 0
|
| 297 |
+
for char in string:
|
| 298 |
+
if char == "(":
|
| 299 |
+
if nesting_level == 0:
|
| 300 |
+
if current_item:
|
| 301 |
+
result.append(current_item)
|
| 302 |
+
current_item = "("
|
| 303 |
+
else:
|
| 304 |
+
current_item = "("
|
| 305 |
+
else:
|
| 306 |
+
current_item += char
|
| 307 |
+
nesting_level += 1
|
| 308 |
+
elif char == ")":
|
| 309 |
+
nesting_level -= 1
|
| 310 |
+
if nesting_level == 0:
|
| 311 |
+
result.append(current_item + ")")
|
| 312 |
+
current_item = ""
|
| 313 |
+
else:
|
| 314 |
+
current_item += char
|
| 315 |
+
else:
|
| 316 |
+
current_item += char
|
| 317 |
+
if current_item:
|
| 318 |
+
result.append(current_item)
|
| 319 |
+
return result
|
| 320 |
+
|
| 321 |
+
def token_weights(string, current_weight):
|
| 322 |
+
a = parse_parentheses(string)
|
| 323 |
+
out = []
|
| 324 |
+
for x in a:
|
| 325 |
+
weight = current_weight
|
| 326 |
+
if len(x) >= 2 and x[-1] == ')' and x[0] == '(':
|
| 327 |
+
x = x[1:-1]
|
| 328 |
+
xx = x.rfind(":")
|
| 329 |
+
weight *= 1.1
|
| 330 |
+
if xx > 0:
|
| 331 |
+
try:
|
| 332 |
+
weight = float(x[xx+1:])
|
| 333 |
+
x = x[:xx]
|
| 334 |
+
except:
|
| 335 |
+
pass
|
| 336 |
+
out += token_weights(x, weight)
|
| 337 |
+
else:
|
| 338 |
+
out += [(x, current_weight)]
|
| 339 |
+
return out
|
| 340 |
+
|
| 341 |
+
def escape_important(text):
|
| 342 |
+
text = text.replace("\\)", "\0\1")
|
| 343 |
+
text = text.replace("\\(", "\0\2")
|
| 344 |
+
return text
|
| 345 |
+
|
| 346 |
+
def unescape_important(text):
|
| 347 |
+
text = text.replace("\0\1", ")")
|
| 348 |
+
text = text.replace("\0\2", "(")
|
| 349 |
+
return text
|
| 350 |
+
|
| 351 |
+
def safe_load_embed_zip(embed_path):
|
| 352 |
+
with zipfile.ZipFile(embed_path) as myzip:
|
| 353 |
+
names = list(filter(lambda a: "data/" in a, myzip.namelist()))
|
| 354 |
+
names.reverse()
|
| 355 |
+
for n in names:
|
| 356 |
+
with myzip.open(n) as myfile:
|
| 357 |
+
data = myfile.read()
|
| 358 |
+
number = len(data) // 4
|
| 359 |
+
length_embed = 1024 #sd2.x
|
| 360 |
+
if number < 768:
|
| 361 |
+
continue
|
| 362 |
+
if number % 768 == 0:
|
| 363 |
+
length_embed = 768 #sd1.x
|
| 364 |
+
num_embeds = number // length_embed
|
| 365 |
+
embed = torch.frombuffer(data, dtype=torch.float)
|
| 366 |
+
out = embed.reshape((num_embeds, length_embed)).clone()
|
| 367 |
+
del embed
|
| 368 |
+
return out
|
| 369 |
+
|
| 370 |
+
def expand_directory_list(directories):
|
| 371 |
+
dirs = set()
|
| 372 |
+
for x in directories:
|
| 373 |
+
dirs.add(x)
|
| 374 |
+
for root, subdir, file in os.walk(x, followlinks=True):
|
| 375 |
+
dirs.add(root)
|
| 376 |
+
return list(dirs)
|
| 377 |
+
|
| 378 |
+
def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format
|
| 379 |
+
out_list = []
|
| 380 |
+
for k in embed:
|
| 381 |
+
if k.startswith(prefix) and k.endswith(suffix):
|
| 382 |
+
out_list.append(embed[k])
|
| 383 |
+
if len(out_list) == 0:
|
| 384 |
+
return None
|
| 385 |
+
|
| 386 |
+
return torch.cat(out_list, dim=0)
|
| 387 |
+
|
| 388 |
+
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
|
| 389 |
+
if isinstance(embedding_directory, str):
|
| 390 |
+
embedding_directory = [embedding_directory]
|
| 391 |
+
|
| 392 |
+
embedding_directory = expand_directory_list(embedding_directory)
|
| 393 |
+
|
| 394 |
+
valid_file = None
|
| 395 |
+
for embed_dir in embedding_directory:
|
| 396 |
+
embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
|
| 397 |
+
embed_dir = os.path.abspath(embed_dir)
|
| 398 |
+
try:
|
| 399 |
+
if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
|
| 400 |
+
continue
|
| 401 |
+
except:
|
| 402 |
+
continue
|
| 403 |
+
if not os.path.isfile(embed_path):
|
| 404 |
+
extensions = ['.safetensors', '.pt', '.bin']
|
| 405 |
+
for x in extensions:
|
| 406 |
+
t = embed_path + x
|
| 407 |
+
if os.path.isfile(t):
|
| 408 |
+
valid_file = t
|
| 409 |
+
break
|
| 410 |
+
else:
|
| 411 |
+
valid_file = embed_path
|
| 412 |
+
if valid_file is not None:
|
| 413 |
+
break
|
| 414 |
+
|
| 415 |
+
if valid_file is None:
|
| 416 |
+
return None
|
| 417 |
+
|
| 418 |
+
embed_path = valid_file
|
| 419 |
+
|
| 420 |
+
embed_out = None
|
| 421 |
+
|
| 422 |
+
try:
|
| 423 |
+
if embed_path.lower().endswith(".safetensors"):
|
| 424 |
+
import safetensors.torch
|
| 425 |
+
embed = safetensors.torch.load_file(embed_path, device="cpu")
|
| 426 |
+
else:
|
| 427 |
+
try:
|
| 428 |
+
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
|
| 429 |
+
except:
|
| 430 |
+
embed_out = safe_load_embed_zip(embed_path)
|
| 431 |
+
except Exception:
|
| 432 |
+
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
|
| 433 |
+
return None
|
| 434 |
+
|
| 435 |
+
if embed_out is None:
|
| 436 |
+
if 'string_to_param' in embed:
|
| 437 |
+
values = embed['string_to_param'].values()
|
| 438 |
+
embed_out = next(iter(values))
|
| 439 |
+
elif isinstance(embed, list):
|
| 440 |
+
out_list = []
|
| 441 |
+
for x in range(len(embed)):
|
| 442 |
+
for k in embed[x]:
|
| 443 |
+
t = embed[x][k]
|
| 444 |
+
if t.shape[-1] != embedding_size:
|
| 445 |
+
continue
|
| 446 |
+
out_list.append(t.reshape(-1, t.shape[-1]))
|
| 447 |
+
embed_out = torch.cat(out_list, dim=0)
|
| 448 |
+
elif embed_key is not None and embed_key in embed:
|
| 449 |
+
embed_out = embed[embed_key]
|
| 450 |
+
else:
|
| 451 |
+
embed_out = bundled_embed(embed, 'bundle_emb.', '.string_to_param.*')
|
| 452 |
+
if embed_out is None:
|
| 453 |
+
embed_out = bundled_embed(embed, 'bundle_emb.', '.{}'.format(embed_key))
|
| 454 |
+
if embed_out is None:
|
| 455 |
+
values = embed.values()
|
| 456 |
+
embed_out = next(iter(values))
|
| 457 |
+
return embed_out
|
| 458 |
+
|
| 459 |
+
class SDTokenizer:
|
| 460 |
+
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
|
| 461 |
+
if tokenizer_path is None:
|
| 462 |
+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
| 463 |
+
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
| 464 |
+
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
|
| 465 |
+
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
|
| 466 |
+
self.end_token = None
|
| 467 |
+
self.min_padding = min_padding
|
| 468 |
+
|
| 469 |
+
empty = self.tokenizer('')["input_ids"]
|
| 470 |
+
self.tokenizer_adds_end_token = has_end_token
|
| 471 |
+
if has_start_token:
|
| 472 |
+
self.tokens_start = 1
|
| 473 |
+
self.start_token = empty[0]
|
| 474 |
+
if end_token is not None:
|
| 475 |
+
self.end_token = end_token
|
| 476 |
+
else:
|
| 477 |
+
if has_end_token:
|
| 478 |
+
self.end_token = empty[1]
|
| 479 |
+
else:
|
| 480 |
+
self.tokens_start = 0
|
| 481 |
+
self.start_token = None
|
| 482 |
+
if end_token is not None:
|
| 483 |
+
self.end_token = end_token
|
| 484 |
+
else:
|
| 485 |
+
if has_end_token:
|
| 486 |
+
self.end_token = empty[0]
|
| 487 |
+
|
| 488 |
+
if pad_token is not None:
|
| 489 |
+
self.pad_token = pad_token
|
| 490 |
+
elif pad_with_end:
|
| 491 |
+
self.pad_token = self.end_token
|
| 492 |
+
else:
|
| 493 |
+
self.pad_token = 0
|
| 494 |
+
|
| 495 |
+
self.pad_with_end = pad_with_end
|
| 496 |
+
self.pad_to_max_length = pad_to_max_length
|
| 497 |
+
|
| 498 |
+
vocab = self.tokenizer.get_vocab()
|
| 499 |
+
self.inv_vocab = {v: k for k, v in vocab.items()}
|
| 500 |
+
self.embedding_directory = embedding_directory
|
| 501 |
+
self.max_word_length = 8
|
| 502 |
+
self.embedding_identifier = "embedding:"
|
| 503 |
+
self.embedding_size = embedding_size
|
| 504 |
+
self.embedding_key = embedding_key
|
| 505 |
+
|
| 506 |
+
def _try_get_embedding(self, embedding_name:str):
|
| 507 |
+
'''
|
| 508 |
+
Takes a potential embedding name and tries to retrieve it.
|
| 509 |
+
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
| 510 |
+
'''
|
| 511 |
+
split_embed = embedding_name.split()
|
| 512 |
+
embedding_name = split_embed[0]
|
| 513 |
+
leftover = ' '.join(split_embed[1:])
|
| 514 |
+
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
| 515 |
+
if embed is None:
|
| 516 |
+
stripped = embedding_name.strip(',')
|
| 517 |
+
if len(stripped) < len(embedding_name):
|
| 518 |
+
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
|
| 519 |
+
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
| 520 |
+
return (embed, leftover)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
|
| 524 |
+
'''
|
| 525 |
+
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
| 526 |
+
Tokens can both be integer tokens and pre computed CLIP tensors.
|
| 527 |
+
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
|
| 528 |
+
Returned list has the dimensions NxM where M is the input size of CLIP
|
| 529 |
+
'''
|
| 530 |
+
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
|
| 531 |
+
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
| 532 |
+
|
| 533 |
+
text = escape_important(text)
|
| 534 |
+
parsed_weights = token_weights(text, 1.0)
|
| 535 |
+
|
| 536 |
+
# tokenize words
|
| 537 |
+
tokens = []
|
| 538 |
+
for weighted_segment, weight in parsed_weights:
|
| 539 |
+
to_tokenize = unescape_important(weighted_segment)
|
| 540 |
+
split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
|
| 541 |
+
to_tokenize = [split[0]]
|
| 542 |
+
for i in range(1, len(split)):
|
| 543 |
+
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
|
| 544 |
+
|
| 545 |
+
to_tokenize = [x for x in to_tokenize if x != ""]
|
| 546 |
+
for word in to_tokenize:
|
| 547 |
+
# if we find an embedding, deal with the embedding
|
| 548 |
+
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
|
| 549 |
+
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
| 550 |
+
embed, leftover = self._try_get_embedding(embedding_name)
|
| 551 |
+
if embed is None:
|
| 552 |
+
logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
|
| 553 |
+
else:
|
| 554 |
+
if len(embed.shape) == 1:
|
| 555 |
+
tokens.append([(embed, weight)])
|
| 556 |
+
else:
|
| 557 |
+
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
|
| 558 |
+
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
|
| 559 |
+
if leftover != "":
|
| 560 |
+
word = leftover
|
| 561 |
+
else:
|
| 562 |
+
continue
|
| 563 |
+
end = 999999999999
|
| 564 |
+
if self.tokenizer_adds_end_token:
|
| 565 |
+
end = -1
|
| 566 |
+
#parse word
|
| 567 |
+
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:end]])
|
| 568 |
+
|
| 569 |
+
#reshape token array to CLIP input size
|
| 570 |
+
batched_tokens = []
|
| 571 |
+
batch = []
|
| 572 |
+
if self.start_token is not None:
|
| 573 |
+
batch.append((self.start_token, 1.0, 0))
|
| 574 |
+
batched_tokens.append(batch)
|
| 575 |
+
for i, t_group in enumerate(tokens):
|
| 576 |
+
#determine if we're going to try and keep the tokens in a single batch
|
| 577 |
+
is_large = len(t_group) >= self.max_word_length
|
| 578 |
+
if self.end_token is not None:
|
| 579 |
+
has_end_token = 1
|
| 580 |
+
else:
|
| 581 |
+
has_end_token = 0
|
| 582 |
+
|
| 583 |
+
while len(t_group) > 0:
|
| 584 |
+
if len(t_group) + len(batch) > self.max_length - has_end_token:
|
| 585 |
+
remaining_length = self.max_length - len(batch) - has_end_token
|
| 586 |
+
#break word in two and add end token
|
| 587 |
+
if is_large:
|
| 588 |
+
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
| 589 |
+
if self.end_token is not None:
|
| 590 |
+
batch.append((self.end_token, 1.0, 0))
|
| 591 |
+
t_group = t_group[remaining_length:]
|
| 592 |
+
#add end token and pad
|
| 593 |
+
else:
|
| 594 |
+
if self.end_token is not None:
|
| 595 |
+
batch.append((self.end_token, 1.0, 0))
|
| 596 |
+
if self.pad_to_max_length:
|
| 597 |
+
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
| 598 |
+
#start new batch
|
| 599 |
+
batch = []
|
| 600 |
+
if self.start_token is not None:
|
| 601 |
+
batch.append((self.start_token, 1.0, 0))
|
| 602 |
+
batched_tokens.append(batch)
|
| 603 |
+
else:
|
| 604 |
+
batch.extend([(t,w,i+1) for t,w in t_group])
|
| 605 |
+
t_group = []
|
| 606 |
+
|
| 607 |
+
#fill last batch
|
| 608 |
+
if self.end_token is not None:
|
| 609 |
+
batch.append((self.end_token, 1.0, 0))
|
| 610 |
+
if min_padding is not None:
|
| 611 |
+
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
|
| 612 |
+
if self.pad_to_max_length and len(batch) < self.max_length:
|
| 613 |
+
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
| 614 |
+
if min_length is not None and len(batch) < min_length:
|
| 615 |
+
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
|
| 616 |
+
|
| 617 |
+
if not return_word_ids:
|
| 618 |
+
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
| 619 |
+
|
| 620 |
+
return batched_tokens
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def untokenize(self, token_weight_pair):
|
| 624 |
+
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
| 625 |
+
|
| 626 |
+
def state_dict(self):
|
| 627 |
+
return {}
|
| 628 |
+
|
| 629 |
+
class SD1Tokenizer:
|
| 630 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
|
| 631 |
+
if name is not None:
|
| 632 |
+
self.clip_name = name
|
| 633 |
+
self.clip = "{}".format(self.clip_name)
|
| 634 |
+
else:
|
| 635 |
+
self.clip_name = clip_name
|
| 636 |
+
self.clip = "clip_{}".format(self.clip_name)
|
| 637 |
+
|
| 638 |
+
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
|
| 639 |
+
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
|
| 640 |
+
|
| 641 |
+
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
| 642 |
+
out = {}
|
| 643 |
+
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 644 |
+
return out
|
| 645 |
+
|
| 646 |
+
def untokenize(self, token_weight_pair):
|
| 647 |
+
return getattr(self, self.clip).untokenize(token_weight_pair)
|
| 648 |
+
|
| 649 |
+
def state_dict(self):
|
| 650 |
+
return getattr(self, self.clip).state_dict()
|
| 651 |
+
|
| 652 |
+
class SD1CheckpointClipModel(SDClipModel):
|
| 653 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 654 |
+
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
| 655 |
+
|
| 656 |
+
class SD1ClipModel(torch.nn.Module):
|
| 657 |
+
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, name=None, **kwargs):
|
| 658 |
+
super().__init__()
|
| 659 |
+
|
| 660 |
+
if name is not None:
|
| 661 |
+
self.clip_name = name
|
| 662 |
+
self.clip = "{}".format(self.clip_name)
|
| 663 |
+
else:
|
| 664 |
+
self.clip_name = clip_name
|
| 665 |
+
self.clip = "clip_{}".format(self.clip_name)
|
| 666 |
+
|
| 667 |
+
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
|
| 668 |
+
model_options = {**model_options, "model_name": self.clip}
|
| 669 |
+
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
|
| 670 |
+
|
| 671 |
+
self.dtypes = set()
|
| 672 |
+
if dtype is not None:
|
| 673 |
+
self.dtypes.add(dtype)
|
| 674 |
+
|
| 675 |
+
def set_clip_options(self, options):
|
| 676 |
+
getattr(self, self.clip).set_clip_options(options)
|
| 677 |
+
|
| 678 |
+
def reset_clip_options(self):
|
| 679 |
+
getattr(self, self.clip).reset_clip_options()
|
| 680 |
+
|
| 681 |
+
def encode_token_weights(self, token_weight_pairs):
|
| 682 |
+
token_weight_pairs = token_weight_pairs[self.clip_name]
|
| 683 |
+
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
| 684 |
+
return out
|
| 685 |
+
|
| 686 |
+
def load_sd(self, sd):
|
| 687 |
+
return getattr(self, self.clip).load_sd(sd)
|
ComfyUI/comfy/sd1_clip_config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"CLIPTextModel"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"eos_token_id": 49407,
|
| 10 |
+
"hidden_act": "quick_gelu",
|
| 11 |
+
"hidden_size": 768,
|
| 12 |
+
"initializer_factor": 1.0,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": 3072,
|
| 15 |
+
"layer_norm_eps": 1e-05,
|
| 16 |
+
"max_position_embeddings": 77,
|
| 17 |
+
"model_type": "clip_text_model",
|
| 18 |
+
"num_attention_heads": 12,
|
| 19 |
+
"num_hidden_layers": 12,
|
| 20 |
+
"pad_token_id": 1,
|
| 21 |
+
"projection_dim": 768,
|
| 22 |
+
"torch_dtype": "float32",
|
| 23 |
+
"transformers_version": "4.24.0",
|
| 24 |
+
"vocab_size": 49408
|
| 25 |
+
}
|
ComfyUI/comfy/sd1_tokenizer/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ComfyUI/comfy/sd1_tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"bos_token": {
|
| 4 |
+
"__type": "AddedToken",
|
| 5 |
+
"content": "<|startoftext|>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": true,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false
|
| 10 |
+
},
|
| 11 |
+
"do_lower_case": true,
|
| 12 |
+
"eos_token": {
|
| 13 |
+
"__type": "AddedToken",
|
| 14 |
+
"content": "<|endoftext|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": true,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false
|
| 19 |
+
},
|
| 20 |
+
"errors": "replace",
|
| 21 |
+
"model_max_length": 8192,
|
| 22 |
+
"name_or_path": "openai/clip-vit-large-patch14",
|
| 23 |
+
"pad_token": "<|endoftext|>",
|
| 24 |
+
"special_tokens_map_file": "./special_tokens_map.json",
|
| 25 |
+
"tokenizer_class": "CLIPTokenizer",
|
| 26 |
+
"unk_token": {
|
| 27 |
+
"__type": "AddedToken",
|
| 28 |
+
"content": "<|endoftext|>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": true,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false
|
| 33 |
+
}
|
| 34 |
+
}
|
ComfyUI/comfy/sd1_tokenizer/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ComfyUI/comfy/supported_models.py
ADDED
|
@@ -0,0 +1,1235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import model_base
|
| 3 |
+
from . import utils
|
| 4 |
+
|
| 5 |
+
from . import sd1_clip
|
| 6 |
+
from . import sdxl_clip
|
| 7 |
+
import comfy.text_encoders.sd2_clip
|
| 8 |
+
import comfy.text_encoders.sd3_clip
|
| 9 |
+
import comfy.text_encoders.sa_t5
|
| 10 |
+
import comfy.text_encoders.aura_t5
|
| 11 |
+
import comfy.text_encoders.pixart_t5
|
| 12 |
+
import comfy.text_encoders.hydit
|
| 13 |
+
import comfy.text_encoders.flux
|
| 14 |
+
import comfy.text_encoders.genmo
|
| 15 |
+
import comfy.text_encoders.lt
|
| 16 |
+
import comfy.text_encoders.hunyuan_video
|
| 17 |
+
import comfy.text_encoders.cosmos
|
| 18 |
+
import comfy.text_encoders.lumina2
|
| 19 |
+
import comfy.text_encoders.wan
|
| 20 |
+
import comfy.text_encoders.ace
|
| 21 |
+
import comfy.text_encoders.omnigen2
|
| 22 |
+
|
| 23 |
+
from . import supported_models_base
|
| 24 |
+
from . import latent_formats
|
| 25 |
+
|
| 26 |
+
from . import diffusers_convert
|
| 27 |
+
|
| 28 |
+
class SD15(supported_models_base.BASE):
|
| 29 |
+
unet_config = {
|
| 30 |
+
"context_dim": 768,
|
| 31 |
+
"model_channels": 320,
|
| 32 |
+
"use_linear_in_transformer": False,
|
| 33 |
+
"adm_in_channels": None,
|
| 34 |
+
"use_temporal_attention": False,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
unet_extra_config = {
|
| 38 |
+
"num_heads": 8,
|
| 39 |
+
"num_head_channels": -1,
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
latent_format = latent_formats.SD15
|
| 43 |
+
memory_usage_factor = 1.0
|
| 44 |
+
|
| 45 |
+
def process_clip_state_dict(self, state_dict):
|
| 46 |
+
k = list(state_dict.keys())
|
| 47 |
+
for x in k:
|
| 48 |
+
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
|
| 49 |
+
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
|
| 50 |
+
state_dict[y] = state_dict.pop(x)
|
| 51 |
+
|
| 52 |
+
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in state_dict:
|
| 53 |
+
ids = state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids']
|
| 54 |
+
if ids.dtype == torch.float32:
|
| 55 |
+
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
| 56 |
+
|
| 57 |
+
replace_prefix = {}
|
| 58 |
+
replace_prefix["cond_stage_model."] = "clip_l."
|
| 59 |
+
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
| 60 |
+
return state_dict
|
| 61 |
+
|
| 62 |
+
def process_clip_state_dict_for_saving(self, state_dict):
|
| 63 |
+
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
|
| 64 |
+
for p in pop_keys:
|
| 65 |
+
if p in state_dict:
|
| 66 |
+
state_dict.pop(p)
|
| 67 |
+
|
| 68 |
+
replace_prefix = {"clip_l.": "cond_stage_model."}
|
| 69 |
+
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
| 70 |
+
|
| 71 |
+
def clip_target(self, state_dict={}):
|
| 72 |
+
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
| 73 |
+
|
| 74 |
+
class SD20(supported_models_base.BASE):
|
| 75 |
+
unet_config = {
|
| 76 |
+
"context_dim": 1024,
|
| 77 |
+
"model_channels": 320,
|
| 78 |
+
"use_linear_in_transformer": True,
|
| 79 |
+
"adm_in_channels": None,
|
| 80 |
+
"use_temporal_attention": False,
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
unet_extra_config = {
|
| 84 |
+
"num_heads": -1,
|
| 85 |
+
"num_head_channels": 64,
|
| 86 |
+
"attn_precision": torch.float32,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
latent_format = latent_formats.SD15
|
| 90 |
+
memory_usage_factor = 1.0
|
| 91 |
+
|
| 92 |
+
def model_type(self, state_dict, prefix=""):
|
| 93 |
+
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
| 94 |
+
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
| 95 |
+
out = state_dict.get(k, None)
|
| 96 |
+
if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
| 97 |
+
return model_base.ModelType.V_PREDICTION
|
| 98 |
+
return model_base.ModelType.EPS
|
| 99 |
+
|
| 100 |
+
def process_clip_state_dict(self, state_dict):
|
| 101 |
+
replace_prefix = {}
|
| 102 |
+
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
|
| 103 |
+
replace_prefix["cond_stage_model.model."] = "clip_h."
|
| 104 |
+
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
| 105 |
+
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
|
| 106 |
+
return state_dict
|
| 107 |
+
|
| 108 |
+
def process_clip_state_dict_for_saving(self, state_dict):
|
| 109 |
+
replace_prefix = {}
|
| 110 |
+
replace_prefix["clip_h"] = "cond_stage_model.model"
|
| 111 |
+
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
| 112 |
+
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
| 113 |
+
return state_dict
|
| 114 |
+
|
| 115 |
+
def clip_target(self, state_dict={}):
|
| 116 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.sd2_clip.SD2Tokenizer, comfy.text_encoders.sd2_clip.SD2ClipModel)
|
| 117 |
+
|
| 118 |
+
class SD21UnclipL(SD20):
|
| 119 |
+
unet_config = {
|
| 120 |
+
"context_dim": 1024,
|
| 121 |
+
"model_channels": 320,
|
| 122 |
+
"use_linear_in_transformer": True,
|
| 123 |
+
"adm_in_channels": 1536,
|
| 124 |
+
"use_temporal_attention": False,
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
clip_vision_prefix = "embedder.model.visual."
|
| 128 |
+
noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 768}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class SD21UnclipH(SD20):
|
| 132 |
+
unet_config = {
|
| 133 |
+
"context_dim": 1024,
|
| 134 |
+
"model_channels": 320,
|
| 135 |
+
"use_linear_in_transformer": True,
|
| 136 |
+
"adm_in_channels": 2048,
|
| 137 |
+
"use_temporal_attention": False,
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
clip_vision_prefix = "embedder.model.visual."
|
| 141 |
+
noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1024}
|
| 142 |
+
|
| 143 |
+
class SDXLRefiner(supported_models_base.BASE):
|
| 144 |
+
unet_config = {
|
| 145 |
+
"model_channels": 384,
|
| 146 |
+
"use_linear_in_transformer": True,
|
| 147 |
+
"context_dim": 1280,
|
| 148 |
+
"adm_in_channels": 2560,
|
| 149 |
+
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
|
| 150 |
+
"use_temporal_attention": False,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
latent_format = latent_formats.SDXL
|
| 154 |
+
memory_usage_factor = 1.0
|
| 155 |
+
|
| 156 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 157 |
+
return model_base.SDXLRefiner(self, device=device)
|
| 158 |
+
|
| 159 |
+
def process_clip_state_dict(self, state_dict):
|
| 160 |
+
keys_to_replace = {}
|
| 161 |
+
replace_prefix = {}
|
| 162 |
+
replace_prefix["conditioner.embedders.0.model."] = "clip_g."
|
| 163 |
+
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
| 164 |
+
|
| 165 |
+
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
|
| 166 |
+
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
| 167 |
+
return state_dict
|
| 168 |
+
|
| 169 |
+
def process_clip_state_dict_for_saving(self, state_dict):
|
| 170 |
+
replace_prefix = {}
|
| 171 |
+
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
| 172 |
+
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
|
| 173 |
+
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
| 174 |
+
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
| 175 |
+
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
| 176 |
+
return state_dict_g
|
| 177 |
+
|
| 178 |
+
def clip_target(self, state_dict={}):
|
| 179 |
+
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
| 180 |
+
|
| 181 |
+
class SDXL(supported_models_base.BASE):
|
| 182 |
+
unet_config = {
|
| 183 |
+
"model_channels": 320,
|
| 184 |
+
"use_linear_in_transformer": True,
|
| 185 |
+
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
| 186 |
+
"context_dim": 2048,
|
| 187 |
+
"adm_in_channels": 2816,
|
| 188 |
+
"use_temporal_attention": False,
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
latent_format = latent_formats.SDXL
|
| 192 |
+
|
| 193 |
+
memory_usage_factor = 0.8
|
| 194 |
+
|
| 195 |
+
def model_type(self, state_dict, prefix=""):
|
| 196 |
+
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
| 197 |
+
self.latent_format = latent_formats.SDXL_Playground_2_5()
|
| 198 |
+
self.sampling_settings["sigma_data"] = 0.5
|
| 199 |
+
self.sampling_settings["sigma_max"] = 80.0
|
| 200 |
+
self.sampling_settings["sigma_min"] = 0.002
|
| 201 |
+
return model_base.ModelType.EDM
|
| 202 |
+
elif "edm_vpred.sigma_max" in state_dict:
|
| 203 |
+
self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item())
|
| 204 |
+
if "edm_vpred.sigma_min" in state_dict:
|
| 205 |
+
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
|
| 206 |
+
return model_base.ModelType.V_PREDICTION_EDM
|
| 207 |
+
elif "v_pred" in state_dict:
|
| 208 |
+
if "ztsnr" in state_dict: #Some zsnr anime checkpoints
|
| 209 |
+
self.sampling_settings["zsnr"] = True
|
| 210 |
+
return model_base.ModelType.V_PREDICTION
|
| 211 |
+
else:
|
| 212 |
+
return model_base.ModelType.EPS
|
| 213 |
+
|
| 214 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 215 |
+
out = model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
|
| 216 |
+
if self.inpaint_model():
|
| 217 |
+
out.set_inpaint()
|
| 218 |
+
return out
|
| 219 |
+
|
| 220 |
+
def process_clip_state_dict(self, state_dict):
|
| 221 |
+
keys_to_replace = {}
|
| 222 |
+
replace_prefix = {}
|
| 223 |
+
|
| 224 |
+
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model"
|
| 225 |
+
replace_prefix["conditioner.embedders.1.model."] = "clip_g."
|
| 226 |
+
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
| 227 |
+
|
| 228 |
+
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
| 229 |
+
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
|
| 230 |
+
return state_dict
|
| 231 |
+
|
| 232 |
+
def process_clip_state_dict_for_saving(self, state_dict):
|
| 233 |
+
replace_prefix = {}
|
| 234 |
+
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
| 235 |
+
for k in state_dict:
|
| 236 |
+
if k.startswith("clip_l"):
|
| 237 |
+
state_dict_g[k] = state_dict[k]
|
| 238 |
+
|
| 239 |
+
state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1))
|
| 240 |
+
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
|
| 241 |
+
for p in pop_keys:
|
| 242 |
+
if p in state_dict_g:
|
| 243 |
+
state_dict_g.pop(p)
|
| 244 |
+
|
| 245 |
+
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
| 246 |
+
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
| 247 |
+
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
| 248 |
+
return state_dict_g
|
| 249 |
+
|
| 250 |
+
def clip_target(self, state_dict={}):
|
| 251 |
+
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
| 252 |
+
|
| 253 |
+
class SSD1B(SDXL):
|
| 254 |
+
unet_config = {
|
| 255 |
+
"model_channels": 320,
|
| 256 |
+
"use_linear_in_transformer": True,
|
| 257 |
+
"transformer_depth": [0, 0, 2, 2, 4, 4],
|
| 258 |
+
"context_dim": 2048,
|
| 259 |
+
"adm_in_channels": 2816,
|
| 260 |
+
"use_temporal_attention": False,
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
class Segmind_Vega(SDXL):
|
| 264 |
+
unet_config = {
|
| 265 |
+
"model_channels": 320,
|
| 266 |
+
"use_linear_in_transformer": True,
|
| 267 |
+
"transformer_depth": [0, 0, 1, 1, 2, 2],
|
| 268 |
+
"context_dim": 2048,
|
| 269 |
+
"adm_in_channels": 2816,
|
| 270 |
+
"use_temporal_attention": False,
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
class KOALA_700M(SDXL):
|
| 274 |
+
unet_config = {
|
| 275 |
+
"model_channels": 320,
|
| 276 |
+
"use_linear_in_transformer": True,
|
| 277 |
+
"transformer_depth": [0, 2, 5],
|
| 278 |
+
"context_dim": 2048,
|
| 279 |
+
"adm_in_channels": 2816,
|
| 280 |
+
"use_temporal_attention": False,
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
class KOALA_1B(SDXL):
|
| 284 |
+
unet_config = {
|
| 285 |
+
"model_channels": 320,
|
| 286 |
+
"use_linear_in_transformer": True,
|
| 287 |
+
"transformer_depth": [0, 2, 6],
|
| 288 |
+
"context_dim": 2048,
|
| 289 |
+
"adm_in_channels": 2816,
|
| 290 |
+
"use_temporal_attention": False,
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
class SVD_img2vid(supported_models_base.BASE):
|
| 294 |
+
unet_config = {
|
| 295 |
+
"model_channels": 320,
|
| 296 |
+
"in_channels": 8,
|
| 297 |
+
"use_linear_in_transformer": True,
|
| 298 |
+
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
| 299 |
+
"context_dim": 1024,
|
| 300 |
+
"adm_in_channels": 768,
|
| 301 |
+
"use_temporal_attention": True,
|
| 302 |
+
"use_temporal_resblock": True
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
unet_extra_config = {
|
| 306 |
+
"num_heads": -1,
|
| 307 |
+
"num_head_channels": 64,
|
| 308 |
+
"attn_precision": torch.float32,
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
|
| 312 |
+
|
| 313 |
+
latent_format = latent_formats.SD15
|
| 314 |
+
|
| 315 |
+
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
|
| 316 |
+
|
| 317 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 318 |
+
out = model_base.SVD_img2vid(self, device=device)
|
| 319 |
+
return out
|
| 320 |
+
|
| 321 |
+
def clip_target(self, state_dict={}):
|
| 322 |
+
return None
|
| 323 |
+
|
| 324 |
+
class SV3D_u(SVD_img2vid):
|
| 325 |
+
unet_config = {
|
| 326 |
+
"model_channels": 320,
|
| 327 |
+
"in_channels": 8,
|
| 328 |
+
"use_linear_in_transformer": True,
|
| 329 |
+
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
| 330 |
+
"context_dim": 1024,
|
| 331 |
+
"adm_in_channels": 256,
|
| 332 |
+
"use_temporal_attention": True,
|
| 333 |
+
"use_temporal_resblock": True
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
vae_key_prefix = ["conditioner.embedders.1.encoder."]
|
| 337 |
+
|
| 338 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 339 |
+
out = model_base.SV3D_u(self, device=device)
|
| 340 |
+
return out
|
| 341 |
+
|
| 342 |
+
class SV3D_p(SV3D_u):
|
| 343 |
+
unet_config = {
|
| 344 |
+
"model_channels": 320,
|
| 345 |
+
"in_channels": 8,
|
| 346 |
+
"use_linear_in_transformer": True,
|
| 347 |
+
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
| 348 |
+
"context_dim": 1024,
|
| 349 |
+
"adm_in_channels": 1280,
|
| 350 |
+
"use_temporal_attention": True,
|
| 351 |
+
"use_temporal_resblock": True
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 356 |
+
out = model_base.SV3D_p(self, device=device)
|
| 357 |
+
return out
|
| 358 |
+
|
| 359 |
+
class Stable_Zero123(supported_models_base.BASE):
|
| 360 |
+
unet_config = {
|
| 361 |
+
"context_dim": 768,
|
| 362 |
+
"model_channels": 320,
|
| 363 |
+
"use_linear_in_transformer": False,
|
| 364 |
+
"adm_in_channels": None,
|
| 365 |
+
"use_temporal_attention": False,
|
| 366 |
+
"in_channels": 8,
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
unet_extra_config = {
|
| 370 |
+
"num_heads": 8,
|
| 371 |
+
"num_head_channels": -1,
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
required_keys = {
|
| 375 |
+
"cc_projection.weight": None,
|
| 376 |
+
"cc_projection.bias": None,
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
clip_vision_prefix = "cond_stage_model.model.visual."
|
| 380 |
+
|
| 381 |
+
latent_format = latent_formats.SD15
|
| 382 |
+
|
| 383 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 384 |
+
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
| 385 |
+
return out
|
| 386 |
+
|
| 387 |
+
def clip_target(self, state_dict={}):
|
| 388 |
+
return None
|
| 389 |
+
|
| 390 |
+
class SD_X4Upscaler(SD20):
|
| 391 |
+
unet_config = {
|
| 392 |
+
"context_dim": 1024,
|
| 393 |
+
"model_channels": 256,
|
| 394 |
+
'in_channels': 7,
|
| 395 |
+
"use_linear_in_transformer": True,
|
| 396 |
+
"adm_in_channels": None,
|
| 397 |
+
"use_temporal_attention": False,
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
unet_extra_config = {
|
| 401 |
+
"disable_self_attentions": [True, True, True, False],
|
| 402 |
+
"num_classes": 1000,
|
| 403 |
+
"num_heads": 8,
|
| 404 |
+
"num_head_channels": -1,
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
latent_format = latent_formats.SD_X4
|
| 408 |
+
|
| 409 |
+
sampling_settings = {
|
| 410 |
+
"linear_start": 0.0001,
|
| 411 |
+
"linear_end": 0.02,
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 415 |
+
out = model_base.SD_X4Upscaler(self, device=device)
|
| 416 |
+
return out
|
| 417 |
+
|
| 418 |
+
class Stable_Cascade_C(supported_models_base.BASE):
|
| 419 |
+
unet_config = {
|
| 420 |
+
"stable_cascade_stage": 'c',
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
unet_extra_config = {}
|
| 424 |
+
|
| 425 |
+
latent_format = latent_formats.SC_Prior
|
| 426 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
| 427 |
+
|
| 428 |
+
sampling_settings = {
|
| 429 |
+
"shift": 2.0,
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
vae_key_prefix = ["vae."]
|
| 433 |
+
text_encoder_key_prefix = ["text_encoder."]
|
| 434 |
+
clip_vision_prefix = "clip_l_vision."
|
| 435 |
+
|
| 436 |
+
def process_unet_state_dict(self, state_dict):
|
| 437 |
+
key_list = list(state_dict.keys())
|
| 438 |
+
for y in ["weight", "bias"]:
|
| 439 |
+
suffix = "in_proj_{}".format(y)
|
| 440 |
+
keys = filter(lambda a: a.endswith(suffix), key_list)
|
| 441 |
+
for k_from in keys:
|
| 442 |
+
weights = state_dict.pop(k_from)
|
| 443 |
+
prefix = k_from[:-(len(suffix) + 1)]
|
| 444 |
+
shape_from = weights.shape[0] // 3
|
| 445 |
+
for x in range(3):
|
| 446 |
+
p = ["to_q", "to_k", "to_v"]
|
| 447 |
+
k_to = "{}.{}.{}".format(prefix, p[x], y)
|
| 448 |
+
state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
| 449 |
+
return state_dict
|
| 450 |
+
|
| 451 |
+
def process_clip_state_dict(self, state_dict):
|
| 452 |
+
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
| 453 |
+
if "clip_g.text_projection" in state_dict:
|
| 454 |
+
state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1)
|
| 455 |
+
return state_dict
|
| 456 |
+
|
| 457 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 458 |
+
out = model_base.StableCascade_C(self, device=device)
|
| 459 |
+
return out
|
| 460 |
+
|
| 461 |
+
def clip_target(self, state_dict={}):
|
| 462 |
+
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
|
| 463 |
+
|
| 464 |
+
class Stable_Cascade_B(Stable_Cascade_C):
|
| 465 |
+
unet_config = {
|
| 466 |
+
"stable_cascade_stage": 'b',
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
unet_extra_config = {}
|
| 470 |
+
|
| 471 |
+
latent_format = latent_formats.SC_B
|
| 472 |
+
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
| 473 |
+
|
| 474 |
+
sampling_settings = {
|
| 475 |
+
"shift": 1.0,
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
clip_vision_prefix = None
|
| 479 |
+
|
| 480 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 481 |
+
out = model_base.StableCascade_B(self, device=device)
|
| 482 |
+
return out
|
| 483 |
+
|
| 484 |
+
class SD15_instructpix2pix(SD15):
|
| 485 |
+
unet_config = {
|
| 486 |
+
"context_dim": 768,
|
| 487 |
+
"model_channels": 320,
|
| 488 |
+
"use_linear_in_transformer": False,
|
| 489 |
+
"adm_in_channels": None,
|
| 490 |
+
"use_temporal_attention": False,
|
| 491 |
+
"in_channels": 8,
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 495 |
+
return model_base.SD15_instructpix2pix(self, device=device)
|
| 496 |
+
|
| 497 |
+
class SDXL_instructpix2pix(SDXL):
|
| 498 |
+
unet_config = {
|
| 499 |
+
"model_channels": 320,
|
| 500 |
+
"use_linear_in_transformer": True,
|
| 501 |
+
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
| 502 |
+
"context_dim": 2048,
|
| 503 |
+
"adm_in_channels": 2816,
|
| 504 |
+
"use_temporal_attention": False,
|
| 505 |
+
"in_channels": 8,
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 509 |
+
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
|
| 510 |
+
|
| 511 |
+
class LotusD(SD20):
|
| 512 |
+
unet_config = {
|
| 513 |
+
"model_channels": 320,
|
| 514 |
+
"use_linear_in_transformer": True,
|
| 515 |
+
"use_temporal_attention": False,
|
| 516 |
+
"adm_in_channels": 4,
|
| 517 |
+
"in_channels": 4,
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
unet_extra_config = {
|
| 521 |
+
"num_classes": 'sequential'
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 525 |
+
return model_base.Lotus(self, device=device)
|
| 526 |
+
|
| 527 |
+
class SD3(supported_models_base.BASE):
|
| 528 |
+
unet_config = {
|
| 529 |
+
"in_channels": 16,
|
| 530 |
+
"pos_embed_scaling_factor": None,
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
sampling_settings = {
|
| 534 |
+
"shift": 3.0,
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
unet_extra_config = {}
|
| 538 |
+
latent_format = latent_formats.SD3
|
| 539 |
+
|
| 540 |
+
memory_usage_factor = 1.2
|
| 541 |
+
|
| 542 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 543 |
+
|
| 544 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 545 |
+
out = model_base.SD3(self, device=device)
|
| 546 |
+
return out
|
| 547 |
+
|
| 548 |
+
def clip_target(self, state_dict={}):
|
| 549 |
+
clip_l = False
|
| 550 |
+
clip_g = False
|
| 551 |
+
t5 = False
|
| 552 |
+
pref = self.text_encoder_key_prefix[0]
|
| 553 |
+
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
| 554 |
+
clip_l = True
|
| 555 |
+
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
| 556 |
+
clip_g = True
|
| 557 |
+
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
| 558 |
+
if "dtype_t5" in t5_detect:
|
| 559 |
+
t5 = True
|
| 560 |
+
|
| 561 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, **t5_detect))
|
| 562 |
+
|
| 563 |
+
class StableAudio(supported_models_base.BASE):
|
| 564 |
+
unet_config = {
|
| 565 |
+
"audio_model": "dit1.0",
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
sampling_settings = {"sigma_max": 500.0, "sigma_min": 0.03}
|
| 569 |
+
|
| 570 |
+
unet_extra_config = {}
|
| 571 |
+
latent_format = latent_formats.StableAudio1
|
| 572 |
+
|
| 573 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 574 |
+
vae_key_prefix = ["pretransform.model."]
|
| 575 |
+
|
| 576 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 577 |
+
seconds_start_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_start.": ""}, filter_keys=True)
|
| 578 |
+
seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True)
|
| 579 |
+
return model_base.StableAudio1(self, seconds_start_embedder_weights=seconds_start_sd, seconds_total_embedder_weights=seconds_total_sd, device=device)
|
| 580 |
+
|
| 581 |
+
def process_unet_state_dict(self, state_dict):
|
| 582 |
+
for k in list(state_dict.keys()):
|
| 583 |
+
if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): #These weights are all zero
|
| 584 |
+
state_dict.pop(k)
|
| 585 |
+
return state_dict
|
| 586 |
+
|
| 587 |
+
def process_unet_state_dict_for_saving(self, state_dict):
|
| 588 |
+
replace_prefix = {"": "model.model."}
|
| 589 |
+
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
| 590 |
+
|
| 591 |
+
def clip_target(self, state_dict={}):
|
| 592 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.sa_t5.SAT5Tokenizer, comfy.text_encoders.sa_t5.SAT5Model)
|
| 593 |
+
|
| 594 |
+
class AuraFlow(supported_models_base.BASE):
|
| 595 |
+
unet_config = {
|
| 596 |
+
"cond_seq_dim": 2048,
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
sampling_settings = {
|
| 600 |
+
"multiplier": 1.0,
|
| 601 |
+
"shift": 1.73,
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
unet_extra_config = {}
|
| 605 |
+
latent_format = latent_formats.SDXL
|
| 606 |
+
|
| 607 |
+
vae_key_prefix = ["vae."]
|
| 608 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 609 |
+
|
| 610 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 611 |
+
out = model_base.AuraFlow(self, device=device)
|
| 612 |
+
return out
|
| 613 |
+
|
| 614 |
+
def clip_target(self, state_dict={}):
|
| 615 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
|
| 616 |
+
|
| 617 |
+
class PixArtAlpha(supported_models_base.BASE):
|
| 618 |
+
unet_config = {
|
| 619 |
+
"image_model": "pixart_alpha",
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
sampling_settings = {
|
| 623 |
+
"beta_schedule" : "sqrt_linear",
|
| 624 |
+
"linear_start" : 0.0001,
|
| 625 |
+
"linear_end" : 0.02,
|
| 626 |
+
"timesteps" : 1000,
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
unet_extra_config = {}
|
| 630 |
+
latent_format = latent_formats.SD15
|
| 631 |
+
|
| 632 |
+
memory_usage_factor = 0.5
|
| 633 |
+
|
| 634 |
+
vae_key_prefix = ["vae."]
|
| 635 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 636 |
+
|
| 637 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 638 |
+
out = model_base.PixArt(self, device=device)
|
| 639 |
+
return out.eval()
|
| 640 |
+
|
| 641 |
+
def clip_target(self, state_dict={}):
|
| 642 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.PixArtT5XXL)
|
| 643 |
+
|
| 644 |
+
class PixArtSigma(PixArtAlpha):
|
| 645 |
+
unet_config = {
|
| 646 |
+
"image_model": "pixart_sigma",
|
| 647 |
+
}
|
| 648 |
+
latent_format = latent_formats.SDXL
|
| 649 |
+
|
| 650 |
+
class HunyuanDiT(supported_models_base.BASE):
|
| 651 |
+
unet_config = {
|
| 652 |
+
"image_model": "hydit",
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
unet_extra_config = {
|
| 656 |
+
"attn_precision": torch.float32,
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
sampling_settings = {
|
| 660 |
+
"linear_start": 0.00085,
|
| 661 |
+
"linear_end": 0.018,
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
latent_format = latent_formats.SDXL
|
| 665 |
+
|
| 666 |
+
memory_usage_factor = 1.3
|
| 667 |
+
|
| 668 |
+
vae_key_prefix = ["vae."]
|
| 669 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 670 |
+
|
| 671 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 672 |
+
out = model_base.HunyuanDiT(self, device=device)
|
| 673 |
+
return out
|
| 674 |
+
|
| 675 |
+
def clip_target(self, state_dict={}):
|
| 676 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.hydit.HyditTokenizer, comfy.text_encoders.hydit.HyditModel)
|
| 677 |
+
|
| 678 |
+
class HunyuanDiT1(HunyuanDiT):
|
| 679 |
+
unet_config = {
|
| 680 |
+
"image_model": "hydit1",
|
| 681 |
+
}
|
| 682 |
+
|
| 683 |
+
unet_extra_config = {}
|
| 684 |
+
|
| 685 |
+
sampling_settings = {
|
| 686 |
+
"linear_start" : 0.00085,
|
| 687 |
+
"linear_end" : 0.03,
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
class Flux(supported_models_base.BASE):
|
| 691 |
+
unet_config = {
|
| 692 |
+
"image_model": "flux",
|
| 693 |
+
"guidance_embed": True,
|
| 694 |
+
}
|
| 695 |
+
|
| 696 |
+
sampling_settings = {
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
unet_extra_config = {}
|
| 700 |
+
latent_format = latent_formats.Flux
|
| 701 |
+
|
| 702 |
+
memory_usage_factor = 2.8
|
| 703 |
+
|
| 704 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
| 705 |
+
|
| 706 |
+
vae_key_prefix = ["vae."]
|
| 707 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 708 |
+
|
| 709 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 710 |
+
out = model_base.Flux(self, device=device)
|
| 711 |
+
return out
|
| 712 |
+
|
| 713 |
+
def clip_target(self, state_dict={}):
|
| 714 |
+
pref = self.text_encoder_key_prefix[0]
|
| 715 |
+
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
| 716 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
| 717 |
+
|
| 718 |
+
class FluxInpaint(Flux):
|
| 719 |
+
unet_config = {
|
| 720 |
+
"image_model": "flux",
|
| 721 |
+
"guidance_embed": True,
|
| 722 |
+
"in_channels": 96,
|
| 723 |
+
}
|
| 724 |
+
|
| 725 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
| 726 |
+
|
| 727 |
+
class FluxSchnell(Flux):
|
| 728 |
+
unet_config = {
|
| 729 |
+
"image_model": "flux",
|
| 730 |
+
"guidance_embed": False,
|
| 731 |
+
}
|
| 732 |
+
|
| 733 |
+
sampling_settings = {
|
| 734 |
+
"multiplier": 1.0,
|
| 735 |
+
"shift": 1.0,
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 739 |
+
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
| 740 |
+
return out
|
| 741 |
+
|
| 742 |
+
class GenmoMochi(supported_models_base.BASE):
|
| 743 |
+
unet_config = {
|
| 744 |
+
"image_model": "mochi_preview",
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
sampling_settings = {
|
| 748 |
+
"multiplier": 1.0,
|
| 749 |
+
"shift": 6.0,
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
unet_extra_config = {}
|
| 753 |
+
latent_format = latent_formats.Mochi
|
| 754 |
+
|
| 755 |
+
memory_usage_factor = 2.0 #TODO
|
| 756 |
+
|
| 757 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
| 758 |
+
|
| 759 |
+
vae_key_prefix = ["vae."]
|
| 760 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 761 |
+
|
| 762 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 763 |
+
out = model_base.GenmoMochi(self, device=device)
|
| 764 |
+
return out
|
| 765 |
+
|
| 766 |
+
def clip_target(self, state_dict={}):
|
| 767 |
+
pref = self.text_encoder_key_prefix[0]
|
| 768 |
+
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
| 769 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
|
| 770 |
+
|
| 771 |
+
class LTXV(supported_models_base.BASE):
|
| 772 |
+
unet_config = {
|
| 773 |
+
"image_model": "ltxv",
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
sampling_settings = {
|
| 777 |
+
"shift": 2.37,
|
| 778 |
+
}
|
| 779 |
+
|
| 780 |
+
unet_extra_config = {}
|
| 781 |
+
latent_format = latent_formats.LTXV
|
| 782 |
+
|
| 783 |
+
memory_usage_factor = 5.5 # TODO: img2vid is about 2x vs txt2vid
|
| 784 |
+
|
| 785 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
| 786 |
+
|
| 787 |
+
vae_key_prefix = ["vae."]
|
| 788 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 789 |
+
|
| 790 |
+
def __init__(self, unet_config):
|
| 791 |
+
super().__init__(unet_config)
|
| 792 |
+
self.memory_usage_factor = (unet_config.get("cross_attention_dim", 2048) / 2048) * 5.5
|
| 793 |
+
|
| 794 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 795 |
+
out = model_base.LTXV(self, device=device)
|
| 796 |
+
return out
|
| 797 |
+
|
| 798 |
+
def clip_target(self, state_dict={}):
|
| 799 |
+
pref = self.text_encoder_key_prefix[0]
|
| 800 |
+
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
| 801 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
| 802 |
+
|
| 803 |
+
class HunyuanVideo(supported_models_base.BASE):
|
| 804 |
+
unet_config = {
|
| 805 |
+
"image_model": "hunyuan_video",
|
| 806 |
+
}
|
| 807 |
+
|
| 808 |
+
sampling_settings = {
|
| 809 |
+
"shift": 7.0,
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
unet_extra_config = {}
|
| 813 |
+
latent_format = latent_formats.HunyuanVideo
|
| 814 |
+
|
| 815 |
+
memory_usage_factor = 1.8 #TODO
|
| 816 |
+
|
| 817 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
| 818 |
+
|
| 819 |
+
vae_key_prefix = ["vae."]
|
| 820 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 821 |
+
|
| 822 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 823 |
+
out = model_base.HunyuanVideo(self, device=device)
|
| 824 |
+
return out
|
| 825 |
+
|
| 826 |
+
def process_unet_state_dict(self, state_dict):
|
| 827 |
+
out_sd = {}
|
| 828 |
+
for k in list(state_dict.keys()):
|
| 829 |
+
key_out = k
|
| 830 |
+
key_out = key_out.replace("txt_in.t_embedder.mlp.0.", "txt_in.t_embedder.in_layer.").replace("txt_in.t_embedder.mlp.2.", "txt_in.t_embedder.out_layer.")
|
| 831 |
+
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
| 832 |
+
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
| 833 |
+
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
| 834 |
+
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
|
| 835 |
+
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
|
| 836 |
+
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
| 837 |
+
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
| 838 |
+
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
| 839 |
+
out_sd[key_out] = state_dict[k]
|
| 840 |
+
return out_sd
|
| 841 |
+
|
| 842 |
+
def process_unet_state_dict_for_saving(self, state_dict):
|
| 843 |
+
replace_prefix = {"": "model.model."}
|
| 844 |
+
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
| 845 |
+
|
| 846 |
+
def clip_target(self, state_dict={}):
|
| 847 |
+
pref = self.text_encoder_key_prefix[0]
|
| 848 |
+
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
| 849 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
| 850 |
+
|
| 851 |
+
class HunyuanVideoI2V(HunyuanVideo):
|
| 852 |
+
unet_config = {
|
| 853 |
+
"image_model": "hunyuan_video",
|
| 854 |
+
"in_channels": 33,
|
| 855 |
+
}
|
| 856 |
+
|
| 857 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 858 |
+
out = model_base.HunyuanVideoI2V(self, device=device)
|
| 859 |
+
return out
|
| 860 |
+
|
| 861 |
+
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
|
| 862 |
+
unet_config = {
|
| 863 |
+
"image_model": "hunyuan_video",
|
| 864 |
+
"in_channels": 32,
|
| 865 |
+
}
|
| 866 |
+
|
| 867 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 868 |
+
out = model_base.HunyuanVideoSkyreelsI2V(self, device=device)
|
| 869 |
+
return out
|
| 870 |
+
|
| 871 |
+
class CosmosT2V(supported_models_base.BASE):
|
| 872 |
+
unet_config = {
|
| 873 |
+
"image_model": "cosmos",
|
| 874 |
+
"in_channels": 16,
|
| 875 |
+
}
|
| 876 |
+
|
| 877 |
+
sampling_settings = {
|
| 878 |
+
"sigma_data": 0.5,
|
| 879 |
+
"sigma_max": 80.0,
|
| 880 |
+
"sigma_min": 0.002,
|
| 881 |
+
}
|
| 882 |
+
|
| 883 |
+
unet_extra_config = {}
|
| 884 |
+
latent_format = latent_formats.Cosmos1CV8x8x8
|
| 885 |
+
|
| 886 |
+
memory_usage_factor = 1.6 #TODO
|
| 887 |
+
|
| 888 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
|
| 889 |
+
|
| 890 |
+
vae_key_prefix = ["vae."]
|
| 891 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 892 |
+
|
| 893 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 894 |
+
out = model_base.CosmosVideo(self, device=device)
|
| 895 |
+
return out
|
| 896 |
+
|
| 897 |
+
def clip_target(self, state_dict={}):
|
| 898 |
+
pref = self.text_encoder_key_prefix[0]
|
| 899 |
+
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
| 900 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
|
| 901 |
+
|
| 902 |
+
class CosmosI2V(CosmosT2V):
|
| 903 |
+
unet_config = {
|
| 904 |
+
"image_model": "cosmos",
|
| 905 |
+
"in_channels": 17,
|
| 906 |
+
}
|
| 907 |
+
|
| 908 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 909 |
+
out = model_base.CosmosVideo(self, image_to_video=True, device=device)
|
| 910 |
+
return out
|
| 911 |
+
|
| 912 |
+
class CosmosT2IPredict2(supported_models_base.BASE):
|
| 913 |
+
unet_config = {
|
| 914 |
+
"image_model": "cosmos_predict2",
|
| 915 |
+
"in_channels": 16,
|
| 916 |
+
}
|
| 917 |
+
|
| 918 |
+
sampling_settings = {
|
| 919 |
+
"sigma_data": 1.0,
|
| 920 |
+
"sigma_max": 80.0,
|
| 921 |
+
"sigma_min": 0.002,
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
unet_extra_config = {}
|
| 925 |
+
latent_format = latent_formats.Wan21
|
| 926 |
+
|
| 927 |
+
memory_usage_factor = 1.0
|
| 928 |
+
|
| 929 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
| 930 |
+
|
| 931 |
+
def __init__(self, unet_config):
|
| 932 |
+
super().__init__(unet_config)
|
| 933 |
+
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
|
| 934 |
+
|
| 935 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 936 |
+
out = model_base.CosmosPredict2(self, device=device)
|
| 937 |
+
return out
|
| 938 |
+
|
| 939 |
+
def clip_target(self, state_dict={}):
|
| 940 |
+
pref = self.text_encoder_key_prefix[0]
|
| 941 |
+
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
| 942 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
|
| 943 |
+
|
| 944 |
+
class CosmosI2VPredict2(CosmosT2IPredict2):
|
| 945 |
+
unet_config = {
|
| 946 |
+
"image_model": "cosmos_predict2",
|
| 947 |
+
"in_channels": 17,
|
| 948 |
+
}
|
| 949 |
+
|
| 950 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 951 |
+
out = model_base.CosmosPredict2(self, image_to_video=True, device=device)
|
| 952 |
+
return out
|
| 953 |
+
|
| 954 |
+
class Lumina2(supported_models_base.BASE):
|
| 955 |
+
unet_config = {
|
| 956 |
+
"image_model": "lumina2",
|
| 957 |
+
}
|
| 958 |
+
|
| 959 |
+
sampling_settings = {
|
| 960 |
+
"multiplier": 1.0,
|
| 961 |
+
"shift": 6.0,
|
| 962 |
+
}
|
| 963 |
+
|
| 964 |
+
memory_usage_factor = 1.2
|
| 965 |
+
|
| 966 |
+
unet_extra_config = {}
|
| 967 |
+
latent_format = latent_formats.Flux
|
| 968 |
+
|
| 969 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
| 970 |
+
|
| 971 |
+
vae_key_prefix = ["vae."]
|
| 972 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 973 |
+
|
| 974 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 975 |
+
out = model_base.Lumina2(self, device=device)
|
| 976 |
+
return out
|
| 977 |
+
|
| 978 |
+
def clip_target(self, state_dict={}):
|
| 979 |
+
pref = self.text_encoder_key_prefix[0]
|
| 980 |
+
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
|
| 981 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
|
| 982 |
+
|
| 983 |
+
class WAN21_T2V(supported_models_base.BASE):
|
| 984 |
+
unet_config = {
|
| 985 |
+
"image_model": "wan2.1",
|
| 986 |
+
"model_type": "t2v",
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
sampling_settings = {
|
| 990 |
+
"shift": 8.0,
|
| 991 |
+
}
|
| 992 |
+
|
| 993 |
+
unet_extra_config = {}
|
| 994 |
+
latent_format = latent_formats.Wan21
|
| 995 |
+
|
| 996 |
+
memory_usage_factor = 1.0
|
| 997 |
+
|
| 998 |
+
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
| 999 |
+
|
| 1000 |
+
vae_key_prefix = ["vae."]
|
| 1001 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 1002 |
+
|
| 1003 |
+
def __init__(self, unet_config):
|
| 1004 |
+
super().__init__(unet_config)
|
| 1005 |
+
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000
|
| 1006 |
+
|
| 1007 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 1008 |
+
out = model_base.WAN21(self, device=device)
|
| 1009 |
+
return out
|
| 1010 |
+
|
| 1011 |
+
def clip_target(self, state_dict={}):
|
| 1012 |
+
pref = self.text_encoder_key_prefix[0]
|
| 1013 |
+
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
| 1014 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
| 1015 |
+
|
| 1016 |
+
class WAN21_I2V(WAN21_T2V):
|
| 1017 |
+
unet_config = {
|
| 1018 |
+
"image_model": "wan2.1",
|
| 1019 |
+
"model_type": "i2v",
|
| 1020 |
+
"in_dim": 36,
|
| 1021 |
+
}
|
| 1022 |
+
|
| 1023 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 1024 |
+
out = model_base.WAN21(self, image_to_video=True, device=device)
|
| 1025 |
+
return out
|
| 1026 |
+
|
| 1027 |
+
class WAN21_FunControl2V(WAN21_T2V):
|
| 1028 |
+
unet_config = {
|
| 1029 |
+
"image_model": "wan2.1",
|
| 1030 |
+
"model_type": "i2v",
|
| 1031 |
+
"in_dim": 48,
|
| 1032 |
+
}
|
| 1033 |
+
|
| 1034 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 1035 |
+
out = model_base.WAN21(self, image_to_video=False, device=device)
|
| 1036 |
+
return out
|
| 1037 |
+
|
| 1038 |
+
class WAN21_Camera(WAN21_T2V):
|
| 1039 |
+
unet_config = {
|
| 1040 |
+
"image_model": "wan2.1",
|
| 1041 |
+
"model_type": "camera",
|
| 1042 |
+
"in_dim": 32,
|
| 1043 |
+
}
|
| 1044 |
+
|
| 1045 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 1046 |
+
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
| 1047 |
+
return out
|
| 1048 |
+
class WAN21_Vace(WAN21_T2V):
|
| 1049 |
+
unet_config = {
|
| 1050 |
+
"image_model": "wan2.1",
|
| 1051 |
+
"model_type": "vace",
|
| 1052 |
+
}
|
| 1053 |
+
|
| 1054 |
+
def __init__(self, unet_config):
|
| 1055 |
+
super().__init__(unet_config)
|
| 1056 |
+
self.memory_usage_factor = 1.2 * self.memory_usage_factor
|
| 1057 |
+
|
| 1058 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 1059 |
+
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
|
| 1060 |
+
return out
|
| 1061 |
+
|
| 1062 |
+
class WAN22_T2V(WAN21_T2V):
|
| 1063 |
+
unet_config = {
|
| 1064 |
+
"image_model": "wan2.1",
|
| 1065 |
+
"model_type": "t2v",
|
| 1066 |
+
"out_dim": 48,
|
| 1067 |
+
}
|
| 1068 |
+
|
| 1069 |
+
latent_format = latent_formats.Wan22
|
| 1070 |
+
|
| 1071 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 1072 |
+
out = model_base.WAN22(self, image_to_video=True, device=device)
|
| 1073 |
+
return out
|
| 1074 |
+
|
| 1075 |
+
class Hunyuan3Dv2(supported_models_base.BASE):
|
| 1076 |
+
unet_config = {
|
| 1077 |
+
"image_model": "hunyuan3d2",
|
| 1078 |
+
}
|
| 1079 |
+
|
| 1080 |
+
unet_extra_config = {}
|
| 1081 |
+
|
| 1082 |
+
sampling_settings = {
|
| 1083 |
+
"multiplier": 1.0,
|
| 1084 |
+
"shift": 1.0,
|
| 1085 |
+
}
|
| 1086 |
+
|
| 1087 |
+
memory_usage_factor = 3.5
|
| 1088 |
+
|
| 1089 |
+
clip_vision_prefix = "conditioner.main_image_encoder.model."
|
| 1090 |
+
vae_key_prefix = ["vae."]
|
| 1091 |
+
|
| 1092 |
+
latent_format = latent_formats.Hunyuan3Dv2
|
| 1093 |
+
|
| 1094 |
+
def process_unet_state_dict_for_saving(self, state_dict):
|
| 1095 |
+
replace_prefix = {"": "model."}
|
| 1096 |
+
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
| 1097 |
+
|
| 1098 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 1099 |
+
out = model_base.Hunyuan3Dv2(self, device=device)
|
| 1100 |
+
return out
|
| 1101 |
+
|
| 1102 |
+
def clip_target(self, state_dict={}):
|
| 1103 |
+
return None
|
| 1104 |
+
|
| 1105 |
+
class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
| 1106 |
+
unet_config = {
|
| 1107 |
+
"image_model": "hunyuan3d2",
|
| 1108 |
+
"depth": 8,
|
| 1109 |
+
}
|
| 1110 |
+
|
| 1111 |
+
latent_format = latent_formats.Hunyuan3Dv2mini
|
| 1112 |
+
|
| 1113 |
+
class HiDream(supported_models_base.BASE):
|
| 1114 |
+
unet_config = {
|
| 1115 |
+
"image_model": "hidream",
|
| 1116 |
+
}
|
| 1117 |
+
|
| 1118 |
+
sampling_settings = {
|
| 1119 |
+
"shift": 3.0,
|
| 1120 |
+
}
|
| 1121 |
+
|
| 1122 |
+
sampling_settings = {
|
| 1123 |
+
}
|
| 1124 |
+
|
| 1125 |
+
# memory_usage_factor = 1.2 # TODO
|
| 1126 |
+
|
| 1127 |
+
unet_extra_config = {}
|
| 1128 |
+
latent_format = latent_formats.Flux
|
| 1129 |
+
|
| 1130 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
| 1131 |
+
|
| 1132 |
+
vae_key_prefix = ["vae."]
|
| 1133 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 1134 |
+
|
| 1135 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 1136 |
+
out = model_base.HiDream(self, device=device)
|
| 1137 |
+
return out
|
| 1138 |
+
|
| 1139 |
+
def clip_target(self, state_dict={}):
|
| 1140 |
+
return None # TODO
|
| 1141 |
+
|
| 1142 |
+
class Chroma(supported_models_base.BASE):
|
| 1143 |
+
unet_config = {
|
| 1144 |
+
"image_model": "chroma",
|
| 1145 |
+
}
|
| 1146 |
+
|
| 1147 |
+
unet_extra_config = {
|
| 1148 |
+
}
|
| 1149 |
+
|
| 1150 |
+
sampling_settings = {
|
| 1151 |
+
"multiplier": 1.0,
|
| 1152 |
+
}
|
| 1153 |
+
|
| 1154 |
+
latent_format = comfy.latent_formats.Flux
|
| 1155 |
+
|
| 1156 |
+
memory_usage_factor = 3.2
|
| 1157 |
+
|
| 1158 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
| 1159 |
+
|
| 1160 |
+
|
| 1161 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 1162 |
+
out = model_base.Chroma(self, device=device)
|
| 1163 |
+
return out
|
| 1164 |
+
|
| 1165 |
+
def clip_target(self, state_dict={}):
|
| 1166 |
+
pref = self.text_encoder_key_prefix[0]
|
| 1167 |
+
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
| 1168 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
| 1169 |
+
|
| 1170 |
+
class ACEStep(supported_models_base.BASE):
|
| 1171 |
+
unet_config = {
|
| 1172 |
+
"audio_model": "ace",
|
| 1173 |
+
}
|
| 1174 |
+
|
| 1175 |
+
unet_extra_config = {
|
| 1176 |
+
}
|
| 1177 |
+
|
| 1178 |
+
sampling_settings = {
|
| 1179 |
+
"shift": 3.0,
|
| 1180 |
+
}
|
| 1181 |
+
|
| 1182 |
+
latent_format = comfy.latent_formats.ACEAudio
|
| 1183 |
+
|
| 1184 |
+
memory_usage_factor = 0.5
|
| 1185 |
+
|
| 1186 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
| 1187 |
+
|
| 1188 |
+
vae_key_prefix = ["vae."]
|
| 1189 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 1190 |
+
|
| 1191 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 1192 |
+
out = model_base.ACEStep(self, device=device)
|
| 1193 |
+
return out
|
| 1194 |
+
|
| 1195 |
+
def clip_target(self, state_dict={}):
|
| 1196 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
|
| 1197 |
+
|
| 1198 |
+
class Omnigen2(supported_models_base.BASE):
|
| 1199 |
+
unet_config = {
|
| 1200 |
+
"image_model": "omnigen2",
|
| 1201 |
+
}
|
| 1202 |
+
|
| 1203 |
+
sampling_settings = {
|
| 1204 |
+
"multiplier": 1.0,
|
| 1205 |
+
"shift": 2.6,
|
| 1206 |
+
}
|
| 1207 |
+
|
| 1208 |
+
memory_usage_factor = 1.65 #TODO
|
| 1209 |
+
|
| 1210 |
+
unet_extra_config = {}
|
| 1211 |
+
latent_format = latent_formats.Flux
|
| 1212 |
+
|
| 1213 |
+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
| 1214 |
+
|
| 1215 |
+
vae_key_prefix = ["vae."]
|
| 1216 |
+
text_encoder_key_prefix = ["text_encoders."]
|
| 1217 |
+
|
| 1218 |
+
def __init__(self, unet_config):
|
| 1219 |
+
super().__init__(unet_config)
|
| 1220 |
+
if comfy.model_management.extended_fp16_support():
|
| 1221 |
+
self.supported_inference_dtypes = [torch.float16] + self.supported_inference_dtypes
|
| 1222 |
+
|
| 1223 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 1224 |
+
out = model_base.Omnigen2(self, device=device)
|
| 1225 |
+
return out
|
| 1226 |
+
|
| 1227 |
+
def clip_target(self, state_dict={}):
|
| 1228 |
+
pref = self.text_encoder_key_prefix[0]
|
| 1229 |
+
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
| 1230 |
+
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
| 1231 |
+
|
| 1232 |
+
|
| 1233 |
+
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
| 1234 |
+
|
| 1235 |
+
models += [SVD_img2vid]
|
ComfyUI/comfy/supported_models_base.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is part of ComfyUI.
|
| 3 |
+
Copyright (C) 2024 Comfy
|
| 4 |
+
|
| 5 |
+
This program is free software: you can redistribute it and/or modify
|
| 6 |
+
it under the terms of the GNU General Public License as published by
|
| 7 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 8 |
+
(at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This program is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 13 |
+
GNU General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU General Public License
|
| 16 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from . import model_base
|
| 21 |
+
from . import utils
|
| 22 |
+
from . import latent_formats
|
| 23 |
+
|
| 24 |
+
class ClipTarget:
|
| 25 |
+
def __init__(self, tokenizer, clip):
|
| 26 |
+
self.clip = clip
|
| 27 |
+
self.tokenizer = tokenizer
|
| 28 |
+
self.params = {}
|
| 29 |
+
|
| 30 |
+
class BASE:
|
| 31 |
+
unet_config = {}
|
| 32 |
+
unet_extra_config = {
|
| 33 |
+
"num_heads": -1,
|
| 34 |
+
"num_head_channels": 64,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
required_keys = {}
|
| 38 |
+
|
| 39 |
+
clip_prefix = []
|
| 40 |
+
clip_vision_prefix = None
|
| 41 |
+
noise_aug_config = None
|
| 42 |
+
sampling_settings = {}
|
| 43 |
+
latent_format = latent_formats.LatentFormat
|
| 44 |
+
vae_key_prefix = ["first_stage_model."]
|
| 45 |
+
text_encoder_key_prefix = ["cond_stage_model."]
|
| 46 |
+
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
| 47 |
+
|
| 48 |
+
memory_usage_factor = 2.0
|
| 49 |
+
|
| 50 |
+
manual_cast_dtype = None
|
| 51 |
+
custom_operations = None
|
| 52 |
+
scaled_fp8 = None
|
| 53 |
+
optimizations = {"fp8": False}
|
| 54 |
+
|
| 55 |
+
@classmethod
|
| 56 |
+
def matches(s, unet_config, state_dict=None):
|
| 57 |
+
for k in s.unet_config:
|
| 58 |
+
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
| 59 |
+
return False
|
| 60 |
+
if state_dict is not None:
|
| 61 |
+
for k in s.required_keys:
|
| 62 |
+
if k not in state_dict:
|
| 63 |
+
return False
|
| 64 |
+
return True
|
| 65 |
+
|
| 66 |
+
def model_type(self, state_dict, prefix=""):
|
| 67 |
+
return model_base.ModelType.EPS
|
| 68 |
+
|
| 69 |
+
def inpaint_model(self):
|
| 70 |
+
return self.unet_config["in_channels"] > 4
|
| 71 |
+
|
| 72 |
+
def __init__(self, unet_config):
|
| 73 |
+
self.unet_config = unet_config.copy()
|
| 74 |
+
self.sampling_settings = self.sampling_settings.copy()
|
| 75 |
+
self.latent_format = self.latent_format()
|
| 76 |
+
self.optimizations = self.optimizations.copy()
|
| 77 |
+
for x in self.unet_extra_config:
|
| 78 |
+
self.unet_config[x] = self.unet_extra_config[x]
|
| 79 |
+
|
| 80 |
+
def get_model(self, state_dict, prefix="", device=None):
|
| 81 |
+
if self.noise_aug_config is not None:
|
| 82 |
+
out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
|
| 83 |
+
else:
|
| 84 |
+
out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
|
| 85 |
+
if self.inpaint_model():
|
| 86 |
+
out.set_inpaint()
|
| 87 |
+
return out
|
| 88 |
+
|
| 89 |
+
def process_clip_state_dict(self, state_dict):
|
| 90 |
+
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
| 91 |
+
return state_dict
|
| 92 |
+
|
| 93 |
+
def process_unet_state_dict(self, state_dict):
|
| 94 |
+
return state_dict
|
| 95 |
+
|
| 96 |
+
def process_vae_state_dict(self, state_dict):
|
| 97 |
+
return state_dict
|
| 98 |
+
|
| 99 |
+
def process_clip_state_dict_for_saving(self, state_dict):
|
| 100 |
+
replace_prefix = {"": self.text_encoder_key_prefix[0]}
|
| 101 |
+
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
| 102 |
+
|
| 103 |
+
def process_clip_vision_state_dict_for_saving(self, state_dict):
|
| 104 |
+
replace_prefix = {}
|
| 105 |
+
if self.clip_vision_prefix is not None:
|
| 106 |
+
replace_prefix[""] = self.clip_vision_prefix
|
| 107 |
+
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
| 108 |
+
|
| 109 |
+
def process_unet_state_dict_for_saving(self, state_dict):
|
| 110 |
+
replace_prefix = {"": "model.diffusion_model."}
|
| 111 |
+
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
| 112 |
+
|
| 113 |
+
def process_vae_state_dict_for_saving(self, state_dict):
|
| 114 |
+
replace_prefix = {"": self.vae_key_prefix[0]}
|
| 115 |
+
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
| 116 |
+
|
| 117 |
+
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
| 118 |
+
self.unet_config['dtype'] = dtype
|
| 119 |
+
self.manual_cast_dtype = manual_cast_dtype
|
ComfyUI/comfy/t2i_adapter/adapter.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#taken from https://github.com/TencentARC/T2I-Adapter
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def conv_nd(dims, *args, **kwargs):
|
| 8 |
+
"""
|
| 9 |
+
Create a 1D, 2D, or 3D convolution module.
|
| 10 |
+
"""
|
| 11 |
+
if dims == 1:
|
| 12 |
+
return nn.Conv1d(*args, **kwargs)
|
| 13 |
+
elif dims == 2:
|
| 14 |
+
return nn.Conv2d(*args, **kwargs)
|
| 15 |
+
elif dims == 3:
|
| 16 |
+
return nn.Conv3d(*args, **kwargs)
|
| 17 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
| 21 |
+
"""
|
| 22 |
+
Create a 1D, 2D, or 3D average pooling module.
|
| 23 |
+
"""
|
| 24 |
+
if dims == 1:
|
| 25 |
+
return nn.AvgPool1d(*args, **kwargs)
|
| 26 |
+
elif dims == 2:
|
| 27 |
+
return nn.AvgPool2d(*args, **kwargs)
|
| 28 |
+
elif dims == 3:
|
| 29 |
+
return nn.AvgPool3d(*args, **kwargs)
|
| 30 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Downsample(nn.Module):
|
| 34 |
+
"""
|
| 35 |
+
A downsampling layer with an optional convolution.
|
| 36 |
+
:param channels: channels in the inputs and outputs.
|
| 37 |
+
:param use_conv: a bool determining if a convolution is applied.
|
| 38 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 39 |
+
downsampling occurs in the inner-two dimensions.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.channels = channels
|
| 45 |
+
self.out_channels = out_channels or channels
|
| 46 |
+
self.use_conv = use_conv
|
| 47 |
+
self.dims = dims
|
| 48 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
| 49 |
+
if use_conv:
|
| 50 |
+
self.op = conv_nd(
|
| 51 |
+
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
assert self.channels == self.out_channels
|
| 55 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
assert x.shape[1] == self.channels
|
| 59 |
+
if not self.use_conv:
|
| 60 |
+
padding = [x.shape[2] % 2, x.shape[3] % 2]
|
| 61 |
+
self.op.padding = padding
|
| 62 |
+
|
| 63 |
+
x = self.op(x)
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ResnetBlock(nn.Module):
|
| 68 |
+
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
|
| 69 |
+
super().__init__()
|
| 70 |
+
ps = ksize // 2
|
| 71 |
+
if in_c != out_c or sk == False:
|
| 72 |
+
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
| 73 |
+
else:
|
| 74 |
+
# print('n_in')
|
| 75 |
+
self.in_conv = None
|
| 76 |
+
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
|
| 77 |
+
self.act = nn.ReLU()
|
| 78 |
+
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
|
| 79 |
+
if sk == False:
|
| 80 |
+
self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
| 81 |
+
else:
|
| 82 |
+
self.skep = None
|
| 83 |
+
|
| 84 |
+
self.down = down
|
| 85 |
+
if self.down == True:
|
| 86 |
+
self.down_opt = Downsample(in_c, use_conv=use_conv)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
if self.down == True:
|
| 90 |
+
x = self.down_opt(x)
|
| 91 |
+
if self.in_conv is not None: # edit
|
| 92 |
+
x = self.in_conv(x)
|
| 93 |
+
|
| 94 |
+
h = self.block1(x)
|
| 95 |
+
h = self.act(h)
|
| 96 |
+
h = self.block2(h)
|
| 97 |
+
if self.skep is not None:
|
| 98 |
+
return h + self.skep(x)
|
| 99 |
+
else:
|
| 100 |
+
return h + x
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Adapter(nn.Module):
|
| 104 |
+
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True):
|
| 105 |
+
super(Adapter, self).__init__()
|
| 106 |
+
self.unshuffle_amount = 8
|
| 107 |
+
resblock_no_downsample = []
|
| 108 |
+
resblock_downsample = [3, 2, 1]
|
| 109 |
+
self.xl = xl
|
| 110 |
+
if self.xl:
|
| 111 |
+
self.unshuffle_amount = 16
|
| 112 |
+
resblock_no_downsample = [1]
|
| 113 |
+
resblock_downsample = [2]
|
| 114 |
+
|
| 115 |
+
self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount)
|
| 116 |
+
self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount)
|
| 117 |
+
self.channels = channels
|
| 118 |
+
self.nums_rb = nums_rb
|
| 119 |
+
self.body = []
|
| 120 |
+
for i in range(len(channels)):
|
| 121 |
+
for j in range(nums_rb):
|
| 122 |
+
if (i in resblock_downsample) and (j == 0):
|
| 123 |
+
self.body.append(
|
| 124 |
+
ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
|
| 125 |
+
elif (i in resblock_no_downsample) and (j == 0):
|
| 126 |
+
self.body.append(
|
| 127 |
+
ResnetBlock(channels[i - 1], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
|
| 128 |
+
else:
|
| 129 |
+
self.body.append(
|
| 130 |
+
ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
|
| 131 |
+
self.body = nn.ModuleList(self.body)
|
| 132 |
+
self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
# unshuffle
|
| 136 |
+
x = self.unshuffle(x)
|
| 137 |
+
# extract features
|
| 138 |
+
features = []
|
| 139 |
+
x = self.conv_in(x)
|
| 140 |
+
for i in range(len(self.channels)):
|
| 141 |
+
for j in range(self.nums_rb):
|
| 142 |
+
idx = i * self.nums_rb + j
|
| 143 |
+
x = self.body[idx](x)
|
| 144 |
+
if self.xl:
|
| 145 |
+
features.append(None)
|
| 146 |
+
if i == 0:
|
| 147 |
+
features.append(None)
|
| 148 |
+
features.append(None)
|
| 149 |
+
if i == 2:
|
| 150 |
+
features.append(None)
|
| 151 |
+
else:
|
| 152 |
+
features.append(None)
|
| 153 |
+
features.append(None)
|
| 154 |
+
features.append(x)
|
| 155 |
+
|
| 156 |
+
features = features[::-1]
|
| 157 |
+
|
| 158 |
+
if self.xl:
|
| 159 |
+
return {"input": features[1:], "middle": features[:1]}
|
| 160 |
+
else:
|
| 161 |
+
return {"input": features}
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class LayerNorm(nn.LayerNorm):
|
| 166 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 167 |
+
|
| 168 |
+
def forward(self, x: torch.Tensor):
|
| 169 |
+
orig_type = x.dtype
|
| 170 |
+
ret = super().forward(x.type(torch.float32))
|
| 171 |
+
return ret.type(orig_type)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class QuickGELU(nn.Module):
|
| 175 |
+
|
| 176 |
+
def forward(self, x: torch.Tensor):
|
| 177 |
+
return x * torch.sigmoid(1.702 * x)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class ResidualAttentionBlock(nn.Module):
|
| 181 |
+
|
| 182 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 183 |
+
super().__init__()
|
| 184 |
+
|
| 185 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 186 |
+
self.ln_1 = LayerNorm(d_model)
|
| 187 |
+
self.mlp = nn.Sequential(
|
| 188 |
+
OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
|
| 189 |
+
("c_proj", nn.Linear(d_model * 4, d_model))]))
|
| 190 |
+
self.ln_2 = LayerNorm(d_model)
|
| 191 |
+
self.attn_mask = attn_mask
|
| 192 |
+
|
| 193 |
+
def attention(self, x: torch.Tensor):
|
| 194 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 195 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 196 |
+
|
| 197 |
+
def forward(self, x: torch.Tensor):
|
| 198 |
+
x = x + self.attention(self.ln_1(x))
|
| 199 |
+
x = x + self.mlp(self.ln_2(x))
|
| 200 |
+
return x
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class StyleAdapter(nn.Module):
|
| 204 |
+
|
| 205 |
+
def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4):
|
| 206 |
+
super().__init__()
|
| 207 |
+
|
| 208 |
+
scale = width ** -0.5
|
| 209 |
+
self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)])
|
| 210 |
+
self.num_token = num_token
|
| 211 |
+
self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale)
|
| 212 |
+
self.ln_post = LayerNorm(width)
|
| 213 |
+
self.ln_pre = LayerNorm(width)
|
| 214 |
+
self.proj = nn.Parameter(scale * torch.randn(width, context_dim))
|
| 215 |
+
|
| 216 |
+
def forward(self, x):
|
| 217 |
+
# x shape [N, HW+1, C]
|
| 218 |
+
style_embedding = self.style_embedding + torch.zeros(
|
| 219 |
+
(x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device)
|
| 220 |
+
x = torch.cat([x, style_embedding], dim=1)
|
| 221 |
+
x = self.ln_pre(x)
|
| 222 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 223 |
+
x = self.transformer_layes(x)
|
| 224 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 225 |
+
|
| 226 |
+
x = self.ln_post(x[:, -self.num_token:, :])
|
| 227 |
+
x = x @ self.proj
|
| 228 |
+
|
| 229 |
+
return x
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class ResnetBlock_light(nn.Module):
|
| 233 |
+
def __init__(self, in_c):
|
| 234 |
+
super().__init__()
|
| 235 |
+
self.block1 = nn.Conv2d(in_c, in_c, 3, 1, 1)
|
| 236 |
+
self.act = nn.ReLU()
|
| 237 |
+
self.block2 = nn.Conv2d(in_c, in_c, 3, 1, 1)
|
| 238 |
+
|
| 239 |
+
def forward(self, x):
|
| 240 |
+
h = self.block1(x)
|
| 241 |
+
h = self.act(h)
|
| 242 |
+
h = self.block2(h)
|
| 243 |
+
|
| 244 |
+
return h + x
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class extractor(nn.Module):
|
| 248 |
+
def __init__(self, in_c, inter_c, out_c, nums_rb, down=False):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.in_conv = nn.Conv2d(in_c, inter_c, 1, 1, 0)
|
| 251 |
+
self.body = []
|
| 252 |
+
for _ in range(nums_rb):
|
| 253 |
+
self.body.append(ResnetBlock_light(inter_c))
|
| 254 |
+
self.body = nn.Sequential(*self.body)
|
| 255 |
+
self.out_conv = nn.Conv2d(inter_c, out_c, 1, 1, 0)
|
| 256 |
+
self.down = down
|
| 257 |
+
if self.down == True:
|
| 258 |
+
self.down_opt = Downsample(in_c, use_conv=False)
|
| 259 |
+
|
| 260 |
+
def forward(self, x):
|
| 261 |
+
if self.down == True:
|
| 262 |
+
x = self.down_opt(x)
|
| 263 |
+
x = self.in_conv(x)
|
| 264 |
+
x = self.body(x)
|
| 265 |
+
x = self.out_conv(x)
|
| 266 |
+
|
| 267 |
+
return x
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class Adapter_light(nn.Module):
|
| 271 |
+
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
|
| 272 |
+
super(Adapter_light, self).__init__()
|
| 273 |
+
self.unshuffle_amount = 8
|
| 274 |
+
self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount)
|
| 275 |
+
self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount)
|
| 276 |
+
self.channels = channels
|
| 277 |
+
self.nums_rb = nums_rb
|
| 278 |
+
self.body = []
|
| 279 |
+
self.xl = False
|
| 280 |
+
|
| 281 |
+
for i in range(len(channels)):
|
| 282 |
+
if i == 0:
|
| 283 |
+
self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))
|
| 284 |
+
else:
|
| 285 |
+
self.body.append(extractor(in_c=channels[i-1], inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=True))
|
| 286 |
+
self.body = nn.ModuleList(self.body)
|
| 287 |
+
|
| 288 |
+
def forward(self, x):
|
| 289 |
+
# unshuffle
|
| 290 |
+
x = self.unshuffle(x)
|
| 291 |
+
# extract features
|
| 292 |
+
features = []
|
| 293 |
+
for i in range(len(self.channels)):
|
| 294 |
+
x = self.body[i](x)
|
| 295 |
+
features.append(None)
|
| 296 |
+
features.append(None)
|
| 297 |
+
features.append(x)
|
| 298 |
+
|
| 299 |
+
return {"input": features[::-1]}
|
ComfyUI/comfy/taesd/taesd.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Tiny AutoEncoder for Stable Diffusion
|
| 4 |
+
(DNN for encoding / decoding SD's latent space)
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
import comfy.utils
|
| 10 |
+
import comfy.ops
|
| 11 |
+
|
| 12 |
+
def conv(n_in, n_out, **kwargs):
|
| 13 |
+
return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
| 14 |
+
|
| 15 |
+
class Clamp(nn.Module):
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
return torch.tanh(x / 3) * 3
|
| 18 |
+
|
| 19 |
+
class Block(nn.Module):
|
| 20 |
+
def __init__(self, n_in, n_out):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
| 23 |
+
self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
| 24 |
+
self.fuse = nn.ReLU()
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
return self.fuse(self.conv(x) + self.skip(x))
|
| 27 |
+
|
| 28 |
+
def Encoder(latent_channels=4):
|
| 29 |
+
return nn.Sequential(
|
| 30 |
+
conv(3, 64), Block(64, 64),
|
| 31 |
+
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
| 32 |
+
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
| 33 |
+
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
| 34 |
+
conv(64, latent_channels),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def Decoder(latent_channels=4):
|
| 39 |
+
return nn.Sequential(
|
| 40 |
+
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
| 41 |
+
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
| 42 |
+
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
| 43 |
+
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
| 44 |
+
Block(64, 64), conv(64, 3),
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
class TAESD(nn.Module):
|
| 48 |
+
latent_magnitude = 3
|
| 49 |
+
latent_shift = 0.5
|
| 50 |
+
|
| 51 |
+
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
|
| 52 |
+
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.taesd_encoder = Encoder(latent_channels=latent_channels)
|
| 55 |
+
self.taesd_decoder = Decoder(latent_channels=latent_channels)
|
| 56 |
+
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
| 57 |
+
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
| 58 |
+
if encoder_path is not None:
|
| 59 |
+
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
| 60 |
+
if decoder_path is not None:
|
| 61 |
+
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def scale_latents(x):
|
| 65 |
+
"""raw latents -> [0, 1]"""
|
| 66 |
+
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def unscale_latents(x):
|
| 70 |
+
"""[0, 1] -> raw latents"""
|
| 71 |
+
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
| 72 |
+
|
| 73 |
+
def decode(self, x):
|
| 74 |
+
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
|
| 75 |
+
x_sample = x_sample.sub(0.5).mul(2)
|
| 76 |
+
return x_sample
|
| 77 |
+
|
| 78 |
+
def encode(self, x):
|
| 79 |
+
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
|
ComfyUI/comfy/text_encoders/ace.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from comfy import sd1_clip
|
| 2 |
+
from .spiece_tokenizer import SPieceTokenizer
|
| 3 |
+
import comfy.text_encoders.t5
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import torch
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
from tokenizers import Tokenizer
|
| 10 |
+
from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji
|
| 11 |
+
|
| 12 |
+
SUPPORT_LANGUAGES = {
|
| 13 |
+
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
|
| 14 |
+
"pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
|
| 15 |
+
"nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
|
| 16 |
+
"ko": 6152, "hi": 6680
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
structure_pattern = re.compile(r"\[.*?\]")
|
| 20 |
+
|
| 21 |
+
DEFAULT_VOCAB_FILE = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class VoiceBpeTokenizer:
|
| 25 |
+
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
|
| 26 |
+
self.tokenizer = None
|
| 27 |
+
if vocab_file is not None:
|
| 28 |
+
self.tokenizer = Tokenizer.from_file(vocab_file)
|
| 29 |
+
|
| 30 |
+
def preprocess_text(self, txt, lang):
|
| 31 |
+
txt = multilingual_cleaners(txt, lang)
|
| 32 |
+
return txt
|
| 33 |
+
|
| 34 |
+
def encode(self, txt, lang='en'):
|
| 35 |
+
# lang = lang.split("-")[0] # remove the region
|
| 36 |
+
# self.check_input_length(txt, lang)
|
| 37 |
+
txt = self.preprocess_text(txt, lang)
|
| 38 |
+
lang = "zh-cn" if lang == "zh" else lang
|
| 39 |
+
txt = f"[{lang}]{txt}"
|
| 40 |
+
txt = txt.replace(" ", "[SPACE]")
|
| 41 |
+
return self.tokenizer.encode(txt).ids
|
| 42 |
+
|
| 43 |
+
def get_lang(self, line):
|
| 44 |
+
if line.startswith("[") and line[3:4] == ']':
|
| 45 |
+
lang = line[1:3].lower()
|
| 46 |
+
if lang in SUPPORT_LANGUAGES:
|
| 47 |
+
return lang, line[4:]
|
| 48 |
+
return "en", line
|
| 49 |
+
|
| 50 |
+
def __call__(self, string):
|
| 51 |
+
lines = string.split("\n")
|
| 52 |
+
lyric_token_idx = [261]
|
| 53 |
+
for line in lines:
|
| 54 |
+
line = line.strip()
|
| 55 |
+
if not line:
|
| 56 |
+
lyric_token_idx += [2]
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
lang, line = self.get_lang(line)
|
| 60 |
+
|
| 61 |
+
if lang not in SUPPORT_LANGUAGES:
|
| 62 |
+
lang = "en"
|
| 63 |
+
if "zh" in lang:
|
| 64 |
+
lang = "zh"
|
| 65 |
+
if "spa" in lang:
|
| 66 |
+
lang = "es"
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
line_out = japanese_to_romaji(line)
|
| 70 |
+
if line_out != line:
|
| 71 |
+
lang = "ja"
|
| 72 |
+
line = line_out
|
| 73 |
+
except:
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
if structure_pattern.match(line):
|
| 78 |
+
token_idx = self.encode(line, "en")
|
| 79 |
+
else:
|
| 80 |
+
token_idx = self.encode(line, lang)
|
| 81 |
+
lyric_token_idx = lyric_token_idx + token_idx + [2]
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logging.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
|
| 84 |
+
return {"input_ids": lyric_token_idx}
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
def from_pretrained(path, **kwargs):
|
| 88 |
+
return VoiceBpeTokenizer(path, **kwargs)
|
| 89 |
+
|
| 90 |
+
def get_vocab(self):
|
| 91 |
+
return {}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class UMT5BaseModel(sd1_clip.SDClipModel):
|
| 95 |
+
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
| 96 |
+
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_base.json")
|
| 97 |
+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
|
| 98 |
+
|
| 99 |
+
class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
|
| 100 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 101 |
+
tokenizer = tokenizer_data.get("spiece_model", None)
|
| 102 |
+
super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)
|
| 103 |
+
|
| 104 |
+
def state_dict(self):
|
| 105 |
+
return {"spiece_model": self.tokenizer.serialize_model()}
|
| 106 |
+
|
| 107 |
+
class LyricsTokenizer(sd1_clip.SDTokenizer):
|
| 108 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 109 |
+
tokenizer = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
|
| 110 |
+
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)
|
| 111 |
+
|
| 112 |
+
class AceT5Tokenizer:
|
| 113 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 114 |
+
self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
| 115 |
+
self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
| 116 |
+
|
| 117 |
+
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
| 118 |
+
out = {}
|
| 119 |
+
out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
|
| 120 |
+
out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 121 |
+
return out
|
| 122 |
+
|
| 123 |
+
def untokenize(self, token_weight_pair):
|
| 124 |
+
return self.umt5base.untokenize(token_weight_pair)
|
| 125 |
+
|
| 126 |
+
def state_dict(self):
|
| 127 |
+
return self.umt5base.state_dict()
|
| 128 |
+
|
| 129 |
+
class AceT5Model(torch.nn.Module):
|
| 130 |
+
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
|
| 133 |
+
self.dtypes = set()
|
| 134 |
+
if dtype is not None:
|
| 135 |
+
self.dtypes.add(dtype)
|
| 136 |
+
|
| 137 |
+
def set_clip_options(self, options):
|
| 138 |
+
self.umt5base.set_clip_options(options)
|
| 139 |
+
|
| 140 |
+
def reset_clip_options(self):
|
| 141 |
+
self.umt5base.reset_clip_options()
|
| 142 |
+
|
| 143 |
+
def encode_token_weights(self, token_weight_pairs):
|
| 144 |
+
token_weight_pairs_umt5base = token_weight_pairs["umt5base"]
|
| 145 |
+
token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
|
| 146 |
+
|
| 147 |
+
t5_out, t5_pooled = self.umt5base.encode_token_weights(token_weight_pairs_umt5base)
|
| 148 |
+
|
| 149 |
+
lyrics_embeds = torch.tensor(list(map(lambda a: a[0], token_weight_pairs_lyrics[0]))).unsqueeze(0)
|
| 150 |
+
return t5_out, None, {"conditioning_lyrics": lyrics_embeds}
|
| 151 |
+
|
| 152 |
+
def load_sd(self, sd):
|
| 153 |
+
return self.umt5base.load_sd(sd)
|
ComfyUI/comfy/text_encoders/ace_text_cleaners.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# basic text cleaners for the ACE step model
|
| 2 |
+
# I didn't copy the ones from the reference code because I didn't want to deal with the dependencies
|
| 3 |
+
# TODO: more languages than english?
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
def japanese_to_romaji(japanese_text):
|
| 8 |
+
"""
|
| 9 |
+
Convert Japanese hiragana and katakana to romaji (Latin alphabet representation).
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
japanese_text (str): Text containing hiragana and/or katakana characters
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
str: The romaji (Latin alphabet) equivalent
|
| 16 |
+
"""
|
| 17 |
+
# Dictionary mapping kana characters to their romaji equivalents
|
| 18 |
+
kana_map = {
|
| 19 |
+
# Katakana characters
|
| 20 |
+
'ア': 'a', 'イ': 'i', 'ウ': 'u', 'エ': 'e', 'オ': 'o',
|
| 21 |
+
'カ': 'ka', 'キ': 'ki', 'ク': 'ku', 'ケ': 'ke', 'コ': 'ko',
|
| 22 |
+
'サ': 'sa', 'シ': 'shi', 'ス': 'su', 'セ': 'se', 'ソ': 'so',
|
| 23 |
+
'タ': 'ta', 'チ': 'chi', 'ツ': 'tsu', 'テ': 'te', 'ト': 'to',
|
| 24 |
+
'ナ': 'na', 'ニ': 'ni', 'ヌ': 'nu', 'ネ': 'ne', 'ノ': 'no',
|
| 25 |
+
'ハ': 'ha', 'ヒ': 'hi', 'フ': 'fu', 'ヘ': 'he', 'ホ': 'ho',
|
| 26 |
+
'マ': 'ma', 'ミ': 'mi', 'ム': 'mu', 'メ': 'me', 'モ': 'mo',
|
| 27 |
+
'ヤ': 'ya', 'ユ': 'yu', 'ヨ': 'yo',
|
| 28 |
+
'ラ': 'ra', 'リ': 'ri', 'ル': 'ru', 'レ': 're', 'ロ': 'ro',
|
| 29 |
+
'ワ': 'wa', 'ヲ': 'wo', 'ン': 'n',
|
| 30 |
+
|
| 31 |
+
# Katakana voiced consonants
|
| 32 |
+
'ガ': 'ga', 'ギ': 'gi', 'グ': 'gu', 'ゲ': 'ge', 'ゴ': 'go',
|
| 33 |
+
'ザ': 'za', 'ジ': 'ji', 'ズ': 'zu', 'ゼ': 'ze', 'ゾ': 'zo',
|
| 34 |
+
'ダ': 'da', 'ヂ': 'ji', 'ヅ': 'zu', 'デ': 'de', 'ド': 'do',
|
| 35 |
+
'バ': 'ba', 'ビ': 'bi', 'ブ': 'bu', 'ベ': 'be', 'ボ': 'bo',
|
| 36 |
+
'パ': 'pa', 'ピ': 'pi', 'プ': 'pu', 'ペ': 'pe', 'ポ': 'po',
|
| 37 |
+
|
| 38 |
+
# Katakana combinations
|
| 39 |
+
'キャ': 'kya', 'キュ': 'kyu', 'キョ': 'kyo',
|
| 40 |
+
'シャ': 'sha', 'シュ': 'shu', 'ショ': 'sho',
|
| 41 |
+
'チャ': 'cha', 'チュ': 'chu', 'チョ': 'cho',
|
| 42 |
+
'ニャ': 'nya', 'ニュ': 'nyu', 'ニョ': 'nyo',
|
| 43 |
+
'ヒャ': 'hya', 'ヒュ': 'hyu', 'ヒョ': 'hyo',
|
| 44 |
+
'ミャ': 'mya', 'ミュ': 'myu', 'ミョ': 'myo',
|
| 45 |
+
'リャ': 'rya', 'リュ': 'ryu', 'リョ': 'ryo',
|
| 46 |
+
'ギャ': 'gya', 'ギュ': 'gyu', 'ギョ': 'gyo',
|
| 47 |
+
'ジャ': 'ja', 'ジュ': 'ju', 'ジョ': 'jo',
|
| 48 |
+
'ビャ': 'bya', 'ビュ': 'byu', 'ビョ': 'byo',
|
| 49 |
+
'ピャ': 'pya', 'ピュ': 'pyu', 'ピョ': 'pyo',
|
| 50 |
+
|
| 51 |
+
# Katakana small characters and special cases
|
| 52 |
+
'ッ': '', # Small tsu (doubles the following consonant)
|
| 53 |
+
'ャ': 'ya', 'ュ': 'yu', 'ョ': 'yo',
|
| 54 |
+
|
| 55 |
+
# Katakana extras
|
| 56 |
+
'ヴ': 'vu', 'ファ': 'fa', 'フィ': 'fi', 'フェ': 'fe', 'フォ': 'fo',
|
| 57 |
+
'ウィ': 'wi', 'ウェ': 'we', 'ウォ': 'wo',
|
| 58 |
+
|
| 59 |
+
# Hiragana characters
|
| 60 |
+
'あ': 'a', 'い': 'i', 'う': 'u', 'え': 'e', 'お': 'o',
|
| 61 |
+
'か': 'ka', 'き': 'ki', 'く': 'ku', 'け': 'ke', 'こ': 'ko',
|
| 62 |
+
'さ': 'sa', 'し': 'shi', 'す': 'su', 'せ': 'se', 'そ': 'so',
|
| 63 |
+
'た': 'ta', 'ち': 'chi', 'つ': 'tsu', 'て': 'te', 'と': 'to',
|
| 64 |
+
'な': 'na', 'に': 'ni', 'ぬ': 'nu', 'ね': 'ne', 'の': 'no',
|
| 65 |
+
'は': 'ha', 'ひ': 'hi', 'ふ': 'fu', 'へ': 'he', 'ほ': 'ho',
|
| 66 |
+
'ま': 'ma', 'み': 'mi', 'む': 'mu', 'め': 'me', 'も': 'mo',
|
| 67 |
+
'や': 'ya', 'ゆ': 'yu', 'よ': 'yo',
|
| 68 |
+
'ら': 'ra', 'り': 'ri', 'る': 'ru', 'れ': 're', 'ろ': 'ro',
|
| 69 |
+
'わ': 'wa', 'を': 'wo', 'ん': 'n',
|
| 70 |
+
|
| 71 |
+
# Hiragana voiced consonants
|
| 72 |
+
'が': 'ga', 'ぎ': 'gi', 'ぐ': 'gu', 'げ': 'ge', 'ご': 'go',
|
| 73 |
+
'ざ': 'za', 'じ': 'ji', 'ず': 'zu', 'ぜ': 'ze', 'ぞ': 'zo',
|
| 74 |
+
'だ': 'da', 'ぢ': 'ji', 'づ': 'zu', 'で': 'de', 'ど': 'do',
|
| 75 |
+
'ば': 'ba', 'び': 'bi', 'ぶ': 'bu', 'べ': 'be', 'ぼ': 'bo',
|
| 76 |
+
'ぱ': 'pa', 'ぴ': 'pi', 'ぷ': 'pu', 'ぺ': 'pe', 'ぽ': 'po',
|
| 77 |
+
|
| 78 |
+
# Hiragana combinations
|
| 79 |
+
'きゃ': 'kya', 'きゅ': 'kyu', 'きょ': 'kyo',
|
| 80 |
+
'しゃ': 'sha', 'しゅ': 'shu', 'しょ': 'sho',
|
| 81 |
+
'ちゃ': 'cha', 'ちゅ': 'chu', 'ちょ': 'cho',
|
| 82 |
+
'にゃ': 'nya', 'にゅ': 'nyu', 'にょ': 'nyo',
|
| 83 |
+
'ひゃ': 'hya', 'ひゅ': 'hyu', 'ひょ': 'hyo',
|
| 84 |
+
'みゃ': 'mya', 'みゅ': 'myu', 'みょ': 'myo',
|
| 85 |
+
'りゃ': 'rya', 'りゅ': 'ryu', 'りょ': 'ryo',
|
| 86 |
+
'ぎゃ': 'gya', 'ぎゅ': 'gyu', 'ぎょ': 'gyo',
|
| 87 |
+
'じゃ': 'ja', 'じゅ': 'ju', 'じょ': 'jo',
|
| 88 |
+
'びゃ': 'bya', 'びゅ': 'byu', 'びょ': 'byo',
|
| 89 |
+
'ぴゃ': 'pya', 'ぴゅ': 'pyu', 'ぴょ': 'pyo',
|
| 90 |
+
|
| 91 |
+
# Hiragana small characters and special cases
|
| 92 |
+
'っ': '', # Small tsu (doubles the following consonant)
|
| 93 |
+
'ゃ': 'ya', 'ゅ': 'yu', 'ょ': 'yo',
|
| 94 |
+
|
| 95 |
+
# Common punctuation and spaces
|
| 96 |
+
' ': ' ', # Japanese space
|
| 97 |
+
'、': ', ', '。': '. ',
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
result = []
|
| 101 |
+
i = 0
|
| 102 |
+
|
| 103 |
+
while i < len(japanese_text):
|
| 104 |
+
# Check for small tsu (doubling the following consonant)
|
| 105 |
+
if i < len(japanese_text) - 1 and (japanese_text[i] == 'っ' or japanese_text[i] == '��'):
|
| 106 |
+
if i < len(japanese_text) - 1 and japanese_text[i+1] in kana_map:
|
| 107 |
+
next_romaji = kana_map[japanese_text[i+1]]
|
| 108 |
+
if next_romaji and next_romaji[0] not in 'aiueon':
|
| 109 |
+
result.append(next_romaji[0]) # Double the consonant
|
| 110 |
+
i += 1
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
# Check for combinations with small ya, yu, yo
|
| 114 |
+
if i < len(japanese_text) - 1 and japanese_text[i+1] in ('ゃ', 'ゅ', 'ょ', 'ャ', 'ュ', 'ョ'):
|
| 115 |
+
combo = japanese_text[i:i+2]
|
| 116 |
+
if combo in kana_map:
|
| 117 |
+
result.append(kana_map[combo])
|
| 118 |
+
i += 2
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
# Regular character
|
| 122 |
+
if japanese_text[i] in kana_map:
|
| 123 |
+
result.append(kana_map[japanese_text[i]])
|
| 124 |
+
else:
|
| 125 |
+
# If it's not in our map, keep it as is (might be kanji, romaji, etc.)
|
| 126 |
+
result.append(japanese_text[i])
|
| 127 |
+
|
| 128 |
+
i += 1
|
| 129 |
+
|
| 130 |
+
return ''.join(result)
|
| 131 |
+
|
| 132 |
+
def number_to_text(num, ordinal=False):
|
| 133 |
+
"""
|
| 134 |
+
Convert a number (int or float) to its text representation.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
num: The number to convert
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
str: Text representation of the number
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
if not isinstance(num, (int, float)):
|
| 144 |
+
return "Input must be a number"
|
| 145 |
+
|
| 146 |
+
# Handle special case of zero
|
| 147 |
+
if num == 0:
|
| 148 |
+
return "zero"
|
| 149 |
+
|
| 150 |
+
# Handle negative numbers
|
| 151 |
+
negative = num < 0
|
| 152 |
+
num = abs(num)
|
| 153 |
+
|
| 154 |
+
# Handle floats
|
| 155 |
+
if isinstance(num, float):
|
| 156 |
+
# Split into integer and decimal parts
|
| 157 |
+
int_part = int(num)
|
| 158 |
+
|
| 159 |
+
# Convert both parts
|
| 160 |
+
int_text = _int_to_text(int_part)
|
| 161 |
+
|
| 162 |
+
# Handle decimal part (convert to string and remove '0.')
|
| 163 |
+
decimal_str = str(num).split('.')[1]
|
| 164 |
+
decimal_text = " point " + " ".join(_digit_to_text(int(digit)) for digit in decimal_str)
|
| 165 |
+
|
| 166 |
+
result = int_text + decimal_text
|
| 167 |
+
else:
|
| 168 |
+
# Handle integers
|
| 169 |
+
result = _int_to_text(num)
|
| 170 |
+
|
| 171 |
+
# Add 'negative' prefix for negative numbers
|
| 172 |
+
if negative:
|
| 173 |
+
result = "negative " + result
|
| 174 |
+
|
| 175 |
+
return result
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _int_to_text(num):
|
| 179 |
+
"""Helper function to convert an integer to text"""
|
| 180 |
+
|
| 181 |
+
ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
|
| 182 |
+
"ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen",
|
| 183 |
+
"seventeen", "eighteen", "nineteen"]
|
| 184 |
+
|
| 185 |
+
tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
|
| 186 |
+
|
| 187 |
+
if num < 20:
|
| 188 |
+
return ones[num]
|
| 189 |
+
|
| 190 |
+
if num < 100:
|
| 191 |
+
return tens[num // 10] + (" " + ones[num % 10] if num % 10 != 0 else "")
|
| 192 |
+
|
| 193 |
+
if num < 1000:
|
| 194 |
+
return ones[num // 100] + " hundred" + (" " + _int_to_text(num % 100) if num % 100 != 0 else "")
|
| 195 |
+
|
| 196 |
+
if num < 1000000:
|
| 197 |
+
return _int_to_text(num // 1000) + " thousand" + (" " + _int_to_text(num % 1000) if num % 1000 != 0 else "")
|
| 198 |
+
|
| 199 |
+
if num < 1000000000:
|
| 200 |
+
return _int_to_text(num // 1000000) + " million" + (" " + _int_to_text(num % 1000000) if num % 1000000 != 0 else "")
|
| 201 |
+
|
| 202 |
+
return _int_to_text(num // 1000000000) + " billion" + (" " + _int_to_text(num % 1000000000) if num % 1000000000 != 0 else "")
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _digit_to_text(digit):
|
| 206 |
+
"""Convert a single digit to text"""
|
| 207 |
+
digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
| 208 |
+
return digits[digit]
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
_whitespace_re = re.compile(r"\s+")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
| 215 |
+
_abbreviations = {
|
| 216 |
+
"en": [
|
| 217 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
| 218 |
+
for x in [
|
| 219 |
+
("mrs", "misess"),
|
| 220 |
+
("mr", "mister"),
|
| 221 |
+
("dr", "doctor"),
|
| 222 |
+
("st", "saint"),
|
| 223 |
+
("co", "company"),
|
| 224 |
+
("jr", "junior"),
|
| 225 |
+
("maj", "major"),
|
| 226 |
+
("gen", "general"),
|
| 227 |
+
("drs", "doctors"),
|
| 228 |
+
("rev", "reverend"),
|
| 229 |
+
("lt", "lieutenant"),
|
| 230 |
+
("hon", "honorable"),
|
| 231 |
+
("sgt", "sergeant"),
|
| 232 |
+
("capt", "captain"),
|
| 233 |
+
("esq", "esquire"),
|
| 234 |
+
("ltd", "limited"),
|
| 235 |
+
("col", "colonel"),
|
| 236 |
+
("ft", "fort"),
|
| 237 |
+
]
|
| 238 |
+
],
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def expand_abbreviations_multilingual(text, lang="en"):
|
| 243 |
+
for regex, replacement in _abbreviations[lang]:
|
| 244 |
+
text = re.sub(regex, replacement, text)
|
| 245 |
+
return text
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
_symbols_multilingual = {
|
| 249 |
+
"en": [
|
| 250 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
| 251 |
+
for x in [
|
| 252 |
+
("&", " and "),
|
| 253 |
+
("@", " at "),
|
| 254 |
+
("%", " percent "),
|
| 255 |
+
("#", " hash "),
|
| 256 |
+
("$", " dollar "),
|
| 257 |
+
("£", " pound "),
|
| 258 |
+
("°", " degree "),
|
| 259 |
+
]
|
| 260 |
+
],
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def expand_symbols_multilingual(text, lang="en"):
|
| 265 |
+
for regex, replacement in _symbols_multilingual[lang]:
|
| 266 |
+
text = re.sub(regex, replacement, text)
|
| 267 |
+
text = text.replace(" ", " ") # Ensure there are no double spaces
|
| 268 |
+
return text.strip()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
_ordinal_re = {
|
| 272 |
+
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
|
| 273 |
+
}
|
| 274 |
+
_number_re = re.compile(r"[0-9]+")
|
| 275 |
+
_currency_re = {
|
| 276 |
+
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
| 277 |
+
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
| 278 |
+
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
| 282 |
+
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
| 283 |
+
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def _remove_commas(m):
|
| 287 |
+
text = m.group(0)
|
| 288 |
+
if "," in text:
|
| 289 |
+
text = text.replace(",", "")
|
| 290 |
+
return text
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def _remove_dots(m):
|
| 294 |
+
text = m.group(0)
|
| 295 |
+
if "." in text:
|
| 296 |
+
text = text.replace(".", "")
|
| 297 |
+
return text
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def _expand_decimal_point(m, lang="en"):
|
| 301 |
+
amount = m.group(1).replace(",", ".")
|
| 302 |
+
return number_to_text(float(amount))
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _expand_currency(m, lang="en", currency="USD"):
|
| 306 |
+
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
|
| 307 |
+
full_amount = number_to_text(amount)
|
| 308 |
+
|
| 309 |
+
and_equivalents = {
|
| 310 |
+
"en": ", ",
|
| 311 |
+
"es": " con ",
|
| 312 |
+
"fr": " et ",
|
| 313 |
+
"de": " und ",
|
| 314 |
+
"pt": " e ",
|
| 315 |
+
"it": " e ",
|
| 316 |
+
"pl": ", ",
|
| 317 |
+
"cs": ", ",
|
| 318 |
+
"ru": ", ",
|
| 319 |
+
"nl": ", ",
|
| 320 |
+
"ar": ", ",
|
| 321 |
+
"tr": ", ",
|
| 322 |
+
"hu": ", ",
|
| 323 |
+
"ko": ", ",
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
if amount.is_integer():
|
| 327 |
+
last_and = full_amount.rfind(and_equivalents[lang])
|
| 328 |
+
if last_and != -1:
|
| 329 |
+
full_amount = full_amount[:last_and]
|
| 330 |
+
|
| 331 |
+
return full_amount
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def _expand_ordinal(m, lang="en"):
|
| 335 |
+
return number_to_text(int(m.group(1)), ordinal=True)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def _expand_number(m, lang="en"):
|
| 339 |
+
return number_to_text(int(m.group(0)))
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def expand_numbers_multilingual(text, lang="en"):
|
| 343 |
+
if lang in ["en", "ru"]:
|
| 344 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
| 345 |
+
else:
|
| 346 |
+
text = re.sub(_dot_number_re, _remove_dots, text)
|
| 347 |
+
try:
|
| 348 |
+
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
|
| 349 |
+
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
|
| 350 |
+
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
|
| 351 |
+
except:
|
| 352 |
+
pass
|
| 353 |
+
|
| 354 |
+
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
|
| 355 |
+
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
|
| 356 |
+
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
| 357 |
+
return text
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def lowercase(text):
|
| 361 |
+
return text.lower()
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def collapse_whitespace(text):
|
| 365 |
+
return re.sub(_whitespace_re, " ", text)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def multilingual_cleaners(text, lang):
|
| 369 |
+
text = text.replace('"', "")
|
| 370 |
+
if lang == "tr":
|
| 371 |
+
text = text.replace("İ", "i")
|
| 372 |
+
text = text.replace("Ö", "ö")
|
| 373 |
+
text = text.replace("Ü", "ü")
|
| 374 |
+
text = lowercase(text)
|
| 375 |
+
try:
|
| 376 |
+
text = expand_numbers_multilingual(text, lang)
|
| 377 |
+
except:
|
| 378 |
+
pass
|
| 379 |
+
try:
|
| 380 |
+
text = expand_abbreviations_multilingual(text, lang)
|
| 381 |
+
except:
|
| 382 |
+
pass
|
| 383 |
+
try:
|
| 384 |
+
text = expand_symbols_multilingual(text, lang=lang)
|
| 385 |
+
except:
|
| 386 |
+
pass
|
| 387 |
+
text = collapse_whitespace(text)
|
| 388 |
+
return text
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def basic_cleaners(text):
|
| 392 |
+
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
| 393 |
+
text = lowercase(text)
|
| 394 |
+
text = collapse_whitespace(text)
|
| 395 |
+
return text
|
ComfyUI/comfy/text_encoders/aura_t5.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from comfy import sd1_clip
|
| 2 |
+
from .spiece_tokenizer import SPieceTokenizer
|
| 3 |
+
import comfy.text_encoders.t5
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
class PT5XlModel(sd1_clip.SDClipModel):
|
| 7 |
+
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
| 8 |
+
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
|
| 9 |
+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
|
| 10 |
+
|
| 11 |
+
class PT5XlTokenizer(sd1_clip.SDTokenizer):
|
| 12 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 13 |
+
tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
|
| 14 |
+
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data)
|
| 15 |
+
|
| 16 |
+
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
| 17 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 18 |
+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)
|
| 19 |
+
|
| 20 |
+
class AuraT5Model(sd1_clip.SD1ClipModel):
|
| 21 |
+
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
| 22 |
+
super().__init__(device=device, dtype=dtype, model_options=model_options, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
|
ComfyUI/comfy/text_encoders/bert.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from comfy.ldm.modules.attention import optimized_attention_for_device
|
| 3 |
+
import comfy.ops
|
| 4 |
+
|
| 5 |
+
class BertAttention(torch.nn.Module):
|
| 6 |
+
def __init__(self, embed_dim, heads, dtype, device, operations):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
self.heads = heads
|
| 10 |
+
self.query = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
| 11 |
+
self.key = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
| 12 |
+
self.value = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def forward(self, x, mask=None, optimized_attention=None):
|
| 16 |
+
q = self.query(x)
|
| 17 |
+
k = self.key(x)
|
| 18 |
+
v = self.value(x)
|
| 19 |
+
|
| 20 |
+
out = optimized_attention(q, k, v, self.heads, mask)
|
| 21 |
+
return out
|
| 22 |
+
|
| 23 |
+
class BertOutput(torch.nn.Module):
|
| 24 |
+
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
|
| 27 |
+
self.LayerNorm = operations.LayerNorm(output_dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
| 28 |
+
# self.dropout = nn.Dropout(0.0)
|
| 29 |
+
|
| 30 |
+
def forward(self, x, y):
|
| 31 |
+
x = self.dense(x)
|
| 32 |
+
# hidden_states = self.dropout(hidden_states)
|
| 33 |
+
x = self.LayerNorm(x + y)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
class BertAttentionBlock(torch.nn.Module):
|
| 37 |
+
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.self = BertAttention(embed_dim, heads, dtype, device, operations)
|
| 40 |
+
self.output = BertOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
| 41 |
+
|
| 42 |
+
def forward(self, x, mask, optimized_attention):
|
| 43 |
+
y = self.self(x, mask, optimized_attention)
|
| 44 |
+
return self.output(y, x)
|
| 45 |
+
|
| 46 |
+
class BertIntermediate(torch.nn.Module):
|
| 47 |
+
def __init__(self, embed_dim, intermediate_dim, dtype, device, operations):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.dense = operations.Linear(embed_dim, intermediate_dim, dtype=dtype, device=device)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
x = self.dense(x)
|
| 53 |
+
return torch.nn.functional.gelu(x)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class BertBlock(torch.nn.Module):
|
| 57 |
+
def __init__(self, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.attention = BertAttentionBlock(embed_dim, heads, layer_norm_eps, dtype, device, operations)
|
| 60 |
+
self.intermediate = BertIntermediate(embed_dim, intermediate_dim, dtype, device, operations)
|
| 61 |
+
self.output = BertOutput(intermediate_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
| 62 |
+
|
| 63 |
+
def forward(self, x, mask, optimized_attention):
|
| 64 |
+
x = self.attention(x, mask, optimized_attention)
|
| 65 |
+
y = self.intermediate(x)
|
| 66 |
+
return self.output(y, x)
|
| 67 |
+
|
| 68 |
+
class BertEncoder(torch.nn.Module):
|
| 69 |
+
def __init__(self, num_layers, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.layer = torch.nn.ModuleList([BertBlock(embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations) for i in range(num_layers)])
|
| 72 |
+
|
| 73 |
+
def forward(self, x, mask=None, intermediate_output=None):
|
| 74 |
+
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
| 75 |
+
|
| 76 |
+
if intermediate_output is not None:
|
| 77 |
+
if intermediate_output < 0:
|
| 78 |
+
intermediate_output = len(self.layer) + intermediate_output
|
| 79 |
+
|
| 80 |
+
intermediate = None
|
| 81 |
+
for i, l in enumerate(self.layer):
|
| 82 |
+
x = l(x, mask, optimized_attention)
|
| 83 |
+
if i == intermediate_output:
|
| 84 |
+
intermediate = x.clone()
|
| 85 |
+
return x, intermediate
|
| 86 |
+
|
| 87 |
+
class BertEmbeddings(torch.nn.Module):
|
| 88 |
+
def __init__(self, vocab_size, max_position_embeddings, type_vocab_size, pad_token_id, embed_dim, layer_norm_eps, dtype, device, operations):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.word_embeddings = operations.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id, dtype=dtype, device=device)
|
| 91 |
+
self.position_embeddings = operations.Embedding(max_position_embeddings, embed_dim, dtype=dtype, device=device)
|
| 92 |
+
self.token_type_embeddings = operations.Embedding(type_vocab_size, embed_dim, dtype=dtype, device=device)
|
| 93 |
+
|
| 94 |
+
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
| 95 |
+
|
| 96 |
+
def forward(self, input_tokens, embeds=None, token_type_ids=None, dtype=None):
|
| 97 |
+
if embeds is not None:
|
| 98 |
+
x = embeds
|
| 99 |
+
else:
|
| 100 |
+
x = self.word_embeddings(input_tokens, out_dtype=dtype)
|
| 101 |
+
x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
|
| 102 |
+
if token_type_ids is not None:
|
| 103 |
+
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
|
| 104 |
+
else:
|
| 105 |
+
x += comfy.ops.cast_to_input(self.token_type_embeddings.weight[0], x)
|
| 106 |
+
x = self.LayerNorm(x)
|
| 107 |
+
return x
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class BertModel_(torch.nn.Module):
|
| 111 |
+
def __init__(self, config_dict, dtype, device, operations):
|
| 112 |
+
super().__init__()
|
| 113 |
+
embed_dim = config_dict["hidden_size"]
|
| 114 |
+
layer_norm_eps = config_dict["layer_norm_eps"]
|
| 115 |
+
|
| 116 |
+
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
|
| 117 |
+
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
|
| 118 |
+
|
| 119 |
+
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
| 120 |
+
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
|
| 121 |
+
mask = None
|
| 122 |
+
if attention_mask is not None:
|
| 123 |
+
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
| 124 |
+
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
| 125 |
+
|
| 126 |
+
x, i = self.encoder(x, mask, intermediate_output)
|
| 127 |
+
return x, i
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class BertModel(torch.nn.Module):
|
| 131 |
+
def __init__(self, config_dict, dtype, device, operations):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.bert = BertModel_(config_dict, dtype, device, operations)
|
| 134 |
+
self.num_layers = config_dict["num_hidden_layers"]
|
| 135 |
+
|
| 136 |
+
def get_input_embeddings(self):
|
| 137 |
+
return self.bert.embeddings.word_embeddings
|
| 138 |
+
|
| 139 |
+
def set_input_embeddings(self, embeddings):
|
| 140 |
+
self.bert.embeddings.word_embeddings = embeddings
|
| 141 |
+
|
| 142 |
+
def forward(self, *args, **kwargs):
|
| 143 |
+
return self.bert(*args, **kwargs)
|
ComfyUI/comfy/text_encoders/cosmos.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from comfy import sd1_clip
|
| 2 |
+
import comfy.text_encoders.t5
|
| 3 |
+
import os
|
| 4 |
+
from transformers import T5TokenizerFast
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class T5XXLModel(sd1_clip.SDClipModel):
|
| 8 |
+
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
| 9 |
+
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
| 10 |
+
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
| 11 |
+
if t5xxl_scaled_fp8 is not None:
|
| 12 |
+
model_options = model_options.copy()
|
| 13 |
+
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
| 14 |
+
|
| 15 |
+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
| 16 |
+
|
| 17 |
+
class CosmosT5XXL(sd1_clip.SD1ClipModel):
|
| 18 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 19 |
+
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
| 23 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 24 |
+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
| 25 |
+
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
| 29 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 30 |
+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
| 34 |
+
class CosmosTEModel_(CosmosT5XXL):
|
| 35 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 36 |
+
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
| 37 |
+
model_options = model_options.copy()
|
| 38 |
+
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
| 39 |
+
if dtype is None:
|
| 40 |
+
dtype = dtype_t5
|
| 41 |
+
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
| 42 |
+
return CosmosTEModel_
|
ComfyUI/comfy/text_encoders/flux.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from comfy import sd1_clip
|
| 2 |
+
import comfy.text_encoders.t5
|
| 3 |
+
import comfy.text_encoders.sd3_clip
|
| 4 |
+
import comfy.model_management
|
| 5 |
+
from transformers import T5TokenizerFast
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
| 10 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 11 |
+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
| 12 |
+
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FluxTokenizer:
|
| 16 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 17 |
+
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
| 18 |
+
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
| 19 |
+
|
| 20 |
+
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
| 21 |
+
out = {}
|
| 22 |
+
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 23 |
+
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 24 |
+
return out
|
| 25 |
+
|
| 26 |
+
def untokenize(self, token_weight_pair):
|
| 27 |
+
return self.clip_l.untokenize(token_weight_pair)
|
| 28 |
+
|
| 29 |
+
def state_dict(self):
|
| 30 |
+
return {}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class FluxClipModel(torch.nn.Module):
|
| 34 |
+
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
| 35 |
+
super().__init__()
|
| 36 |
+
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
| 37 |
+
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
| 38 |
+
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
| 39 |
+
self.dtypes = set([dtype, dtype_t5])
|
| 40 |
+
|
| 41 |
+
def set_clip_options(self, options):
|
| 42 |
+
self.clip_l.set_clip_options(options)
|
| 43 |
+
self.t5xxl.set_clip_options(options)
|
| 44 |
+
|
| 45 |
+
def reset_clip_options(self):
|
| 46 |
+
self.clip_l.reset_clip_options()
|
| 47 |
+
self.t5xxl.reset_clip_options()
|
| 48 |
+
|
| 49 |
+
def encode_token_weights(self, token_weight_pairs):
|
| 50 |
+
token_weight_pairs_l = token_weight_pairs["l"]
|
| 51 |
+
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
|
| 52 |
+
|
| 53 |
+
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
| 54 |
+
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
| 55 |
+
return t5_out, l_pooled
|
| 56 |
+
|
| 57 |
+
def load_sd(self, sd):
|
| 58 |
+
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
| 59 |
+
return self.clip_l.load_sd(sd)
|
| 60 |
+
else:
|
| 61 |
+
return self.t5xxl.load_sd(sd)
|
| 62 |
+
|
| 63 |
+
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
| 64 |
+
class FluxClipModel_(FluxClipModel):
|
| 65 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 66 |
+
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
| 67 |
+
model_options = model_options.copy()
|
| 68 |
+
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
| 69 |
+
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
| 70 |
+
return FluxClipModel_
|
ComfyUI/comfy/text_encoders/genmo.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from comfy import sd1_clip
|
| 2 |
+
import comfy.text_encoders.sd3_clip
|
| 3 |
+
import os
|
| 4 |
+
from transformers import T5TokenizerFast
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
| 8 |
+
def __init__(self, **kwargs):
|
| 9 |
+
kwargs["attention_mask"] = True
|
| 10 |
+
super().__init__(**kwargs)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MochiT5XXL(sd1_clip.SD1ClipModel):
|
| 14 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 15 |
+
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
| 19 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 20 |
+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
| 21 |
+
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
| 25 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 26 |
+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
| 30 |
+
class MochiTEModel_(MochiT5XXL):
|
| 31 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 32 |
+
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
| 33 |
+
model_options = model_options.copy()
|
| 34 |
+
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
| 35 |
+
if dtype is None:
|
| 36 |
+
dtype = dtype_t5
|
| 37 |
+
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
| 38 |
+
return MochiTEModel_
|
ComfyUI/comfy/text_encoders/hidream.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import hunyuan_video
|
| 2 |
+
from . import sd3_clip
|
| 3 |
+
from comfy import sd1_clip
|
| 4 |
+
from comfy import sdxl_clip
|
| 5 |
+
import comfy.model_management
|
| 6 |
+
import torch
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HiDreamTokenizer:
|
| 11 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 12 |
+
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
| 13 |
+
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
| 14 |
+
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data)
|
| 15 |
+
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
|
| 16 |
+
|
| 17 |
+
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
| 18 |
+
out = {}
|
| 19 |
+
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 20 |
+
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 21 |
+
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 22 |
+
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
|
| 23 |
+
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 24 |
+
return out
|
| 25 |
+
|
| 26 |
+
def untokenize(self, token_weight_pair):
|
| 27 |
+
return self.clip_g.untokenize(token_weight_pair)
|
| 28 |
+
|
| 29 |
+
def state_dict(self):
|
| 30 |
+
return {}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class HiDreamTEModel(torch.nn.Module):
|
| 34 |
+
def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.dtypes = set()
|
| 37 |
+
if clip_l:
|
| 38 |
+
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options)
|
| 39 |
+
self.dtypes.add(dtype)
|
| 40 |
+
else:
|
| 41 |
+
self.clip_l = None
|
| 42 |
+
|
| 43 |
+
if clip_g:
|
| 44 |
+
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
| 45 |
+
self.dtypes.add(dtype)
|
| 46 |
+
else:
|
| 47 |
+
self.clip_g = None
|
| 48 |
+
|
| 49 |
+
if t5:
|
| 50 |
+
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
| 51 |
+
self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True)
|
| 52 |
+
self.dtypes.add(dtype_t5)
|
| 53 |
+
else:
|
| 54 |
+
self.t5xxl = None
|
| 55 |
+
|
| 56 |
+
if llama:
|
| 57 |
+
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
| 58 |
+
if "vocab_size" not in model_options:
|
| 59 |
+
model_options["vocab_size"] = 128256
|
| 60 |
+
self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009})
|
| 61 |
+
self.dtypes.add(dtype_llama)
|
| 62 |
+
else:
|
| 63 |
+
self.llama = None
|
| 64 |
+
|
| 65 |
+
logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
|
| 66 |
+
|
| 67 |
+
def set_clip_options(self, options):
|
| 68 |
+
if self.clip_l is not None:
|
| 69 |
+
self.clip_l.set_clip_options(options)
|
| 70 |
+
if self.clip_g is not None:
|
| 71 |
+
self.clip_g.set_clip_options(options)
|
| 72 |
+
if self.t5xxl is not None:
|
| 73 |
+
self.t5xxl.set_clip_options(options)
|
| 74 |
+
if self.llama is not None:
|
| 75 |
+
self.llama.set_clip_options(options)
|
| 76 |
+
|
| 77 |
+
def reset_clip_options(self):
|
| 78 |
+
if self.clip_l is not None:
|
| 79 |
+
self.clip_l.reset_clip_options()
|
| 80 |
+
if self.clip_g is not None:
|
| 81 |
+
self.clip_g.reset_clip_options()
|
| 82 |
+
if self.t5xxl is not None:
|
| 83 |
+
self.t5xxl.reset_clip_options()
|
| 84 |
+
if self.llama is not None:
|
| 85 |
+
self.llama.reset_clip_options()
|
| 86 |
+
|
| 87 |
+
def encode_token_weights(self, token_weight_pairs):
|
| 88 |
+
token_weight_pairs_l = token_weight_pairs["l"]
|
| 89 |
+
token_weight_pairs_g = token_weight_pairs["g"]
|
| 90 |
+
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
|
| 91 |
+
token_weight_pairs_llama = token_weight_pairs["llama"]
|
| 92 |
+
lg_out = None
|
| 93 |
+
pooled = None
|
| 94 |
+
extra = {}
|
| 95 |
+
|
| 96 |
+
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
| 97 |
+
if self.clip_l is not None:
|
| 98 |
+
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
| 99 |
+
else:
|
| 100 |
+
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
|
| 101 |
+
|
| 102 |
+
if self.clip_g is not None:
|
| 103 |
+
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
| 104 |
+
else:
|
| 105 |
+
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
|
| 106 |
+
|
| 107 |
+
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
| 108 |
+
|
| 109 |
+
if self.t5xxl is not None:
|
| 110 |
+
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
| 111 |
+
t5_out, t5_pooled = t5_output[:2]
|
| 112 |
+
else:
|
| 113 |
+
t5_out = None
|
| 114 |
+
|
| 115 |
+
if self.llama is not None:
|
| 116 |
+
ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
|
| 117 |
+
ll_out, ll_pooled = ll_output[:2]
|
| 118 |
+
ll_out = ll_out[:, 1:]
|
| 119 |
+
else:
|
| 120 |
+
ll_out = None
|
| 121 |
+
|
| 122 |
+
if t5_out is None:
|
| 123 |
+
t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
|
| 124 |
+
|
| 125 |
+
if ll_out is None:
|
| 126 |
+
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
| 127 |
+
|
| 128 |
+
if pooled is None:
|
| 129 |
+
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
| 130 |
+
|
| 131 |
+
extra["conditioning_llama3"] = ll_out
|
| 132 |
+
return t5_out, pooled, extra
|
| 133 |
+
|
| 134 |
+
def load_sd(self, sd):
|
| 135 |
+
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
| 136 |
+
return self.clip_g.load_sd(sd)
|
| 137 |
+
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
| 138 |
+
return self.clip_l.load_sd(sd)
|
| 139 |
+
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
| 140 |
+
return self.t5xxl.load_sd(sd)
|
| 141 |
+
else:
|
| 142 |
+
return self.llama.load_sd(sd)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
|
| 146 |
+
class HiDreamTEModel_(HiDreamTEModel):
|
| 147 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 148 |
+
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
| 149 |
+
model_options = model_options.copy()
|
| 150 |
+
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
| 151 |
+
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
| 152 |
+
model_options = model_options.copy()
|
| 153 |
+
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
| 154 |
+
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
| 155 |
+
return HiDreamTEModel_
|
ComfyUI/comfy/text_encoders/hunyuan_video.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from comfy import sd1_clip
|
| 2 |
+
import comfy.model_management
|
| 3 |
+
import comfy.text_encoders.llama
|
| 4 |
+
from transformers import LlamaTokenizerFast
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
import numbers
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def llama_detect(state_dict, prefix=""):
|
| 11 |
+
out = {}
|
| 12 |
+
t5_key = "{}model.norm.weight".format(prefix)
|
| 13 |
+
if t5_key in state_dict:
|
| 14 |
+
out["dtype_llama"] = state_dict[t5_key].dtype
|
| 15 |
+
|
| 16 |
+
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
| 17 |
+
if scaled_fp8_key in state_dict:
|
| 18 |
+
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
| 19 |
+
|
| 20 |
+
return out
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
| 24 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258):
|
| 25 |
+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
|
| 26 |
+
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data)
|
| 27 |
+
|
| 28 |
+
class LLAMAModel(sd1_clip.SDClipModel):
|
| 29 |
+
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
| 30 |
+
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
| 31 |
+
if llama_scaled_fp8 is not None:
|
| 32 |
+
model_options = model_options.copy()
|
| 33 |
+
model_options["scaled_fp8"] = llama_scaled_fp8
|
| 34 |
+
|
| 35 |
+
textmodel_json_config = {}
|
| 36 |
+
vocab_size = model_options.get("vocab_size", None)
|
| 37 |
+
if vocab_size is not None:
|
| 38 |
+
textmodel_json_config["vocab_size"] = vocab_size
|
| 39 |
+
|
| 40 |
+
model_options = {**model_options, "model_name": "llama"}
|
| 41 |
+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class HunyuanVideoTokenizer:
|
| 45 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 46 |
+
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
| 47 |
+
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
|
| 48 |
+
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data)
|
| 49 |
+
|
| 50 |
+
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
|
| 51 |
+
out = {}
|
| 52 |
+
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 53 |
+
|
| 54 |
+
if llama_template is None:
|
| 55 |
+
llama_text = self.llama_template.format(text)
|
| 56 |
+
else:
|
| 57 |
+
llama_text = llama_template.format(text)
|
| 58 |
+
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids, **kwargs)
|
| 59 |
+
embed_count = 0
|
| 60 |
+
for r in llama_text_tokens:
|
| 61 |
+
for i in range(len(r)):
|
| 62 |
+
if r[i][0] == 128257:
|
| 63 |
+
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
| 64 |
+
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:]
|
| 65 |
+
embed_count += 1
|
| 66 |
+
out["llama"] = llama_text_tokens
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
def untokenize(self, token_weight_pair):
|
| 70 |
+
return self.clip_l.untokenize(token_weight_pair)
|
| 71 |
+
|
| 72 |
+
def state_dict(self):
|
| 73 |
+
return {}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class HunyuanVideoClipModel(torch.nn.Module):
|
| 77 |
+
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
| 78 |
+
super().__init__()
|
| 79 |
+
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
| 80 |
+
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
| 81 |
+
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
|
| 82 |
+
self.dtypes = set([dtype, dtype_llama])
|
| 83 |
+
|
| 84 |
+
def set_clip_options(self, options):
|
| 85 |
+
self.clip_l.set_clip_options(options)
|
| 86 |
+
self.llama.set_clip_options(options)
|
| 87 |
+
|
| 88 |
+
def reset_clip_options(self):
|
| 89 |
+
self.clip_l.reset_clip_options()
|
| 90 |
+
self.llama.reset_clip_options()
|
| 91 |
+
|
| 92 |
+
def encode_token_weights(self, token_weight_pairs):
|
| 93 |
+
token_weight_pairs_l = token_weight_pairs["l"]
|
| 94 |
+
token_weight_pairs_llama = token_weight_pairs["llama"]
|
| 95 |
+
|
| 96 |
+
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
|
| 97 |
+
|
| 98 |
+
template_end = 0
|
| 99 |
+
extra_template_end = 0
|
| 100 |
+
extra_sizes = 0
|
| 101 |
+
user_end = 9999999999999
|
| 102 |
+
images = []
|
| 103 |
+
|
| 104 |
+
tok_pairs = token_weight_pairs_llama[0]
|
| 105 |
+
for i, v in enumerate(tok_pairs):
|
| 106 |
+
elem = v[0]
|
| 107 |
+
if not torch.is_tensor(elem):
|
| 108 |
+
if isinstance(elem, numbers.Integral):
|
| 109 |
+
if elem == 128006:
|
| 110 |
+
if tok_pairs[i + 1][0] == 882:
|
| 111 |
+
if tok_pairs[i + 2][0] == 128007:
|
| 112 |
+
template_end = i + 2
|
| 113 |
+
user_end = -1
|
| 114 |
+
if elem == 128009 and user_end == -1:
|
| 115 |
+
user_end = i + 1
|
| 116 |
+
else:
|
| 117 |
+
if elem.get("original_type") == "image":
|
| 118 |
+
elem_size = elem.get("data").shape[0]
|
| 119 |
+
if template_end > 0:
|
| 120 |
+
if user_end == -1:
|
| 121 |
+
extra_template_end += elem_size - 1
|
| 122 |
+
else:
|
| 123 |
+
image_start = i + extra_sizes
|
| 124 |
+
image_end = i + elem_size + extra_sizes
|
| 125 |
+
images.append((image_start, image_end, elem.get("image_interleave", 1)))
|
| 126 |
+
extra_sizes += elem_size - 1
|
| 127 |
+
|
| 128 |
+
if llama_out.shape[1] > (template_end + 2):
|
| 129 |
+
if tok_pairs[template_end + 1][0] == 271:
|
| 130 |
+
template_end += 2
|
| 131 |
+
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
| 132 |
+
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
| 133 |
+
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
| 134 |
+
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
|
| 135 |
+
|
| 136 |
+
if len(images) > 0:
|
| 137 |
+
out = []
|
| 138 |
+
for i in images:
|
| 139 |
+
out.append(llama_out[:, i[0]: i[1]: i[2]])
|
| 140 |
+
llama_output = torch.cat(out + [llama_output], dim=1)
|
| 141 |
+
|
| 142 |
+
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
| 143 |
+
return llama_output, l_pooled, llama_extra_out
|
| 144 |
+
|
| 145 |
+
def load_sd(self, sd):
|
| 146 |
+
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
| 147 |
+
return self.clip_l.load_sd(sd)
|
| 148 |
+
else:
|
| 149 |
+
return self.llama.load_sd(sd)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
|
| 153 |
+
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
|
| 154 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 155 |
+
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
| 156 |
+
model_options = model_options.copy()
|
| 157 |
+
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
| 158 |
+
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
| 159 |
+
return HunyuanVideoClipModel_
|
ComfyUI/comfy/text_encoders/hydit.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from comfy import sd1_clip
|
| 2 |
+
from transformers import BertTokenizer
|
| 3 |
+
from .spiece_tokenizer import SPieceTokenizer
|
| 4 |
+
from .bert import BertModel
|
| 5 |
+
import comfy.text_encoders.t5
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
class HyditBertModel(sd1_clip.SDClipModel):
|
| 10 |
+
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
| 11 |
+
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
| 12 |
+
model_options = {**model_options, "model_name": "hydit_clip"}
|
| 13 |
+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
| 14 |
+
|
| 15 |
+
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
| 16 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 17 |
+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
|
| 18 |
+
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MT5XLModel(sd1_clip.SDClipModel):
|
| 22 |
+
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
| 23 |
+
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
| 24 |
+
model_options = {**model_options, "model_name": "mt5xl"}
|
| 25 |
+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
| 26 |
+
|
| 27 |
+
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
| 28 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 29 |
+
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
|
| 30 |
+
tokenizer = tokenizer_data.get("spiece_model", None)
|
| 31 |
+
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
|
| 32 |
+
|
| 33 |
+
def state_dict(self):
|
| 34 |
+
return {"spiece_model": self.tokenizer.serialize_model()}
|
| 35 |
+
|
| 36 |
+
class HyditTokenizer:
|
| 37 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 38 |
+
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
|
| 39 |
+
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
|
| 40 |
+
self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
|
| 41 |
+
|
| 42 |
+
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
| 43 |
+
out = {}
|
| 44 |
+
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 45 |
+
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 46 |
+
return out
|
| 47 |
+
|
| 48 |
+
def untokenize(self, token_weight_pair):
|
| 49 |
+
return self.hydit_clip.untokenize(token_weight_pair)
|
| 50 |
+
|
| 51 |
+
def state_dict(self):
|
| 52 |
+
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
|
| 53 |
+
|
| 54 |
+
class HyditModel(torch.nn.Module):
|
| 55 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.hydit_clip = HyditBertModel(dtype=dtype, model_options=model_options)
|
| 58 |
+
self.mt5xl = MT5XLModel(dtype=dtype, model_options=model_options)
|
| 59 |
+
|
| 60 |
+
self.dtypes = set()
|
| 61 |
+
if dtype is not None:
|
| 62 |
+
self.dtypes.add(dtype)
|
| 63 |
+
|
| 64 |
+
def encode_token_weights(self, token_weight_pairs):
|
| 65 |
+
hydit_out = self.hydit_clip.encode_token_weights(token_weight_pairs["hydit_clip"])
|
| 66 |
+
mt5_out = self.mt5xl.encode_token_weights(token_weight_pairs["mt5xl"])
|
| 67 |
+
return hydit_out[0], hydit_out[1], {"attention_mask": hydit_out[2]["attention_mask"], "conditioning_mt5xl": mt5_out[0], "attention_mask_mt5xl": mt5_out[2]["attention_mask"]}
|
| 68 |
+
|
| 69 |
+
def load_sd(self, sd):
|
| 70 |
+
if "bert.encoder.layer.0.attention.self.query.weight" in sd:
|
| 71 |
+
return self.hydit_clip.load_sd(sd)
|
| 72 |
+
else:
|
| 73 |
+
return self.mt5xl.load_sd(sd)
|
| 74 |
+
|
| 75 |
+
def set_clip_options(self, options):
|
| 76 |
+
self.hydit_clip.set_clip_options(options)
|
| 77 |
+
self.mt5xl.set_clip_options(options)
|
| 78 |
+
|
| 79 |
+
def reset_clip_options(self):
|
| 80 |
+
self.hydit_clip.reset_clip_options()
|
| 81 |
+
self.mt5xl.reset_clip_options()
|
ComfyUI/comfy/text_encoders/hydit_clip.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "hfl/chinese-roberta-wwm-ext-large",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BertModel"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"classifier_dropout": null,
|
| 9 |
+
"directionality": "bidi",
|
| 10 |
+
"eos_token_id": 2,
|
| 11 |
+
"hidden_act": "gelu",
|
| 12 |
+
"hidden_dropout_prob": 0.1,
|
| 13 |
+
"hidden_size": 1024,
|
| 14 |
+
"initializer_range": 0.02,
|
| 15 |
+
"intermediate_size": 4096,
|
| 16 |
+
"layer_norm_eps": 1e-12,
|
| 17 |
+
"max_position_embeddings": 512,
|
| 18 |
+
"model_type": "bert",
|
| 19 |
+
"num_attention_heads": 16,
|
| 20 |
+
"num_hidden_layers": 24,
|
| 21 |
+
"output_past": true,
|
| 22 |
+
"pad_token_id": 0,
|
| 23 |
+
"pooler_fc_size": 768,
|
| 24 |
+
"pooler_num_attention_heads": 12,
|
| 25 |
+
"pooler_num_fc_layers": 3,
|
| 26 |
+
"pooler_size_per_head": 128,
|
| 27 |
+
"pooler_type": "first_token_transform",
|
| 28 |
+
"position_embedding_type": "absolute",
|
| 29 |
+
"torch_dtype": "float32",
|
| 30 |
+
"transformers_version": "4.22.1",
|
| 31 |
+
"type_vocab_size": 2,
|
| 32 |
+
"use_cache": true,
|
| 33 |
+
"vocab_size": 47020
|
| 34 |
+
}
|
| 35 |
+
|
ComfyUI/comfy/text_encoders/llama.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional, Any
|
| 5 |
+
|
| 6 |
+
from comfy.ldm.modules.attention import optimized_attention_for_device
|
| 7 |
+
import comfy.model_management
|
| 8 |
+
import comfy.ldm.common_dit
|
| 9 |
+
|
| 10 |
+
import comfy.model_management
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class Llama2Config:
|
| 14 |
+
vocab_size: int = 128320
|
| 15 |
+
hidden_size: int = 4096
|
| 16 |
+
intermediate_size: int = 14336
|
| 17 |
+
num_hidden_layers: int = 32
|
| 18 |
+
num_attention_heads: int = 32
|
| 19 |
+
num_key_value_heads: int = 8
|
| 20 |
+
max_position_embeddings: int = 8192
|
| 21 |
+
rms_norm_eps: float = 1e-5
|
| 22 |
+
rope_theta: float = 500000.0
|
| 23 |
+
transformer_type: str = "llama"
|
| 24 |
+
head_dim = 128
|
| 25 |
+
rms_norm_add = False
|
| 26 |
+
mlp_activation = "silu"
|
| 27 |
+
qkv_bias = False
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class Qwen25_3BConfig:
|
| 31 |
+
vocab_size: int = 151936
|
| 32 |
+
hidden_size: int = 2048
|
| 33 |
+
intermediate_size: int = 11008
|
| 34 |
+
num_hidden_layers: int = 36
|
| 35 |
+
num_attention_heads: int = 16
|
| 36 |
+
num_key_value_heads: int = 2
|
| 37 |
+
max_position_embeddings: int = 128000
|
| 38 |
+
rms_norm_eps: float = 1e-6
|
| 39 |
+
rope_theta: float = 1000000.0
|
| 40 |
+
transformer_type: str = "llama"
|
| 41 |
+
head_dim = 128
|
| 42 |
+
rms_norm_add = False
|
| 43 |
+
mlp_activation = "silu"
|
| 44 |
+
qkv_bias = True
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class Gemma2_2B_Config:
|
| 48 |
+
vocab_size: int = 256000
|
| 49 |
+
hidden_size: int = 2304
|
| 50 |
+
intermediate_size: int = 9216
|
| 51 |
+
num_hidden_layers: int = 26
|
| 52 |
+
num_attention_heads: int = 8
|
| 53 |
+
num_key_value_heads: int = 4
|
| 54 |
+
max_position_embeddings: int = 8192
|
| 55 |
+
rms_norm_eps: float = 1e-6
|
| 56 |
+
rope_theta: float = 10000.0
|
| 57 |
+
transformer_type: str = "gemma2"
|
| 58 |
+
head_dim = 256
|
| 59 |
+
rms_norm_add = True
|
| 60 |
+
mlp_activation = "gelu_pytorch_tanh"
|
| 61 |
+
qkv_bias = False
|
| 62 |
+
|
| 63 |
+
class RMSNorm(nn.Module):
|
| 64 |
+
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.eps = eps
|
| 67 |
+
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
| 68 |
+
self.add = add
|
| 69 |
+
|
| 70 |
+
def forward(self, x: torch.Tensor):
|
| 71 |
+
w = self.weight
|
| 72 |
+
if self.add:
|
| 73 |
+
w = w + 1.0
|
| 74 |
+
|
| 75 |
+
return comfy.ldm.common_dit.rms_norm(x, w, self.eps)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def rotate_half(x):
|
| 80 |
+
"""Rotates half the hidden dims of the input."""
|
| 81 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 82 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 83 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
|
| 87 |
+
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
|
| 88 |
+
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
|
| 89 |
+
|
| 90 |
+
position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
|
| 91 |
+
|
| 92 |
+
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 93 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 94 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 95 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 96 |
+
cos = emb.cos()
|
| 97 |
+
sin = emb.sin()
|
| 98 |
+
return (cos, sin)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def apply_rope(xq, xk, freqs_cis):
|
| 102 |
+
cos = freqs_cis[0].unsqueeze(1)
|
| 103 |
+
sin = freqs_cis[1].unsqueeze(1)
|
| 104 |
+
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
| 105 |
+
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
| 106 |
+
return q_embed, k_embed
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class Attention(nn.Module):
|
| 110 |
+
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.num_heads = config.num_attention_heads
|
| 113 |
+
self.num_kv_heads = config.num_key_value_heads
|
| 114 |
+
self.hidden_size = config.hidden_size
|
| 115 |
+
|
| 116 |
+
self.head_dim = config.head_dim
|
| 117 |
+
self.inner_size = self.num_heads * self.head_dim
|
| 118 |
+
|
| 119 |
+
ops = ops or nn
|
| 120 |
+
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
|
| 121 |
+
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
| 122 |
+
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
| 123 |
+
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
| 124 |
+
|
| 125 |
+
def forward(
|
| 126 |
+
self,
|
| 127 |
+
hidden_states: torch.Tensor,
|
| 128 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 129 |
+
freqs_cis: Optional[torch.Tensor] = None,
|
| 130 |
+
optimized_attention=None,
|
| 131 |
+
):
|
| 132 |
+
batch_size, seq_length, _ = hidden_states.shape
|
| 133 |
+
xq = self.q_proj(hidden_states)
|
| 134 |
+
xk = self.k_proj(hidden_states)
|
| 135 |
+
xv = self.v_proj(hidden_states)
|
| 136 |
+
|
| 137 |
+
xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
| 138 |
+
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 139 |
+
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 140 |
+
|
| 141 |
+
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
|
| 142 |
+
|
| 143 |
+
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
| 144 |
+
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
| 145 |
+
|
| 146 |
+
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
| 147 |
+
return self.o_proj(output)
|
| 148 |
+
|
| 149 |
+
class MLP(nn.Module):
|
| 150 |
+
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
| 151 |
+
super().__init__()
|
| 152 |
+
ops = ops or nn
|
| 153 |
+
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
| 154 |
+
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
| 155 |
+
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
| 156 |
+
if config.mlp_activation == "silu":
|
| 157 |
+
self.activation = torch.nn.functional.silu
|
| 158 |
+
elif config.mlp_activation == "gelu_pytorch_tanh":
|
| 159 |
+
self.activation = lambda a: torch.nn.functional.gelu(a, approximate="tanh")
|
| 160 |
+
|
| 161 |
+
def forward(self, x):
|
| 162 |
+
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
|
| 163 |
+
|
| 164 |
+
class TransformerBlock(nn.Module):
|
| 165 |
+
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
|
| 168 |
+
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
| 169 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
| 170 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
| 171 |
+
|
| 172 |
+
def forward(
|
| 173 |
+
self,
|
| 174 |
+
x: torch.Tensor,
|
| 175 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 176 |
+
freqs_cis: Optional[torch.Tensor] = None,
|
| 177 |
+
optimized_attention=None,
|
| 178 |
+
):
|
| 179 |
+
# Self Attention
|
| 180 |
+
residual = x
|
| 181 |
+
x = self.input_layernorm(x)
|
| 182 |
+
x = self.self_attn(
|
| 183 |
+
hidden_states=x,
|
| 184 |
+
attention_mask=attention_mask,
|
| 185 |
+
freqs_cis=freqs_cis,
|
| 186 |
+
optimized_attention=optimized_attention,
|
| 187 |
+
)
|
| 188 |
+
x = residual + x
|
| 189 |
+
|
| 190 |
+
# MLP
|
| 191 |
+
residual = x
|
| 192 |
+
x = self.post_attention_layernorm(x)
|
| 193 |
+
x = self.mlp(x)
|
| 194 |
+
x = residual + x
|
| 195 |
+
|
| 196 |
+
return x
|
| 197 |
+
|
| 198 |
+
class TransformerBlockGemma2(nn.Module):
|
| 199 |
+
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
|
| 202 |
+
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
| 203 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
| 204 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
| 205 |
+
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
| 206 |
+
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
| 207 |
+
|
| 208 |
+
def forward(
|
| 209 |
+
self,
|
| 210 |
+
x: torch.Tensor,
|
| 211 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 212 |
+
freqs_cis: Optional[torch.Tensor] = None,
|
| 213 |
+
optimized_attention=None,
|
| 214 |
+
):
|
| 215 |
+
# Self Attention
|
| 216 |
+
residual = x
|
| 217 |
+
x = self.input_layernorm(x)
|
| 218 |
+
x = self.self_attn(
|
| 219 |
+
hidden_states=x,
|
| 220 |
+
attention_mask=attention_mask,
|
| 221 |
+
freqs_cis=freqs_cis,
|
| 222 |
+
optimized_attention=optimized_attention,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
x = self.post_attention_layernorm(x)
|
| 226 |
+
x = residual + x
|
| 227 |
+
|
| 228 |
+
# MLP
|
| 229 |
+
residual = x
|
| 230 |
+
x = self.pre_feedforward_layernorm(x)
|
| 231 |
+
x = self.mlp(x)
|
| 232 |
+
x = self.post_feedforward_layernorm(x)
|
| 233 |
+
x = residual + x
|
| 234 |
+
|
| 235 |
+
return x
|
| 236 |
+
|
| 237 |
+
class Llama2_(nn.Module):
|
| 238 |
+
def __init__(self, config, device=None, dtype=None, ops=None):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.config = config
|
| 241 |
+
self.vocab_size = config.vocab_size
|
| 242 |
+
|
| 243 |
+
self.embed_tokens = ops.Embedding(
|
| 244 |
+
config.vocab_size,
|
| 245 |
+
config.hidden_size,
|
| 246 |
+
device=device,
|
| 247 |
+
dtype=dtype
|
| 248 |
+
)
|
| 249 |
+
if self.config.transformer_type == "gemma2":
|
| 250 |
+
transformer = TransformerBlockGemma2
|
| 251 |
+
self.normalize_in = True
|
| 252 |
+
else:
|
| 253 |
+
transformer = TransformerBlock
|
| 254 |
+
self.normalize_in = False
|
| 255 |
+
|
| 256 |
+
self.layers = nn.ModuleList([
|
| 257 |
+
transformer(config, device=device, dtype=dtype, ops=ops)
|
| 258 |
+
for _ in range(config.num_hidden_layers)
|
| 259 |
+
])
|
| 260 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
| 261 |
+
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
| 262 |
+
|
| 263 |
+
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
| 264 |
+
if embeds is not None:
|
| 265 |
+
x = embeds
|
| 266 |
+
else:
|
| 267 |
+
x = self.embed_tokens(x, out_dtype=dtype)
|
| 268 |
+
|
| 269 |
+
if self.normalize_in:
|
| 270 |
+
x *= self.config.hidden_size ** 0.5
|
| 271 |
+
|
| 272 |
+
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
| 273 |
+
x.shape[1],
|
| 274 |
+
self.config.rope_theta,
|
| 275 |
+
device=x.device)
|
| 276 |
+
|
| 277 |
+
mask = None
|
| 278 |
+
if attention_mask is not None:
|
| 279 |
+
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
| 280 |
+
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
| 281 |
+
|
| 282 |
+
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
| 283 |
+
if mask is not None:
|
| 284 |
+
mask += causal_mask
|
| 285 |
+
else:
|
| 286 |
+
mask = causal_mask
|
| 287 |
+
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
| 288 |
+
|
| 289 |
+
intermediate = None
|
| 290 |
+
all_intermediate = None
|
| 291 |
+
if intermediate_output is not None:
|
| 292 |
+
if intermediate_output == "all":
|
| 293 |
+
all_intermediate = []
|
| 294 |
+
intermediate_output = None
|
| 295 |
+
elif intermediate_output < 0:
|
| 296 |
+
intermediate_output = len(self.layers) + intermediate_output
|
| 297 |
+
|
| 298 |
+
for i, layer in enumerate(self.layers):
|
| 299 |
+
if all_intermediate is not None:
|
| 300 |
+
all_intermediate.append(x.unsqueeze(1).clone())
|
| 301 |
+
x = layer(
|
| 302 |
+
x=x,
|
| 303 |
+
attention_mask=mask,
|
| 304 |
+
freqs_cis=freqs_cis,
|
| 305 |
+
optimized_attention=optimized_attention,
|
| 306 |
+
)
|
| 307 |
+
if i == intermediate_output:
|
| 308 |
+
intermediate = x.clone()
|
| 309 |
+
|
| 310 |
+
x = self.norm(x)
|
| 311 |
+
if all_intermediate is not None:
|
| 312 |
+
all_intermediate.append(x.unsqueeze(1).clone())
|
| 313 |
+
|
| 314 |
+
if all_intermediate is not None:
|
| 315 |
+
intermediate = torch.cat(all_intermediate, dim=1)
|
| 316 |
+
|
| 317 |
+
if intermediate is not None and final_layer_norm_intermediate:
|
| 318 |
+
intermediate = self.norm(intermediate)
|
| 319 |
+
|
| 320 |
+
return x, intermediate
|
| 321 |
+
|
| 322 |
+
class BaseLlama:
|
| 323 |
+
def get_input_embeddings(self):
|
| 324 |
+
return self.model.embed_tokens
|
| 325 |
+
|
| 326 |
+
def set_input_embeddings(self, embeddings):
|
| 327 |
+
self.model.embed_tokens = embeddings
|
| 328 |
+
|
| 329 |
+
def forward(self, input_ids, *args, **kwargs):
|
| 330 |
+
return self.model(input_ids, *args, **kwargs)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class Llama2(BaseLlama, torch.nn.Module):
|
| 334 |
+
def __init__(self, config_dict, dtype, device, operations):
|
| 335 |
+
super().__init__()
|
| 336 |
+
config = Llama2Config(**config_dict)
|
| 337 |
+
self.num_layers = config.num_hidden_layers
|
| 338 |
+
|
| 339 |
+
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
| 340 |
+
self.dtype = dtype
|
| 341 |
+
|
| 342 |
+
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
| 343 |
+
def __init__(self, config_dict, dtype, device, operations):
|
| 344 |
+
super().__init__()
|
| 345 |
+
config = Qwen25_3BConfig(**config_dict)
|
| 346 |
+
self.num_layers = config.num_hidden_layers
|
| 347 |
+
|
| 348 |
+
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
| 349 |
+
self.dtype = dtype
|
| 350 |
+
|
| 351 |
+
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
| 352 |
+
def __init__(self, config_dict, dtype, device, operations):
|
| 353 |
+
super().__init__()
|
| 354 |
+
config = Gemma2_2B_Config(**config_dict)
|
| 355 |
+
self.num_layers = config.num_hidden_layers
|
| 356 |
+
|
| 357 |
+
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
| 358 |
+
self.dtype = dtype
|
ComfyUI/comfy/text_encoders/long_clipl.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
def model_options_long_clip(sd, tokenizer_data, model_options):
|
| 4 |
+
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
|
| 5 |
+
if w is None:
|
| 6 |
+
w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None)
|
| 7 |
+
else:
|
| 8 |
+
model_name = "clip_g"
|
| 9 |
+
|
| 10 |
+
if w is None:
|
| 11 |
+
w = sd.get("text_model.embeddings.position_embedding.weight", None)
|
| 12 |
+
if w is not None:
|
| 13 |
+
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
| 14 |
+
model_name = "clip_g"
|
| 15 |
+
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
| 16 |
+
model_name = "clip_l"
|
| 17 |
+
else:
|
| 18 |
+
model_name = "clip_l"
|
| 19 |
+
|
| 20 |
+
if w is not None:
|
| 21 |
+
tokenizer_data = tokenizer_data.copy()
|
| 22 |
+
model_options = model_options.copy()
|
| 23 |
+
model_config = model_options.get("model_config", {})
|
| 24 |
+
model_config["max_position_embeddings"] = w.shape[0]
|
| 25 |
+
model_options["{}_model_config".format(model_name)] = model_config
|
| 26 |
+
tokenizer_data["{}_max_length".format(model_name)] = w.shape[0]
|
| 27 |
+
return tokenizer_data, model_options
|
ComfyUI/comfy/text_encoders/lt.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from comfy import sd1_clip
|
| 2 |
+
import os
|
| 3 |
+
from transformers import T5TokenizerFast
|
| 4 |
+
import comfy.text_encoders.genmo
|
| 5 |
+
|
| 6 |
+
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
| 7 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 8 |
+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
| 9 |
+
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128?
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
|
| 13 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 14 |
+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def ltxv_te(*args, **kwargs):
|
| 18 |
+
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
ComfyUI/comfy/text_encoders/lumina2.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from comfy import sd1_clip
|
| 2 |
+
from .spiece_tokenizer import SPieceTokenizer
|
| 3 |
+
import comfy.text_encoders.llama
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
| 7 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 8 |
+
tokenizer = tokenizer_data.get("spiece_model", None)
|
| 9 |
+
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
| 10 |
+
|
| 11 |
+
def state_dict(self):
|
| 12 |
+
return {"spiece_model": self.tokenizer.serialize_model()}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
|
| 16 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 17 |
+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Gemma2_2BModel(sd1_clip.SDClipModel):
|
| 21 |
+
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
| 22 |
+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class LuminaModel(sd1_clip.SD1ClipModel):
|
| 26 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 27 |
+
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def te(dtype_llama=None, llama_scaled_fp8=None):
|
| 31 |
+
class LuminaTEModel_(LuminaModel):
|
| 32 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 33 |
+
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
| 34 |
+
model_options = model_options.copy()
|
| 35 |
+
model_options["scaled_fp8"] = llama_scaled_fp8
|
| 36 |
+
if dtype_llama is not None:
|
| 37 |
+
dtype = dtype_llama
|
| 38 |
+
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
| 39 |
+
return LuminaTEModel_
|
ComfyUI/comfy/text_encoders/mt5_config_xl.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"d_ff": 5120,
|
| 3 |
+
"d_kv": 64,
|
| 4 |
+
"d_model": 2048,
|
| 5 |
+
"decoder_start_token_id": 0,
|
| 6 |
+
"dropout_rate": 0.1,
|
| 7 |
+
"eos_token_id": 1,
|
| 8 |
+
"dense_act_fn": "gelu_pytorch_tanh",
|
| 9 |
+
"initializer_factor": 1.0,
|
| 10 |
+
"is_encoder_decoder": true,
|
| 11 |
+
"is_gated_act": true,
|
| 12 |
+
"layer_norm_epsilon": 1e-06,
|
| 13 |
+
"model_type": "mt5",
|
| 14 |
+
"num_decoder_layers": 24,
|
| 15 |
+
"num_heads": 32,
|
| 16 |
+
"num_layers": 24,
|
| 17 |
+
"output_past": true,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"relative_attention_num_buckets": 32,
|
| 20 |
+
"tie_word_embeddings": false,
|
| 21 |
+
"vocab_size": 250112
|
| 22 |
+
}
|
ComfyUI/comfy/text_encoders/omnigen2.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2Tokenizer
|
| 2 |
+
from comfy import sd1_clip
|
| 3 |
+
import comfy.text_encoders.llama
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
|
| 8 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 9 |
+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
| 10 |
+
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
|
| 14 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 15 |
+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
|
| 16 |
+
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
|
| 17 |
+
|
| 18 |
+
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
|
| 19 |
+
if llama_template is None:
|
| 20 |
+
llama_text = self.llama_template.format(text)
|
| 21 |
+
else:
|
| 22 |
+
llama_text = llama_template.format(text)
|
| 23 |
+
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
|
| 24 |
+
|
| 25 |
+
class Qwen25_3BModel(sd1_clip.SDClipModel):
|
| 26 |
+
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
| 27 |
+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Omnigen2Model(sd1_clip.SD1ClipModel):
|
| 31 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 32 |
+
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def te(dtype_llama=None, llama_scaled_fp8=None):
|
| 36 |
+
class Omnigen2TEModel_(Omnigen2Model):
|
| 37 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 38 |
+
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
| 39 |
+
model_options = model_options.copy()
|
| 40 |
+
model_options["scaled_fp8"] = llama_scaled_fp8
|
| 41 |
+
if dtype_llama is not None:
|
| 42 |
+
dtype = dtype_llama
|
| 43 |
+
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
| 44 |
+
return Omnigen2TEModel_
|
ComfyUI/comfy/text_encoders/pixart_t5.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from comfy import sd1_clip
|
| 4 |
+
import comfy.text_encoders.t5
|
| 5 |
+
import comfy.text_encoders.sd3_clip
|
| 6 |
+
from comfy.sd1_clip import gen_empty_tokens
|
| 7 |
+
|
| 8 |
+
from transformers import T5TokenizerFast
|
| 9 |
+
|
| 10 |
+
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
| 11 |
+
def __init__(self, **kwargs):
|
| 12 |
+
super().__init__(**kwargs)
|
| 13 |
+
|
| 14 |
+
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
|
| 15 |
+
# PixArt expects the negative to be all pad tokens
|
| 16 |
+
special_tokens = special_tokens.copy()
|
| 17 |
+
special_tokens.pop("end")
|
| 18 |
+
return gen_empty_tokens(special_tokens, *args, **kwargs)
|
| 19 |
+
|
| 20 |
+
class PixArtT5XXL(sd1_clip.SD1ClipModel):
|
| 21 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 22 |
+
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
| 23 |
+
|
| 24 |
+
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
| 25 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 26 |
+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
| 27 |
+
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
|
| 28 |
+
|
| 29 |
+
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
| 30 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 31 |
+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
| 32 |
+
|
| 33 |
+
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
| 34 |
+
class PixArtTEModel_(PixArtT5XXL):
|
| 35 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 36 |
+
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
| 37 |
+
model_options = model_options.copy()
|
| 38 |
+
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
| 39 |
+
if dtype is None:
|
| 40 |
+
dtype = dtype_t5
|
| 41 |
+
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
| 42 |
+
return PixArtTEModel_
|
ComfyUI/comfy/text_encoders/sd2_clip.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from comfy import sd1_clip
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
class SD2ClipHModel(sd1_clip.SDClipModel):
|
| 5 |
+
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, model_options={}):
|
| 6 |
+
if layer == "penultimate":
|
| 7 |
+
layer="hidden"
|
| 8 |
+
layer_idx=-2
|
| 9 |
+
|
| 10 |
+
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
| 11 |
+
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options)
|
| 12 |
+
|
| 13 |
+
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
| 14 |
+
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
|
| 15 |
+
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data)
|
| 16 |
+
|
| 17 |
+
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
| 18 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 19 |
+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
| 20 |
+
|
| 21 |
+
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
| 22 |
+
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
| 23 |
+
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
ComfyUI/comfy/text_encoders/sd2_clip_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"CLIPTextModel"
|
| 4 |
+
],
|
| 5 |
+
"attention_dropout": 0.0,
|
| 6 |
+
"bos_token_id": 0,
|
| 7 |
+
"dropout": 0.0,
|
| 8 |
+
"eos_token_id": 49407,
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_size": 1024,
|
| 11 |
+
"initializer_factor": 1.0,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 4096,
|
| 14 |
+
"layer_norm_eps": 1e-05,
|
| 15 |
+
"max_position_embeddings": 77,
|
| 16 |
+
"model_type": "clip_text_model",
|
| 17 |
+
"num_attention_heads": 16,
|
| 18 |
+
"num_hidden_layers": 24,
|
| 19 |
+
"pad_token_id": 1,
|
| 20 |
+
"projection_dim": 1024,
|
| 21 |
+
"torch_dtype": "float32",
|
| 22 |
+
"vocab_size": 49408
|
| 23 |
+
}
|
ComfyUI/comfy/text_encoders/sd3_clip.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from comfy import sd1_clip
|
| 2 |
+
from comfy import sdxl_clip
|
| 3 |
+
from transformers import T5TokenizerFast
|
| 4 |
+
import comfy.text_encoders.t5
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
import comfy.model_management
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
class T5XXLModel(sd1_clip.SDClipModel):
|
| 11 |
+
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
| 12 |
+
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
| 13 |
+
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
| 14 |
+
if t5xxl_scaled_fp8 is not None:
|
| 15 |
+
model_options = model_options.copy()
|
| 16 |
+
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
| 17 |
+
|
| 18 |
+
model_options = {**model_options, "model_name": "t5xxl"}
|
| 19 |
+
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def t5_xxl_detect(state_dict, prefix=""):
|
| 23 |
+
out = {}
|
| 24 |
+
t5_key = "{}encoder.final_layer_norm.weight".format(prefix)
|
| 25 |
+
if t5_key in state_dict:
|
| 26 |
+
out["dtype_t5"] = state_dict[t5_key].dtype
|
| 27 |
+
|
| 28 |
+
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
| 29 |
+
if scaled_fp8_key in state_dict:
|
| 30 |
+
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
| 31 |
+
|
| 32 |
+
return out
|
| 33 |
+
|
| 34 |
+
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
| 35 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77, max_length=99999999):
|
| 36 |
+
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
| 37 |
+
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=max_length, min_length=min_length, tokenizer_data=tokenizer_data)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SD3Tokenizer:
|
| 41 |
+
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
| 42 |
+
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
| 43 |
+
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
| 44 |
+
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
| 45 |
+
|
| 46 |
+
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
| 47 |
+
out = {}
|
| 48 |
+
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 49 |
+
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 50 |
+
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
| 51 |
+
return out
|
| 52 |
+
|
| 53 |
+
def untokenize(self, token_weight_pair):
|
| 54 |
+
return self.clip_g.untokenize(token_weight_pair)
|
| 55 |
+
|
| 56 |
+
def state_dict(self):
|
| 57 |
+
return {}
|
| 58 |
+
|
| 59 |
+
class SD3ClipModel(torch.nn.Module):
|
| 60 |
+
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.dtypes = set()
|
| 63 |
+
if clip_l:
|
| 64 |
+
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
| 65 |
+
self.dtypes.add(dtype)
|
| 66 |
+
else:
|
| 67 |
+
self.clip_l = None
|
| 68 |
+
|
| 69 |
+
if clip_g:
|
| 70 |
+
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
| 71 |
+
self.dtypes.add(dtype)
|
| 72 |
+
else:
|
| 73 |
+
self.clip_g = None
|
| 74 |
+
|
| 75 |
+
if t5:
|
| 76 |
+
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
| 77 |
+
self.t5_attention_mask = t5_attention_mask
|
| 78 |
+
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=self.t5_attention_mask)
|
| 79 |
+
self.dtypes.add(dtype_t5)
|
| 80 |
+
else:
|
| 81 |
+
self.t5xxl = None
|
| 82 |
+
|
| 83 |
+
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
|
| 84 |
+
|
| 85 |
+
def set_clip_options(self, options):
|
| 86 |
+
if self.clip_l is not None:
|
| 87 |
+
self.clip_l.set_clip_options(options)
|
| 88 |
+
if self.clip_g is not None:
|
| 89 |
+
self.clip_g.set_clip_options(options)
|
| 90 |
+
if self.t5xxl is not None:
|
| 91 |
+
self.t5xxl.set_clip_options(options)
|
| 92 |
+
|
| 93 |
+
def reset_clip_options(self):
|
| 94 |
+
if self.clip_l is not None:
|
| 95 |
+
self.clip_l.reset_clip_options()
|
| 96 |
+
if self.clip_g is not None:
|
| 97 |
+
self.clip_g.reset_clip_options()
|
| 98 |
+
if self.t5xxl is not None:
|
| 99 |
+
self.t5xxl.reset_clip_options()
|
| 100 |
+
|
| 101 |
+
def encode_token_weights(self, token_weight_pairs):
|
| 102 |
+
token_weight_pairs_l = token_weight_pairs["l"]
|
| 103 |
+
token_weight_pairs_g = token_weight_pairs["g"]
|
| 104 |
+
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
|
| 105 |
+
lg_out = None
|
| 106 |
+
pooled = None
|
| 107 |
+
out = None
|
| 108 |
+
extra = {}
|
| 109 |
+
|
| 110 |
+
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
| 111 |
+
if self.clip_l is not None:
|
| 112 |
+
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
| 113 |
+
else:
|
| 114 |
+
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
|
| 115 |
+
|
| 116 |
+
if self.clip_g is not None:
|
| 117 |
+
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
| 118 |
+
if lg_out is not None:
|
| 119 |
+
cut_to = min(lg_out.shape[1], g_out.shape[1])
|
| 120 |
+
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
|
| 121 |
+
else:
|
| 122 |
+
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
| 123 |
+
else:
|
| 124 |
+
g_out = None
|
| 125 |
+
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
|
| 126 |
+
|
| 127 |
+
if lg_out is not None:
|
| 128 |
+
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
| 129 |
+
out = lg_out
|
| 130 |
+
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
| 131 |
+
|
| 132 |
+
if self.t5xxl is not None:
|
| 133 |
+
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
| 134 |
+
t5_out, t5_pooled = t5_output[:2]
|
| 135 |
+
if self.t5_attention_mask:
|
| 136 |
+
extra["attention_mask"] = t5_output[2]["attention_mask"]
|
| 137 |
+
|
| 138 |
+
if lg_out is not None:
|
| 139 |
+
out = torch.cat([lg_out, t5_out], dim=-2)
|
| 140 |
+
else:
|
| 141 |
+
out = t5_out
|
| 142 |
+
|
| 143 |
+
if out is None:
|
| 144 |
+
out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device())
|
| 145 |
+
|
| 146 |
+
if pooled is None:
|
| 147 |
+
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
| 148 |
+
|
| 149 |
+
return out, pooled, extra
|
| 150 |
+
|
| 151 |
+
def load_sd(self, sd):
|
| 152 |
+
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
| 153 |
+
return self.clip_g.load_sd(sd)
|
| 154 |
+
elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
| 155 |
+
return self.clip_l.load_sd(sd)
|
| 156 |
+
else:
|
| 157 |
+
return self.t5xxl.load_sd(sd)
|
| 158 |
+
|
| 159 |
+
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
|
| 160 |
+
class SD3ClipModel_(SD3ClipModel):
|
| 161 |
+
def __init__(self, device="cpu", dtype=None, model_options={}):
|
| 162 |
+
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
| 163 |
+
model_options = model_options.copy()
|
| 164 |
+
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
| 165 |
+
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
| 166 |
+
return SD3ClipModel_
|
ComfyUI/comfy/text_encoders/t5.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
from comfy.ldm.modules.attention import optimized_attention_for_device
|
| 4 |
+
import comfy.ops
|
| 5 |
+
|
| 6 |
+
class T5LayerNorm(torch.nn.Module):
|
| 7 |
+
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
|
| 10 |
+
self.variance_epsilon = eps
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 14 |
+
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
| 15 |
+
return comfy.ops.cast_to_input(self.weight, x) * x
|
| 16 |
+
|
| 17 |
+
activations = {
|
| 18 |
+
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
| 19 |
+
"relu": torch.nn.functional.relu,
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
class T5DenseActDense(torch.nn.Module):
|
| 23 |
+
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
| 26 |
+
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
| 27 |
+
# self.dropout = nn.Dropout(config.dropout_rate)
|
| 28 |
+
self.act = activations[ff_activation]
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
x = self.act(self.wi(x))
|
| 32 |
+
# x = self.dropout(x)
|
| 33 |
+
x = self.wo(x)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
class T5DenseGatedActDense(torch.nn.Module):
|
| 37 |
+
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
| 40 |
+
self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
| 41 |
+
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
| 42 |
+
# self.dropout = nn.Dropout(config.dropout_rate)
|
| 43 |
+
self.act = activations[ff_activation]
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
hidden_gelu = self.act(self.wi_0(x))
|
| 47 |
+
hidden_linear = self.wi_1(x)
|
| 48 |
+
x = hidden_gelu * hidden_linear
|
| 49 |
+
# x = self.dropout(x)
|
| 50 |
+
x = self.wo(x)
|
| 51 |
+
return x
|
| 52 |
+
|
| 53 |
+
class T5LayerFF(torch.nn.Module):
|
| 54 |
+
def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations):
|
| 55 |
+
super().__init__()
|
| 56 |
+
if gated_act:
|
| 57 |
+
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
|
| 58 |
+
else:
|
| 59 |
+
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
|
| 60 |
+
|
| 61 |
+
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
| 62 |
+
# self.dropout = nn.Dropout(config.dropout_rate)
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
forwarded_states = self.layer_norm(x)
|
| 66 |
+
forwarded_states = self.DenseReluDense(forwarded_states)
|
| 67 |
+
# x = x + self.dropout(forwarded_states)
|
| 68 |
+
x += forwarded_states
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
class T5Attention(torch.nn.Module):
|
| 72 |
+
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device, operations):
|
| 73 |
+
super().__init__()
|
| 74 |
+
|
| 75 |
+
# Mesh TensorFlow initialization to avoid scaling before softmax
|
| 76 |
+
self.q = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
| 77 |
+
self.k = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
| 78 |
+
self.v = operations.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
| 79 |
+
self.o = operations.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
| 80 |
+
self.num_heads = num_heads
|
| 81 |
+
|
| 82 |
+
self.relative_attention_bias = None
|
| 83 |
+
if relative_attention_bias:
|
| 84 |
+
self.relative_attention_num_buckets = 32
|
| 85 |
+
self.relative_attention_max_distance = 128
|
| 86 |
+
self.relative_attention_bias = operations.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=dtype)
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
| 90 |
+
"""
|
| 91 |
+
Adapted from Mesh Tensorflow:
|
| 92 |
+
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
| 93 |
+
|
| 94 |
+
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
| 95 |
+
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
| 96 |
+
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
| 97 |
+
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
| 98 |
+
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
| 99 |
+
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
relative_position: an int32 Tensor
|
| 103 |
+
bidirectional: a boolean - whether the attention is bidirectional
|
| 104 |
+
num_buckets: an integer
|
| 105 |
+
max_distance: an integer
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
| 109 |
+
"""
|
| 110 |
+
relative_buckets = 0
|
| 111 |
+
if bidirectional:
|
| 112 |
+
num_buckets //= 2
|
| 113 |
+
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
| 114 |
+
relative_position = torch.abs(relative_position)
|
| 115 |
+
else:
|
| 116 |
+
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
| 117 |
+
# now relative_position is in the range [0, inf)
|
| 118 |
+
|
| 119 |
+
# half of the buckets are for exact increments in positions
|
| 120 |
+
max_exact = num_buckets // 2
|
| 121 |
+
is_small = relative_position < max_exact
|
| 122 |
+
|
| 123 |
+
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
| 124 |
+
relative_position_if_large = max_exact + (
|
| 125 |
+
torch.log(relative_position.float() / max_exact)
|
| 126 |
+
/ math.log(max_distance / max_exact)
|
| 127 |
+
* (num_buckets - max_exact)
|
| 128 |
+
).to(torch.long)
|
| 129 |
+
relative_position_if_large = torch.min(
|
| 130 |
+
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
| 134 |
+
return relative_buckets
|
| 135 |
+
|
| 136 |
+
def compute_bias(self, query_length, key_length, device, dtype):
|
| 137 |
+
"""Compute binned relative position bias"""
|
| 138 |
+
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
| 139 |
+
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
| 140 |
+
relative_position = memory_position - context_position # shape (query_length, key_length)
|
| 141 |
+
relative_position_bucket = self._relative_position_bucket(
|
| 142 |
+
relative_position, # shape (query_length, key_length)
|
| 143 |
+
bidirectional=True,
|
| 144 |
+
num_buckets=self.relative_attention_num_buckets,
|
| 145 |
+
max_distance=self.relative_attention_max_distance,
|
| 146 |
+
)
|
| 147 |
+
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
|
| 148 |
+
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
| 149 |
+
return values.contiguous()
|
| 150 |
+
|
| 151 |
+
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
| 152 |
+
q = self.q(x)
|
| 153 |
+
k = self.k(x)
|
| 154 |
+
v = self.v(x)
|
| 155 |
+
if self.relative_attention_bias is not None:
|
| 156 |
+
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device, x.dtype)
|
| 157 |
+
|
| 158 |
+
if past_bias is not None:
|
| 159 |
+
if mask is not None:
|
| 160 |
+
mask = mask + past_bias
|
| 161 |
+
else:
|
| 162 |
+
mask = past_bias
|
| 163 |
+
|
| 164 |
+
out = optimized_attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
|
| 165 |
+
return self.o(out), past_bias
|
| 166 |
+
|
| 167 |
+
class T5LayerSelfAttention(torch.nn.Module):
|
| 168 |
+
def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device, operations)
|
| 171 |
+
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
| 172 |
+
# self.dropout = nn.Dropout(config.dropout_rate)
|
| 173 |
+
|
| 174 |
+
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
| 175 |
+
output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias, optimized_attention=optimized_attention)
|
| 176 |
+
# x = x + self.dropout(attention_output)
|
| 177 |
+
x += output
|
| 178 |
+
return x, past_bias
|
| 179 |
+
|
| 180 |
+
class T5Block(torch.nn.Module):
|
| 181 |
+
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias, dtype, device, operations):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.layer = torch.nn.ModuleList()
|
| 184 |
+
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations))
|
| 185 |
+
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations))
|
| 186 |
+
|
| 187 |
+
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
| 188 |
+
x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
|
| 189 |
+
x = self.layer[-1](x)
|
| 190 |
+
return x, past_bias
|
| 191 |
+
|
| 192 |
+
class T5Stack(torch.nn.Module):
|
| 193 |
+
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention, dtype, device, operations):
|
| 194 |
+
super().__init__()
|
| 195 |
+
|
| 196 |
+
self.block = torch.nn.ModuleList(
|
| 197 |
+
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias=((not relative_attention) or (i == 0)), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
|
| 198 |
+
)
|
| 199 |
+
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
| 200 |
+
# self.dropout = nn.Dropout(config.dropout_rate)
|
| 201 |
+
|
| 202 |
+
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
| 203 |
+
mask = None
|
| 204 |
+
if attention_mask is not None:
|
| 205 |
+
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
| 206 |
+
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
| 207 |
+
|
| 208 |
+
intermediate = None
|
| 209 |
+
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
|
| 210 |
+
past_bias = None
|
| 211 |
+
|
| 212 |
+
if intermediate_output is not None:
|
| 213 |
+
if intermediate_output < 0:
|
| 214 |
+
intermediate_output = len(self.block) + intermediate_output
|
| 215 |
+
|
| 216 |
+
for i, l in enumerate(self.block):
|
| 217 |
+
x, past_bias = l(x, mask, past_bias, optimized_attention)
|
| 218 |
+
if i == intermediate_output:
|
| 219 |
+
intermediate = x.clone()
|
| 220 |
+
x = self.final_layer_norm(x)
|
| 221 |
+
if intermediate is not None and final_layer_norm_intermediate:
|
| 222 |
+
intermediate = self.final_layer_norm(intermediate)
|
| 223 |
+
return x, intermediate
|
| 224 |
+
|
| 225 |
+
class T5(torch.nn.Module):
|
| 226 |
+
def __init__(self, config_dict, dtype, device, operations):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.num_layers = config_dict["num_layers"]
|
| 229 |
+
model_dim = config_dict["d_model"]
|
| 230 |
+
inner_dim = config_dict["d_kv"] * config_dict["num_heads"]
|
| 231 |
+
|
| 232 |
+
self.encoder = T5Stack(self.num_layers, model_dim, inner_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
|
| 233 |
+
self.dtype = dtype
|
| 234 |
+
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)
|
| 235 |
+
|
| 236 |
+
def get_input_embeddings(self):
|
| 237 |
+
return self.shared
|
| 238 |
+
|
| 239 |
+
def set_input_embeddings(self, embeddings):
|
| 240 |
+
self.shared = embeddings
|
| 241 |
+
|
| 242 |
+
def forward(self, input_ids, attention_mask, embeds=None, num_tokens=None, **kwargs):
|
| 243 |
+
if input_ids is None:
|
| 244 |
+
x = embeds
|
| 245 |
+
else:
|
| 246 |
+
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
|
| 247 |
+
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
| 248 |
+
x = torch.nan_to_num(x) #Fix for fp8 T5 base
|
| 249 |
+
return self.encoder(x, attention_mask=attention_mask, **kwargs)
|
ComfyUI/comfy/text_encoders/t5_config_base.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"d_ff": 3072,
|
| 3 |
+
"d_kv": 64,
|
| 4 |
+
"d_model": 768,
|
| 5 |
+
"decoder_start_token_id": 0,
|
| 6 |
+
"dropout_rate": 0.1,
|
| 7 |
+
"eos_token_id": 1,
|
| 8 |
+
"dense_act_fn": "relu",
|
| 9 |
+
"initializer_factor": 1.0,
|
| 10 |
+
"is_encoder_decoder": true,
|
| 11 |
+
"is_gated_act": false,
|
| 12 |
+
"layer_norm_epsilon": 1e-06,
|
| 13 |
+
"model_type": "t5",
|
| 14 |
+
"num_decoder_layers": 12,
|
| 15 |
+
"num_heads": 12,
|
| 16 |
+
"num_layers": 12,
|
| 17 |
+
"output_past": true,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"relative_attention_num_buckets": 32,
|
| 20 |
+
"tie_word_embeddings": false,
|
| 21 |
+
"vocab_size": 32128
|
| 22 |
+
}
|
ComfyUI/comfy/text_encoders/t5_config_xxl.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"d_ff": 10240,
|
| 3 |
+
"d_kv": 64,
|
| 4 |
+
"d_model": 4096,
|
| 5 |
+
"decoder_start_token_id": 0,
|
| 6 |
+
"dropout_rate": 0.1,
|
| 7 |
+
"eos_token_id": 1,
|
| 8 |
+
"dense_act_fn": "gelu_pytorch_tanh",
|
| 9 |
+
"initializer_factor": 1.0,
|
| 10 |
+
"is_encoder_decoder": true,
|
| 11 |
+
"is_gated_act": true,
|
| 12 |
+
"layer_norm_epsilon": 1e-06,
|
| 13 |
+
"model_type": "t5",
|
| 14 |
+
"num_decoder_layers": 24,
|
| 15 |
+
"num_heads": 64,
|
| 16 |
+
"num_layers": 24,
|
| 17 |
+
"output_past": true,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"relative_attention_num_buckets": 32,
|
| 20 |
+
"tie_word_embeddings": false,
|
| 21 |
+
"vocab_size": 32128
|
| 22 |
+
}
|
ComfyUI/comfy/text_encoders/t5_old_config_xxl.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"d_ff": 65536,
|
| 3 |
+
"d_kv": 128,
|
| 4 |
+
"d_model": 1024,
|
| 5 |
+
"decoder_start_token_id": 0,
|
| 6 |
+
"dropout_rate": 0.1,
|
| 7 |
+
"eos_token_id": 1,
|
| 8 |
+
"dense_act_fn": "relu",
|
| 9 |
+
"initializer_factor": 1.0,
|
| 10 |
+
"is_encoder_decoder": true,
|
| 11 |
+
"is_gated_act": false,
|
| 12 |
+
"layer_norm_epsilon": 1e-06,
|
| 13 |
+
"model_type": "t5",
|
| 14 |
+
"num_decoder_layers": 24,
|
| 15 |
+
"num_heads": 128,
|
| 16 |
+
"num_layers": 24,
|
| 17 |
+
"output_past": true,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"relative_attention_num_buckets": 32,
|
| 20 |
+
"tie_word_embeddings": false,
|
| 21 |
+
"vocab_size": 32128
|
| 22 |
+
}
|
ComfyUI/comfy/text_encoders/umt5_config_base.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"d_ff": 2048,
|
| 3 |
+
"d_kv": 64,
|
| 4 |
+
"d_model": 768,
|
| 5 |
+
"decoder_start_token_id": 0,
|
| 6 |
+
"dropout_rate": 0.1,
|
| 7 |
+
"eos_token_id": 1,
|
| 8 |
+
"dense_act_fn": "gelu_pytorch_tanh",
|
| 9 |
+
"initializer_factor": 1.0,
|
| 10 |
+
"is_encoder_decoder": true,
|
| 11 |
+
"is_gated_act": true,
|
| 12 |
+
"layer_norm_epsilon": 1e-06,
|
| 13 |
+
"model_type": "umt5",
|
| 14 |
+
"num_decoder_layers": 12,
|
| 15 |
+
"num_heads": 12,
|
| 16 |
+
"num_layers": 12,
|
| 17 |
+
"output_past": true,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"relative_attention_num_buckets": 32,
|
| 20 |
+
"tie_word_embeddings": false,
|
| 21 |
+
"vocab_size": 256384
|
| 22 |
+
}
|
ComfyUI/comfy/utils.py
ADDED
|
@@ -0,0 +1,1104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is part of ComfyUI.
|
| 3 |
+
Copyright (C) 2024 Comfy
|
| 4 |
+
|
| 5 |
+
This program is free software: you can redistribute it and/or modify
|
| 6 |
+
it under the terms of the GNU General Public License as published by
|
| 7 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 8 |
+
(at your option) any later version.
|
| 9 |
+
|
| 10 |
+
This program is distributed in the hope that it will be useful,
|
| 11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 13 |
+
GNU General Public License for more details.
|
| 14 |
+
|
| 15 |
+
You should have received a copy of the GNU General Public License
|
| 16 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import math
|
| 22 |
+
import struct
|
| 23 |
+
import comfy.checkpoint_pickle
|
| 24 |
+
import safetensors.torch
|
| 25 |
+
import numpy as np
|
| 26 |
+
from PIL import Image
|
| 27 |
+
import logging
|
| 28 |
+
import itertools
|
| 29 |
+
from torch.nn.functional import interpolate
|
| 30 |
+
from einops import rearrange
|
| 31 |
+
from comfy.cli_args import args
|
| 32 |
+
|
| 33 |
+
MMAP_TORCH_FILES = args.mmap_torch_files
|
| 34 |
+
DISABLE_MMAP = args.disable_mmap
|
| 35 |
+
|
| 36 |
+
ALWAYS_SAFE_LOAD = False
|
| 37 |
+
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
| 38 |
+
class ModelCheckpoint:
|
| 39 |
+
pass
|
| 40 |
+
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
| 41 |
+
|
| 42 |
+
from numpy.core.multiarray import scalar
|
| 43 |
+
from numpy import dtype
|
| 44 |
+
from numpy.dtypes import Float64DType
|
| 45 |
+
from _codecs import encode
|
| 46 |
+
|
| 47 |
+
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
| 48 |
+
ALWAYS_SAFE_LOAD = True
|
| 49 |
+
logging.info("Checkpoint files will always be loaded safely.")
|
| 50 |
+
else:
|
| 51 |
+
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
| 52 |
+
|
| 53 |
+
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
| 54 |
+
if device is None:
|
| 55 |
+
device = torch.device("cpu")
|
| 56 |
+
metadata = None
|
| 57 |
+
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
| 58 |
+
try:
|
| 59 |
+
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
| 60 |
+
sd = {}
|
| 61 |
+
for k in f.keys():
|
| 62 |
+
tensor = f.get_tensor(k)
|
| 63 |
+
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
| 64 |
+
tensor = tensor.to(device=device, copy=True)
|
| 65 |
+
sd[k] = tensor
|
| 66 |
+
if return_metadata:
|
| 67 |
+
metadata = f.metadata()
|
| 68 |
+
except Exception as e:
|
| 69 |
+
if len(e.args) > 0:
|
| 70 |
+
message = e.args[0]
|
| 71 |
+
if "HeaderTooLarge" in message:
|
| 72 |
+
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt))
|
| 73 |
+
if "MetadataIncompleteBuffer" in message:
|
| 74 |
+
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
|
| 75 |
+
raise e
|
| 76 |
+
else:
|
| 77 |
+
torch_args = {}
|
| 78 |
+
if MMAP_TORCH_FILES:
|
| 79 |
+
torch_args["mmap"] = True
|
| 80 |
+
|
| 81 |
+
if safe_load or ALWAYS_SAFE_LOAD:
|
| 82 |
+
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
| 83 |
+
else:
|
| 84 |
+
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
| 85 |
+
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
| 86 |
+
if "state_dict" in pl_sd:
|
| 87 |
+
sd = pl_sd["state_dict"]
|
| 88 |
+
else:
|
| 89 |
+
if len(pl_sd) == 1:
|
| 90 |
+
key = list(pl_sd.keys())[0]
|
| 91 |
+
sd = pl_sd[key]
|
| 92 |
+
if not isinstance(sd, dict):
|
| 93 |
+
sd = pl_sd
|
| 94 |
+
else:
|
| 95 |
+
sd = pl_sd
|
| 96 |
+
return (sd, metadata) if return_metadata else sd
|
| 97 |
+
|
| 98 |
+
def save_torch_file(sd, ckpt, metadata=None):
|
| 99 |
+
if metadata is not None:
|
| 100 |
+
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
|
| 101 |
+
else:
|
| 102 |
+
safetensors.torch.save_file(sd, ckpt)
|
| 103 |
+
|
| 104 |
+
def calculate_parameters(sd, prefix=""):
|
| 105 |
+
params = 0
|
| 106 |
+
for k in sd.keys():
|
| 107 |
+
if k.startswith(prefix):
|
| 108 |
+
w = sd[k]
|
| 109 |
+
params += w.nelement()
|
| 110 |
+
return params
|
| 111 |
+
|
| 112 |
+
def weight_dtype(sd, prefix=""):
|
| 113 |
+
dtypes = {}
|
| 114 |
+
for k in sd.keys():
|
| 115 |
+
if k.startswith(prefix):
|
| 116 |
+
w = sd[k]
|
| 117 |
+
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
|
| 118 |
+
|
| 119 |
+
if len(dtypes) == 0:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
return max(dtypes, key=dtypes.get)
|
| 123 |
+
|
| 124 |
+
def state_dict_key_replace(state_dict, keys_to_replace):
|
| 125 |
+
for x in keys_to_replace:
|
| 126 |
+
if x in state_dict:
|
| 127 |
+
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
| 128 |
+
return state_dict
|
| 129 |
+
|
| 130 |
+
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
| 131 |
+
if filter_keys:
|
| 132 |
+
out = {}
|
| 133 |
+
else:
|
| 134 |
+
out = state_dict
|
| 135 |
+
for rp in replace_prefix:
|
| 136 |
+
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
|
| 137 |
+
for x in replace:
|
| 138 |
+
w = state_dict.pop(x[0])
|
| 139 |
+
out[x[1]] = w
|
| 140 |
+
return out
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def transformers_convert(sd, prefix_from, prefix_to, number):
|
| 144 |
+
keys_to_replace = {
|
| 145 |
+
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
| 146 |
+
"{}token_embedding.weight": "{}embeddings.token_embedding.weight",
|
| 147 |
+
"{}ln_final.weight": "{}final_layer_norm.weight",
|
| 148 |
+
"{}ln_final.bias": "{}final_layer_norm.bias",
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
for k in keys_to_replace:
|
| 152 |
+
x = k.format(prefix_from)
|
| 153 |
+
if x in sd:
|
| 154 |
+
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
|
| 155 |
+
|
| 156 |
+
resblock_to_replace = {
|
| 157 |
+
"ln_1": "layer_norm1",
|
| 158 |
+
"ln_2": "layer_norm2",
|
| 159 |
+
"mlp.c_fc": "mlp.fc1",
|
| 160 |
+
"mlp.c_proj": "mlp.fc2",
|
| 161 |
+
"attn.out_proj": "self_attn.out_proj",
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
for resblock in range(number):
|
| 165 |
+
for x in resblock_to_replace:
|
| 166 |
+
for y in ["weight", "bias"]:
|
| 167 |
+
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
| 168 |
+
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
| 169 |
+
if k in sd:
|
| 170 |
+
sd[k_to] = sd.pop(k)
|
| 171 |
+
|
| 172 |
+
for y in ["weight", "bias"]:
|
| 173 |
+
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
| 174 |
+
if k_from in sd:
|
| 175 |
+
weights = sd.pop(k_from)
|
| 176 |
+
shape_from = weights.shape[0] // 3
|
| 177 |
+
for x in range(3):
|
| 178 |
+
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
| 179 |
+
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
| 180 |
+
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
| 181 |
+
|
| 182 |
+
return sd
|
| 183 |
+
|
| 184 |
+
def clip_text_transformers_convert(sd, prefix_from, prefix_to):
|
| 185 |
+
sd = transformers_convert(sd, prefix_from, "{}text_model.".format(prefix_to), 32)
|
| 186 |
+
|
| 187 |
+
tp = "{}text_projection.weight".format(prefix_from)
|
| 188 |
+
if tp in sd:
|
| 189 |
+
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp)
|
| 190 |
+
|
| 191 |
+
tp = "{}text_projection".format(prefix_from)
|
| 192 |
+
if tp in sd:
|
| 193 |
+
sd["{}text_projection.weight".format(prefix_to)] = sd.pop(tp).transpose(0, 1).contiguous()
|
| 194 |
+
return sd
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
UNET_MAP_ATTENTIONS = {
|
| 198 |
+
"proj_in.weight",
|
| 199 |
+
"proj_in.bias",
|
| 200 |
+
"proj_out.weight",
|
| 201 |
+
"proj_out.bias",
|
| 202 |
+
"norm.weight",
|
| 203 |
+
"norm.bias",
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
TRANSFORMER_BLOCKS = {
|
| 207 |
+
"norm1.weight",
|
| 208 |
+
"norm1.bias",
|
| 209 |
+
"norm2.weight",
|
| 210 |
+
"norm2.bias",
|
| 211 |
+
"norm3.weight",
|
| 212 |
+
"norm3.bias",
|
| 213 |
+
"attn1.to_q.weight",
|
| 214 |
+
"attn1.to_k.weight",
|
| 215 |
+
"attn1.to_v.weight",
|
| 216 |
+
"attn1.to_out.0.weight",
|
| 217 |
+
"attn1.to_out.0.bias",
|
| 218 |
+
"attn2.to_q.weight",
|
| 219 |
+
"attn2.to_k.weight",
|
| 220 |
+
"attn2.to_v.weight",
|
| 221 |
+
"attn2.to_out.0.weight",
|
| 222 |
+
"attn2.to_out.0.bias",
|
| 223 |
+
"ff.net.0.proj.weight",
|
| 224 |
+
"ff.net.0.proj.bias",
|
| 225 |
+
"ff.net.2.weight",
|
| 226 |
+
"ff.net.2.bias",
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
UNET_MAP_RESNET = {
|
| 230 |
+
"in_layers.2.weight": "conv1.weight",
|
| 231 |
+
"in_layers.2.bias": "conv1.bias",
|
| 232 |
+
"emb_layers.1.weight": "time_emb_proj.weight",
|
| 233 |
+
"emb_layers.1.bias": "time_emb_proj.bias",
|
| 234 |
+
"out_layers.3.weight": "conv2.weight",
|
| 235 |
+
"out_layers.3.bias": "conv2.bias",
|
| 236 |
+
"skip_connection.weight": "conv_shortcut.weight",
|
| 237 |
+
"skip_connection.bias": "conv_shortcut.bias",
|
| 238 |
+
"in_layers.0.weight": "norm1.weight",
|
| 239 |
+
"in_layers.0.bias": "norm1.bias",
|
| 240 |
+
"out_layers.0.weight": "norm2.weight",
|
| 241 |
+
"out_layers.0.bias": "norm2.bias",
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
UNET_MAP_BASIC = {
|
| 245 |
+
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
|
| 246 |
+
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
|
| 247 |
+
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
|
| 248 |
+
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
|
| 249 |
+
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
| 250 |
+
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
| 251 |
+
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
| 252 |
+
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
| 253 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
| 254 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
| 255 |
+
("out.0.weight", "conv_norm_out.weight"),
|
| 256 |
+
("out.0.bias", "conv_norm_out.bias"),
|
| 257 |
+
("out.2.weight", "conv_out.weight"),
|
| 258 |
+
("out.2.bias", "conv_out.bias"),
|
| 259 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
| 260 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
| 261 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
| 262 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
def unet_to_diffusers(unet_config):
|
| 266 |
+
if "num_res_blocks" not in unet_config:
|
| 267 |
+
return {}
|
| 268 |
+
num_res_blocks = unet_config["num_res_blocks"]
|
| 269 |
+
channel_mult = unet_config["channel_mult"]
|
| 270 |
+
transformer_depth = unet_config["transformer_depth"][:]
|
| 271 |
+
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
| 272 |
+
num_blocks = len(channel_mult)
|
| 273 |
+
|
| 274 |
+
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
| 275 |
+
|
| 276 |
+
diffusers_unet_map = {}
|
| 277 |
+
for x in range(num_blocks):
|
| 278 |
+
n = 1 + (num_res_blocks[x] + 1) * x
|
| 279 |
+
for i in range(num_res_blocks[x]):
|
| 280 |
+
for b in UNET_MAP_RESNET:
|
| 281 |
+
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
| 282 |
+
num_transformers = transformer_depth.pop(0)
|
| 283 |
+
if num_transformers > 0:
|
| 284 |
+
for b in UNET_MAP_ATTENTIONS:
|
| 285 |
+
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
| 286 |
+
for t in range(num_transformers):
|
| 287 |
+
for b in TRANSFORMER_BLOCKS:
|
| 288 |
+
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
| 289 |
+
n += 1
|
| 290 |
+
for k in ["weight", "bias"]:
|
| 291 |
+
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
|
| 292 |
+
|
| 293 |
+
i = 0
|
| 294 |
+
for b in UNET_MAP_ATTENTIONS:
|
| 295 |
+
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
|
| 296 |
+
for t in range(transformers_mid):
|
| 297 |
+
for b in TRANSFORMER_BLOCKS:
|
| 298 |
+
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
| 299 |
+
|
| 300 |
+
for i, n in enumerate([0, 2]):
|
| 301 |
+
for b in UNET_MAP_RESNET:
|
| 302 |
+
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
| 303 |
+
|
| 304 |
+
num_res_blocks = list(reversed(num_res_blocks))
|
| 305 |
+
for x in range(num_blocks):
|
| 306 |
+
n = (num_res_blocks[x] + 1) * x
|
| 307 |
+
l = num_res_blocks[x] + 1
|
| 308 |
+
for i in range(l):
|
| 309 |
+
c = 0
|
| 310 |
+
for b in UNET_MAP_RESNET:
|
| 311 |
+
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
| 312 |
+
c += 1
|
| 313 |
+
num_transformers = transformer_depth_output.pop()
|
| 314 |
+
if num_transformers > 0:
|
| 315 |
+
c += 1
|
| 316 |
+
for b in UNET_MAP_ATTENTIONS:
|
| 317 |
+
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
| 318 |
+
for t in range(num_transformers):
|
| 319 |
+
for b in TRANSFORMER_BLOCKS:
|
| 320 |
+
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
| 321 |
+
if i == l - 1:
|
| 322 |
+
for k in ["weight", "bias"]:
|
| 323 |
+
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
| 324 |
+
n += 1
|
| 325 |
+
|
| 326 |
+
for k in UNET_MAP_BASIC:
|
| 327 |
+
diffusers_unet_map[k[1]] = k[0]
|
| 328 |
+
|
| 329 |
+
return diffusers_unet_map
|
| 330 |
+
|
| 331 |
+
def swap_scale_shift(weight):
|
| 332 |
+
shift, scale = weight.chunk(2, dim=0)
|
| 333 |
+
new_weight = torch.cat([scale, shift], dim=0)
|
| 334 |
+
return new_weight
|
| 335 |
+
|
| 336 |
+
MMDIT_MAP_BASIC = {
|
| 337 |
+
("context_embedder.bias", "context_embedder.bias"),
|
| 338 |
+
("context_embedder.weight", "context_embedder.weight"),
|
| 339 |
+
("t_embedder.mlp.0.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
| 340 |
+
("t_embedder.mlp.0.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
| 341 |
+
("t_embedder.mlp.2.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
| 342 |
+
("t_embedder.mlp.2.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
| 343 |
+
("x_embedder.proj.bias", "pos_embed.proj.bias"),
|
| 344 |
+
("x_embedder.proj.weight", "pos_embed.proj.weight"),
|
| 345 |
+
("y_embedder.mlp.0.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
| 346 |
+
("y_embedder.mlp.0.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
| 347 |
+
("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
| 348 |
+
("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
| 349 |
+
("pos_embed", "pos_embed.pos_embed"),
|
| 350 |
+
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
| 351 |
+
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
| 352 |
+
("final_layer.linear.bias", "proj_out.bias"),
|
| 353 |
+
("final_layer.linear.weight", "proj_out.weight"),
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
MMDIT_MAP_BLOCK = {
|
| 357 |
+
("context_block.adaLN_modulation.1.bias", "norm1_context.linear.bias"),
|
| 358 |
+
("context_block.adaLN_modulation.1.weight", "norm1_context.linear.weight"),
|
| 359 |
+
("context_block.attn.proj.bias", "attn.to_add_out.bias"),
|
| 360 |
+
("context_block.attn.proj.weight", "attn.to_add_out.weight"),
|
| 361 |
+
("context_block.mlp.fc1.bias", "ff_context.net.0.proj.bias"),
|
| 362 |
+
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
|
| 363 |
+
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
|
| 364 |
+
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
|
| 365 |
+
("context_block.attn.ln_q.weight", "attn.norm_added_q.weight"),
|
| 366 |
+
("context_block.attn.ln_k.weight", "attn.norm_added_k.weight"),
|
| 367 |
+
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
|
| 368 |
+
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
|
| 369 |
+
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
|
| 370 |
+
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
|
| 371 |
+
("x_block.attn.ln_q.weight", "attn.norm_q.weight"),
|
| 372 |
+
("x_block.attn.ln_k.weight", "attn.norm_k.weight"),
|
| 373 |
+
("x_block.attn2.proj.bias", "attn2.to_out.0.bias"),
|
| 374 |
+
("x_block.attn2.proj.weight", "attn2.to_out.0.weight"),
|
| 375 |
+
("x_block.attn2.ln_q.weight", "attn2.norm_q.weight"),
|
| 376 |
+
("x_block.attn2.ln_k.weight", "attn2.norm_k.weight"),
|
| 377 |
+
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
|
| 378 |
+
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
|
| 379 |
+
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
|
| 380 |
+
("x_block.mlp.fc2.weight", "ff.net.2.weight"),
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
| 384 |
+
key_map = {}
|
| 385 |
+
|
| 386 |
+
depth = mmdit_config.get("depth", 0)
|
| 387 |
+
num_blocks = mmdit_config.get("num_blocks", depth)
|
| 388 |
+
for i in range(num_blocks):
|
| 389 |
+
block_from = "transformer_blocks.{}".format(i)
|
| 390 |
+
block_to = "{}joint_blocks.{}".format(output_prefix, i)
|
| 391 |
+
|
| 392 |
+
offset = depth * 64
|
| 393 |
+
|
| 394 |
+
for end in ("weight", "bias"):
|
| 395 |
+
k = "{}.attn.".format(block_from)
|
| 396 |
+
qkv = "{}.x_block.attn.qkv.{}".format(block_to, end)
|
| 397 |
+
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
|
| 398 |
+
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
|
| 399 |
+
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
| 400 |
+
|
| 401 |
+
qkv = "{}.context_block.attn.qkv.{}".format(block_to, end)
|
| 402 |
+
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, offset))
|
| 403 |
+
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
|
| 404 |
+
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
| 405 |
+
|
| 406 |
+
k = "{}.attn2.".format(block_from)
|
| 407 |
+
qkv = "{}.x_block.attn2.qkv.{}".format(block_to, end)
|
| 408 |
+
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
|
| 409 |
+
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
|
| 410 |
+
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
|
| 411 |
+
|
| 412 |
+
for k in MMDIT_MAP_BLOCK:
|
| 413 |
+
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
| 414 |
+
|
| 415 |
+
map_basic = MMDIT_MAP_BASIC.copy()
|
| 416 |
+
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift))
|
| 417 |
+
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift))
|
| 418 |
+
|
| 419 |
+
for k in map_basic:
|
| 420 |
+
if len(k) > 2:
|
| 421 |
+
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
| 422 |
+
else:
|
| 423 |
+
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
| 424 |
+
|
| 425 |
+
return key_map
|
| 426 |
+
|
| 427 |
+
PIXART_MAP_BASIC = {
|
| 428 |
+
("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
|
| 429 |
+
("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
|
| 430 |
+
("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
|
| 431 |
+
("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
|
| 432 |
+
("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
|
| 433 |
+
("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
|
| 434 |
+
("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
|
| 435 |
+
("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
|
| 436 |
+
("x_embedder.proj.weight", "pos_embed.proj.weight"),
|
| 437 |
+
("x_embedder.proj.bias", "pos_embed.proj.bias"),
|
| 438 |
+
("y_embedder.y_embedding", "caption_projection.y_embedding"),
|
| 439 |
+
("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
|
| 440 |
+
("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
|
| 441 |
+
("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
|
| 442 |
+
("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
|
| 443 |
+
("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
|
| 444 |
+
("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
|
| 445 |
+
("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
|
| 446 |
+
("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
|
| 447 |
+
("t_block.1.weight", "adaln_single.linear.weight"),
|
| 448 |
+
("t_block.1.bias", "adaln_single.linear.bias"),
|
| 449 |
+
("final_layer.linear.weight", "proj_out.weight"),
|
| 450 |
+
("final_layer.linear.bias", "proj_out.bias"),
|
| 451 |
+
("final_layer.scale_shift_table", "scale_shift_table"),
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
PIXART_MAP_BLOCK = {
|
| 455 |
+
("scale_shift_table", "scale_shift_table"),
|
| 456 |
+
("attn.proj.weight", "attn1.to_out.0.weight"),
|
| 457 |
+
("attn.proj.bias", "attn1.to_out.0.bias"),
|
| 458 |
+
("mlp.fc1.weight", "ff.net.0.proj.weight"),
|
| 459 |
+
("mlp.fc1.bias", "ff.net.0.proj.bias"),
|
| 460 |
+
("mlp.fc2.weight", "ff.net.2.weight"),
|
| 461 |
+
("mlp.fc2.bias", "ff.net.2.bias"),
|
| 462 |
+
("cross_attn.proj.weight" ,"attn2.to_out.0.weight"),
|
| 463 |
+
("cross_attn.proj.bias" ,"attn2.to_out.0.bias"),
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
def pixart_to_diffusers(mmdit_config, output_prefix=""):
|
| 467 |
+
key_map = {}
|
| 468 |
+
|
| 469 |
+
depth = mmdit_config.get("depth", 0)
|
| 470 |
+
offset = mmdit_config.get("hidden_size", 1152)
|
| 471 |
+
|
| 472 |
+
for i in range(depth):
|
| 473 |
+
block_from = "transformer_blocks.{}".format(i)
|
| 474 |
+
block_to = "{}blocks.{}".format(output_prefix, i)
|
| 475 |
+
|
| 476 |
+
for end in ("weight", "bias"):
|
| 477 |
+
s = "{}.attn1.".format(block_from)
|
| 478 |
+
qkv = "{}.attn.qkv.{}".format(block_to, end)
|
| 479 |
+
key_map["{}to_q.{}".format(s, end)] = (qkv, (0, 0, offset))
|
| 480 |
+
key_map["{}to_k.{}".format(s, end)] = (qkv, (0, offset, offset))
|
| 481 |
+
key_map["{}to_v.{}".format(s, end)] = (qkv, (0, offset * 2, offset))
|
| 482 |
+
|
| 483 |
+
s = "{}.attn2.".format(block_from)
|
| 484 |
+
q = "{}.cross_attn.q_linear.{}".format(block_to, end)
|
| 485 |
+
kv = "{}.cross_attn.kv_linear.{}".format(block_to, end)
|
| 486 |
+
|
| 487 |
+
key_map["{}to_q.{}".format(s, end)] = q
|
| 488 |
+
key_map["{}to_k.{}".format(s, end)] = (kv, (0, 0, offset))
|
| 489 |
+
key_map["{}to_v.{}".format(s, end)] = (kv, (0, offset, offset))
|
| 490 |
+
|
| 491 |
+
for k in PIXART_MAP_BLOCK:
|
| 492 |
+
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
| 493 |
+
|
| 494 |
+
for k in PIXART_MAP_BASIC:
|
| 495 |
+
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
| 496 |
+
|
| 497 |
+
return key_map
|
| 498 |
+
|
| 499 |
+
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
| 500 |
+
n_double_layers = mmdit_config.get("n_double_layers", 0)
|
| 501 |
+
n_layers = mmdit_config.get("n_layers", 0)
|
| 502 |
+
|
| 503 |
+
key_map = {}
|
| 504 |
+
for i in range(n_layers):
|
| 505 |
+
if i < n_double_layers:
|
| 506 |
+
index = i
|
| 507 |
+
prefix_from = "joint_transformer_blocks"
|
| 508 |
+
prefix_to = "{}double_layers".format(output_prefix)
|
| 509 |
+
block_map = {
|
| 510 |
+
"attn.to_q.weight": "attn.w2q.weight",
|
| 511 |
+
"attn.to_k.weight": "attn.w2k.weight",
|
| 512 |
+
"attn.to_v.weight": "attn.w2v.weight",
|
| 513 |
+
"attn.to_out.0.weight": "attn.w2o.weight",
|
| 514 |
+
"attn.add_q_proj.weight": "attn.w1q.weight",
|
| 515 |
+
"attn.add_k_proj.weight": "attn.w1k.weight",
|
| 516 |
+
"attn.add_v_proj.weight": "attn.w1v.weight",
|
| 517 |
+
"attn.to_add_out.weight": "attn.w1o.weight",
|
| 518 |
+
"ff.linear_1.weight": "mlpX.c_fc1.weight",
|
| 519 |
+
"ff.linear_2.weight": "mlpX.c_fc2.weight",
|
| 520 |
+
"ff.out_projection.weight": "mlpX.c_proj.weight",
|
| 521 |
+
"ff_context.linear_1.weight": "mlpC.c_fc1.weight",
|
| 522 |
+
"ff_context.linear_2.weight": "mlpC.c_fc2.weight",
|
| 523 |
+
"ff_context.out_projection.weight": "mlpC.c_proj.weight",
|
| 524 |
+
"norm1.linear.weight": "modX.1.weight",
|
| 525 |
+
"norm1_context.linear.weight": "modC.1.weight",
|
| 526 |
+
}
|
| 527 |
+
else:
|
| 528 |
+
index = i - n_double_layers
|
| 529 |
+
prefix_from = "single_transformer_blocks"
|
| 530 |
+
prefix_to = "{}single_layers".format(output_prefix)
|
| 531 |
+
|
| 532 |
+
block_map = {
|
| 533 |
+
"attn.to_q.weight": "attn.w1q.weight",
|
| 534 |
+
"attn.to_k.weight": "attn.w1k.weight",
|
| 535 |
+
"attn.to_v.weight": "attn.w1v.weight",
|
| 536 |
+
"attn.to_out.0.weight": "attn.w1o.weight",
|
| 537 |
+
"norm1.linear.weight": "modCX.1.weight",
|
| 538 |
+
"ff.linear_1.weight": "mlp.c_fc1.weight",
|
| 539 |
+
"ff.linear_2.weight": "mlp.c_fc2.weight",
|
| 540 |
+
"ff.out_projection.weight": "mlp.c_proj.weight"
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
for k in block_map:
|
| 544 |
+
key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
|
| 545 |
+
|
| 546 |
+
MAP_BASIC = {
|
| 547 |
+
("positional_encoding", "pos_embed.pos_embed"),
|
| 548 |
+
("register_tokens", "register_tokens"),
|
| 549 |
+
("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
|
| 550 |
+
("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
|
| 551 |
+
("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
|
| 552 |
+
("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
|
| 553 |
+
("cond_seq_linear.weight", "context_embedder.weight"),
|
| 554 |
+
("init_x_linear.weight", "pos_embed.proj.weight"),
|
| 555 |
+
("init_x_linear.bias", "pos_embed.proj.bias"),
|
| 556 |
+
("final_linear.weight", "proj_out.weight"),
|
| 557 |
+
("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
| 558 |
+
}
|
| 559 |
+
|
| 560 |
+
for k in MAP_BASIC:
|
| 561 |
+
if len(k) > 2:
|
| 562 |
+
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
| 563 |
+
else:
|
| 564 |
+
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
| 565 |
+
|
| 566 |
+
return key_map
|
| 567 |
+
|
| 568 |
+
def flux_to_diffusers(mmdit_config, output_prefix=""):
|
| 569 |
+
n_double_layers = mmdit_config.get("depth", 0)
|
| 570 |
+
n_single_layers = mmdit_config.get("depth_single_blocks", 0)
|
| 571 |
+
hidden_size = mmdit_config.get("hidden_size", 0)
|
| 572 |
+
|
| 573 |
+
key_map = {}
|
| 574 |
+
for index in range(n_double_layers):
|
| 575 |
+
prefix_from = "transformer_blocks.{}".format(index)
|
| 576 |
+
prefix_to = "{}double_blocks.{}".format(output_prefix, index)
|
| 577 |
+
|
| 578 |
+
for end in ("weight", "bias"):
|
| 579 |
+
k = "{}.attn.".format(prefix_from)
|
| 580 |
+
qkv = "{}.img_attn.qkv.{}".format(prefix_to, end)
|
| 581 |
+
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
| 582 |
+
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
| 583 |
+
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
| 584 |
+
|
| 585 |
+
k = "{}.attn.".format(prefix_from)
|
| 586 |
+
qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end)
|
| 587 |
+
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
| 588 |
+
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
| 589 |
+
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
| 590 |
+
|
| 591 |
+
block_map = {
|
| 592 |
+
"attn.to_out.0.weight": "img_attn.proj.weight",
|
| 593 |
+
"attn.to_out.0.bias": "img_attn.proj.bias",
|
| 594 |
+
"norm1.linear.weight": "img_mod.lin.weight",
|
| 595 |
+
"norm1.linear.bias": "img_mod.lin.bias",
|
| 596 |
+
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
| 597 |
+
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
| 598 |
+
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
| 599 |
+
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
| 600 |
+
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
| 601 |
+
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
| 602 |
+
"ff.net.2.weight": "img_mlp.2.weight",
|
| 603 |
+
"ff.net.2.bias": "img_mlp.2.bias",
|
| 604 |
+
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
| 605 |
+
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
| 606 |
+
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
| 607 |
+
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
| 608 |
+
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
| 609 |
+
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
| 610 |
+
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
| 611 |
+
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
for k in block_map:
|
| 615 |
+
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
| 616 |
+
|
| 617 |
+
for index in range(n_single_layers):
|
| 618 |
+
prefix_from = "single_transformer_blocks.{}".format(index)
|
| 619 |
+
prefix_to = "{}single_blocks.{}".format(output_prefix, index)
|
| 620 |
+
|
| 621 |
+
for end in ("weight", "bias"):
|
| 622 |
+
k = "{}.attn.".format(prefix_from)
|
| 623 |
+
qkv = "{}.linear1.{}".format(prefix_to, end)
|
| 624 |
+
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
| 625 |
+
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
| 626 |
+
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
| 627 |
+
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
|
| 628 |
+
|
| 629 |
+
block_map = {
|
| 630 |
+
"norm.linear.weight": "modulation.lin.weight",
|
| 631 |
+
"norm.linear.bias": "modulation.lin.bias",
|
| 632 |
+
"proj_out.weight": "linear2.weight",
|
| 633 |
+
"proj_out.bias": "linear2.bias",
|
| 634 |
+
"attn.norm_q.weight": "norm.query_norm.scale",
|
| 635 |
+
"attn.norm_k.weight": "norm.key_norm.scale",
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
for k in block_map:
|
| 639 |
+
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
| 640 |
+
|
| 641 |
+
MAP_BASIC = {
|
| 642 |
+
("final_layer.linear.bias", "proj_out.bias"),
|
| 643 |
+
("final_layer.linear.weight", "proj_out.weight"),
|
| 644 |
+
("img_in.bias", "x_embedder.bias"),
|
| 645 |
+
("img_in.weight", "x_embedder.weight"),
|
| 646 |
+
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
| 647 |
+
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
| 648 |
+
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
| 649 |
+
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
| 650 |
+
("txt_in.bias", "context_embedder.bias"),
|
| 651 |
+
("txt_in.weight", "context_embedder.weight"),
|
| 652 |
+
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
| 653 |
+
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
| 654 |
+
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
| 655 |
+
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
| 656 |
+
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
| 657 |
+
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
|
| 658 |
+
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
|
| 659 |
+
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
| 660 |
+
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
| 661 |
+
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
| 662 |
+
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
|
| 663 |
+
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
for k in MAP_BASIC:
|
| 667 |
+
if len(k) > 2:
|
| 668 |
+
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
| 669 |
+
else:
|
| 670 |
+
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
| 671 |
+
|
| 672 |
+
return key_map
|
| 673 |
+
|
| 674 |
+
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
| 675 |
+
if tensor.shape[dim] > batch_size:
|
| 676 |
+
return tensor.narrow(dim, 0, batch_size)
|
| 677 |
+
elif tensor.shape[dim] < batch_size:
|
| 678 |
+
return tensor.repeat(dim * [1] + [math.ceil(batch_size / tensor.shape[dim])] + [1] * (len(tensor.shape) - 1 - dim)).narrow(dim, 0, batch_size)
|
| 679 |
+
return tensor
|
| 680 |
+
|
| 681 |
+
def resize_to_batch_size(tensor, batch_size):
|
| 682 |
+
in_batch_size = tensor.shape[0]
|
| 683 |
+
if in_batch_size == batch_size:
|
| 684 |
+
return tensor
|
| 685 |
+
|
| 686 |
+
if batch_size <= 1:
|
| 687 |
+
return tensor[:batch_size]
|
| 688 |
+
|
| 689 |
+
output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device)
|
| 690 |
+
if batch_size < in_batch_size:
|
| 691 |
+
scale = (in_batch_size - 1) / (batch_size - 1)
|
| 692 |
+
for i in range(batch_size):
|
| 693 |
+
output[i] = tensor[min(round(i * scale), in_batch_size - 1)]
|
| 694 |
+
else:
|
| 695 |
+
scale = in_batch_size / batch_size
|
| 696 |
+
for i in range(batch_size):
|
| 697 |
+
output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]
|
| 698 |
+
|
| 699 |
+
return output
|
| 700 |
+
|
| 701 |
+
def resize_list_to_batch_size(l, batch_size):
|
| 702 |
+
in_batch_size = len(l)
|
| 703 |
+
if in_batch_size == batch_size or in_batch_size == 0:
|
| 704 |
+
return l
|
| 705 |
+
|
| 706 |
+
if batch_size <= 1:
|
| 707 |
+
return l[:batch_size]
|
| 708 |
+
|
| 709 |
+
output = []
|
| 710 |
+
if batch_size < in_batch_size:
|
| 711 |
+
scale = (in_batch_size - 1) / (batch_size - 1)
|
| 712 |
+
for i in range(batch_size):
|
| 713 |
+
output.append(l[min(round(i * scale), in_batch_size - 1)])
|
| 714 |
+
else:
|
| 715 |
+
scale = in_batch_size / batch_size
|
| 716 |
+
for i in range(batch_size):
|
| 717 |
+
output.append(l[min(math.floor((i + 0.5) * scale), in_batch_size - 1)])
|
| 718 |
+
|
| 719 |
+
return output
|
| 720 |
+
|
| 721 |
+
def convert_sd_to(state_dict, dtype):
|
| 722 |
+
keys = list(state_dict.keys())
|
| 723 |
+
for k in keys:
|
| 724 |
+
state_dict[k] = state_dict[k].to(dtype)
|
| 725 |
+
return state_dict
|
| 726 |
+
|
| 727 |
+
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
| 728 |
+
with open(safetensors_path, "rb") as f:
|
| 729 |
+
header = f.read(8)
|
| 730 |
+
length_of_header = struct.unpack('<Q', header)[0]
|
| 731 |
+
if length_of_header > max_size:
|
| 732 |
+
return None
|
| 733 |
+
return f.read(length_of_header)
|
| 734 |
+
|
| 735 |
+
def set_attr(obj, attr, value):
|
| 736 |
+
attrs = attr.split(".")
|
| 737 |
+
for name in attrs[:-1]:
|
| 738 |
+
obj = getattr(obj, name)
|
| 739 |
+
prev = getattr(obj, attrs[-1])
|
| 740 |
+
setattr(obj, attrs[-1], value)
|
| 741 |
+
return prev
|
| 742 |
+
|
| 743 |
+
def set_attr_param(obj, attr, value):
|
| 744 |
+
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
| 745 |
+
|
| 746 |
+
def copy_to_param(obj, attr, value):
|
| 747 |
+
# inplace update tensor instead of replacing it
|
| 748 |
+
attrs = attr.split(".")
|
| 749 |
+
for name in attrs[:-1]:
|
| 750 |
+
obj = getattr(obj, name)
|
| 751 |
+
prev = getattr(obj, attrs[-1])
|
| 752 |
+
prev.data.copy_(value)
|
| 753 |
+
|
| 754 |
+
def get_attr(obj, attr: str):
|
| 755 |
+
"""Retrieves a nested attribute from an object using dot notation.
|
| 756 |
+
|
| 757 |
+
Args:
|
| 758 |
+
obj: The object to get the attribute from
|
| 759 |
+
attr (str): The attribute path using dot notation (e.g. "model.layer.weight")
|
| 760 |
+
|
| 761 |
+
Returns:
|
| 762 |
+
The value of the requested attribute
|
| 763 |
+
|
| 764 |
+
Example:
|
| 765 |
+
model = MyModel()
|
| 766 |
+
weight = get_attr(model, "layer1.conv.weight")
|
| 767 |
+
# Equivalent to: model.layer1.conv.weight
|
| 768 |
+
|
| 769 |
+
Important:
|
| 770 |
+
Always prefer `comfy.model_patcher.ModelPatcher.get_model_object` when
|
| 771 |
+
accessing nested model objects under `ModelPatcher.model`.
|
| 772 |
+
"""
|
| 773 |
+
attrs = attr.split(".")
|
| 774 |
+
for name in attrs:
|
| 775 |
+
obj = getattr(obj, name)
|
| 776 |
+
return obj
|
| 777 |
+
|
| 778 |
+
def bislerp(samples, width, height):
|
| 779 |
+
def slerp(b1, b2, r):
|
| 780 |
+
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
| 781 |
+
|
| 782 |
+
c = b1.shape[-1]
|
| 783 |
+
|
| 784 |
+
#norms
|
| 785 |
+
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
| 786 |
+
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
| 787 |
+
|
| 788 |
+
#normalize
|
| 789 |
+
b1_normalized = b1 / b1_norms
|
| 790 |
+
b2_normalized = b2 / b2_norms
|
| 791 |
+
|
| 792 |
+
#zero when norms are zero
|
| 793 |
+
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
| 794 |
+
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
| 795 |
+
|
| 796 |
+
#slerp
|
| 797 |
+
dot = (b1_normalized*b2_normalized).sum(1)
|
| 798 |
+
omega = torch.acos(dot)
|
| 799 |
+
so = torch.sin(omega)
|
| 800 |
+
|
| 801 |
+
#technically not mathematically correct, but more pleasing?
|
| 802 |
+
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
| 803 |
+
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
| 804 |
+
|
| 805 |
+
#edge cases for same or polar opposites
|
| 806 |
+
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
| 807 |
+
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
| 808 |
+
return res
|
| 809 |
+
|
| 810 |
+
def generate_bilinear_data(length_old, length_new, device):
|
| 811 |
+
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
|
| 812 |
+
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
| 813 |
+
ratios = coords_1 - coords_1.floor()
|
| 814 |
+
coords_1 = coords_1.to(torch.int64)
|
| 815 |
+
|
| 816 |
+
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
|
| 817 |
+
coords_2[:,:,:,-1] -= 1
|
| 818 |
+
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
| 819 |
+
coords_2 = coords_2.to(torch.int64)
|
| 820 |
+
return ratios, coords_1, coords_2
|
| 821 |
+
|
| 822 |
+
orig_dtype = samples.dtype
|
| 823 |
+
samples = samples.float()
|
| 824 |
+
n,c,h,w = samples.shape
|
| 825 |
+
h_new, w_new = (height, width)
|
| 826 |
+
|
| 827 |
+
#linear w
|
| 828 |
+
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
|
| 829 |
+
coords_1 = coords_1.expand((n, c, h, -1))
|
| 830 |
+
coords_2 = coords_2.expand((n, c, h, -1))
|
| 831 |
+
ratios = ratios.expand((n, 1, h, -1))
|
| 832 |
+
|
| 833 |
+
pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
|
| 834 |
+
pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
|
| 835 |
+
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
| 836 |
+
|
| 837 |
+
result = slerp(pass_1, pass_2, ratios)
|
| 838 |
+
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
| 839 |
+
|
| 840 |
+
#linear h
|
| 841 |
+
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
|
| 842 |
+
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
| 843 |
+
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
| 844 |
+
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
|
| 845 |
+
|
| 846 |
+
pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
|
| 847 |
+
pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
|
| 848 |
+
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
| 849 |
+
|
| 850 |
+
result = slerp(pass_1, pass_2, ratios)
|
| 851 |
+
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
| 852 |
+
return result.to(orig_dtype)
|
| 853 |
+
|
| 854 |
+
def lanczos(samples, width, height):
|
| 855 |
+
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
| 856 |
+
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
| 857 |
+
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
| 858 |
+
result = torch.stack(images)
|
| 859 |
+
return result.to(samples.device, samples.dtype)
|
| 860 |
+
|
| 861 |
+
def common_upscale(samples, width, height, upscale_method, crop):
|
| 862 |
+
orig_shape = tuple(samples.shape)
|
| 863 |
+
if len(orig_shape) > 4:
|
| 864 |
+
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
|
| 865 |
+
samples = samples.movedim(2, 1)
|
| 866 |
+
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
|
| 867 |
+
if crop == "center":
|
| 868 |
+
old_width = samples.shape[-1]
|
| 869 |
+
old_height = samples.shape[-2]
|
| 870 |
+
old_aspect = old_width / old_height
|
| 871 |
+
new_aspect = width / height
|
| 872 |
+
x = 0
|
| 873 |
+
y = 0
|
| 874 |
+
if old_aspect > new_aspect:
|
| 875 |
+
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
| 876 |
+
elif old_aspect < new_aspect:
|
| 877 |
+
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
| 878 |
+
s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
|
| 879 |
+
else:
|
| 880 |
+
s = samples
|
| 881 |
+
|
| 882 |
+
if upscale_method == "bislerp":
|
| 883 |
+
out = bislerp(s, width, height)
|
| 884 |
+
elif upscale_method == "lanczos":
|
| 885 |
+
out = lanczos(s, width, height)
|
| 886 |
+
else:
|
| 887 |
+
out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
| 888 |
+
|
| 889 |
+
if len(orig_shape) == 4:
|
| 890 |
+
return out
|
| 891 |
+
|
| 892 |
+
out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
|
| 893 |
+
return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
|
| 894 |
+
|
| 895 |
+
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
| 896 |
+
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
|
| 897 |
+
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
|
| 898 |
+
return rows * cols
|
| 899 |
+
|
| 900 |
+
@torch.inference_mode()
|
| 901 |
+
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
|
| 902 |
+
dims = len(tile)
|
| 903 |
+
|
| 904 |
+
if not (isinstance(upscale_amount, (tuple, list))):
|
| 905 |
+
upscale_amount = [upscale_amount] * dims
|
| 906 |
+
|
| 907 |
+
if not (isinstance(overlap, (tuple, list))):
|
| 908 |
+
overlap = [overlap] * dims
|
| 909 |
+
|
| 910 |
+
if index_formulas is None:
|
| 911 |
+
index_formulas = upscale_amount
|
| 912 |
+
|
| 913 |
+
if not (isinstance(index_formulas, (tuple, list))):
|
| 914 |
+
index_formulas = [index_formulas] * dims
|
| 915 |
+
|
| 916 |
+
def get_upscale(dim, val):
|
| 917 |
+
up = upscale_amount[dim]
|
| 918 |
+
if callable(up):
|
| 919 |
+
return up(val)
|
| 920 |
+
else:
|
| 921 |
+
return up * val
|
| 922 |
+
|
| 923 |
+
def get_downscale(dim, val):
|
| 924 |
+
up = upscale_amount[dim]
|
| 925 |
+
if callable(up):
|
| 926 |
+
return up(val)
|
| 927 |
+
else:
|
| 928 |
+
return val / up
|
| 929 |
+
|
| 930 |
+
def get_upscale_pos(dim, val):
|
| 931 |
+
up = index_formulas[dim]
|
| 932 |
+
if callable(up):
|
| 933 |
+
return up(val)
|
| 934 |
+
else:
|
| 935 |
+
return up * val
|
| 936 |
+
|
| 937 |
+
def get_downscale_pos(dim, val):
|
| 938 |
+
up = index_formulas[dim]
|
| 939 |
+
if callable(up):
|
| 940 |
+
return up(val)
|
| 941 |
+
else:
|
| 942 |
+
return val / up
|
| 943 |
+
|
| 944 |
+
if downscale:
|
| 945 |
+
get_scale = get_downscale
|
| 946 |
+
get_pos = get_downscale_pos
|
| 947 |
+
else:
|
| 948 |
+
get_scale = get_upscale
|
| 949 |
+
get_pos = get_upscale_pos
|
| 950 |
+
|
| 951 |
+
def mult_list_upscale(a):
|
| 952 |
+
out = []
|
| 953 |
+
for i in range(len(a)):
|
| 954 |
+
out.append(round(get_scale(i, a[i])))
|
| 955 |
+
return out
|
| 956 |
+
|
| 957 |
+
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
|
| 958 |
+
|
| 959 |
+
for b in range(samples.shape[0]):
|
| 960 |
+
s = samples[b:b+1]
|
| 961 |
+
|
| 962 |
+
# handle entire input fitting in a single tile
|
| 963 |
+
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
| 964 |
+
output[b:b+1] = function(s).to(output_device)
|
| 965 |
+
if pbar is not None:
|
| 966 |
+
pbar.update(1)
|
| 967 |
+
continue
|
| 968 |
+
|
| 969 |
+
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
| 970 |
+
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
| 971 |
+
|
| 972 |
+
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
| 973 |
+
|
| 974 |
+
for it in itertools.product(*positions):
|
| 975 |
+
s_in = s
|
| 976 |
+
upscaled = []
|
| 977 |
+
|
| 978 |
+
for d in range(dims):
|
| 979 |
+
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
| 980 |
+
l = min(tile[d], s.shape[d + 2] - pos)
|
| 981 |
+
s_in = s_in.narrow(d + 2, pos, l)
|
| 982 |
+
upscaled.append(round(get_pos(d, pos)))
|
| 983 |
+
|
| 984 |
+
ps = function(s_in).to(output_device)
|
| 985 |
+
mask = torch.ones_like(ps)
|
| 986 |
+
|
| 987 |
+
for d in range(2, dims + 2):
|
| 988 |
+
feather = round(get_scale(d - 2, overlap[d - 2]))
|
| 989 |
+
if feather >= mask.shape[d]:
|
| 990 |
+
continue
|
| 991 |
+
for t in range(feather):
|
| 992 |
+
a = (t + 1) / feather
|
| 993 |
+
mask.narrow(d, t, 1).mul_(a)
|
| 994 |
+
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
| 995 |
+
|
| 996 |
+
o = out
|
| 997 |
+
o_d = out_div
|
| 998 |
+
for d in range(dims):
|
| 999 |
+
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
| 1000 |
+
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
| 1001 |
+
|
| 1002 |
+
o.add_(ps * mask)
|
| 1003 |
+
o_d.add_(mask)
|
| 1004 |
+
|
| 1005 |
+
if pbar is not None:
|
| 1006 |
+
pbar.update(1)
|
| 1007 |
+
|
| 1008 |
+
output[b:b+1] = out/out_div
|
| 1009 |
+
return output
|
| 1010 |
+
|
| 1011 |
+
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
| 1012 |
+
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
| 1013 |
+
|
| 1014 |
+
PROGRESS_BAR_ENABLED = True
|
| 1015 |
+
def set_progress_bar_enabled(enabled):
|
| 1016 |
+
global PROGRESS_BAR_ENABLED
|
| 1017 |
+
PROGRESS_BAR_ENABLED = enabled
|
| 1018 |
+
|
| 1019 |
+
PROGRESS_BAR_HOOK = None
|
| 1020 |
+
def set_progress_bar_global_hook(function):
|
| 1021 |
+
global PROGRESS_BAR_HOOK
|
| 1022 |
+
PROGRESS_BAR_HOOK = function
|
| 1023 |
+
|
| 1024 |
+
class ProgressBar:
|
| 1025 |
+
def __init__(self, total, node_id=None):
|
| 1026 |
+
global PROGRESS_BAR_HOOK
|
| 1027 |
+
self.total = total
|
| 1028 |
+
self.current = 0
|
| 1029 |
+
self.hook = PROGRESS_BAR_HOOK
|
| 1030 |
+
self.node_id = node_id
|
| 1031 |
+
|
| 1032 |
+
def update_absolute(self, value, total=None, preview=None):
|
| 1033 |
+
if total is not None:
|
| 1034 |
+
self.total = total
|
| 1035 |
+
if value > self.total:
|
| 1036 |
+
value = self.total
|
| 1037 |
+
self.current = value
|
| 1038 |
+
if self.hook is not None:
|
| 1039 |
+
self.hook(self.current, self.total, preview, node_id=self.node_id)
|
| 1040 |
+
|
| 1041 |
+
def update(self, value):
|
| 1042 |
+
self.update_absolute(self.current + value)
|
| 1043 |
+
|
| 1044 |
+
def reshape_mask(input_mask, output_shape):
|
| 1045 |
+
dims = len(output_shape) - 2
|
| 1046 |
+
|
| 1047 |
+
if dims == 1:
|
| 1048 |
+
scale_mode = "linear"
|
| 1049 |
+
|
| 1050 |
+
if dims == 2:
|
| 1051 |
+
input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
|
| 1052 |
+
scale_mode = "bilinear"
|
| 1053 |
+
|
| 1054 |
+
if dims == 3:
|
| 1055 |
+
if len(input_mask.shape) < 5:
|
| 1056 |
+
input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
|
| 1057 |
+
scale_mode = "trilinear"
|
| 1058 |
+
|
| 1059 |
+
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
|
| 1060 |
+
if mask.shape[1] < output_shape[1]:
|
| 1061 |
+
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
|
| 1062 |
+
mask = repeat_to_batch_size(mask, output_shape[0])
|
| 1063 |
+
return mask
|
| 1064 |
+
|
| 1065 |
+
def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
|
| 1066 |
+
hi, wi = img_size_in
|
| 1067 |
+
ho, wo = img_size_out
|
| 1068 |
+
# if it's already the correct size, no need to do anything
|
| 1069 |
+
if (hi, wi) == (ho, wo):
|
| 1070 |
+
return mask
|
| 1071 |
+
if mask.ndim == 2:
|
| 1072 |
+
mask = mask.unsqueeze(0)
|
| 1073 |
+
if mask.ndim != 3:
|
| 1074 |
+
raise ValueError(f"Got a mask of shape {list(mask.shape)}, expected [b, q, k] or [q, k]")
|
| 1075 |
+
txt_tokens = mask.shape[1] - (hi * wi)
|
| 1076 |
+
# quadrants of the mask
|
| 1077 |
+
txt_to_txt = mask[:, :txt_tokens, :txt_tokens]
|
| 1078 |
+
txt_to_img = mask[:, :txt_tokens, txt_tokens:]
|
| 1079 |
+
img_to_img = mask[:, txt_tokens:, txt_tokens:]
|
| 1080 |
+
img_to_txt = mask[:, txt_tokens:, :txt_tokens]
|
| 1081 |
+
|
| 1082 |
+
# convert to 1d x 2d, interpolate, then back to 1d x 1d
|
| 1083 |
+
txt_to_img = rearrange (txt_to_img, "b t (h w) -> b t h w", h=hi, w=wi)
|
| 1084 |
+
txt_to_img = interpolate(txt_to_img, size=img_size_out, mode="bilinear")
|
| 1085 |
+
txt_to_img = rearrange (txt_to_img, "b t h w -> b t (h w)")
|
| 1086 |
+
# this one is hard because we have to do it twice
|
| 1087 |
+
# convert to 1d x 2d, interpolate, then to 2d x 1d, interpolate, then 1d x 1d
|
| 1088 |
+
img_to_img = rearrange (img_to_img, "b hw (h w) -> b hw h w", h=hi, w=wi)
|
| 1089 |
+
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
|
| 1090 |
+
img_to_img = rearrange (img_to_img, "b (hk wk) hq wq -> b (hq wq) hk wk", hk=hi, wk=wi)
|
| 1091 |
+
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
|
| 1092 |
+
img_to_img = rearrange (img_to_img, "b (hq wq) hk wk -> b (hk wk) (hq wq)", hq=ho, wq=wo)
|
| 1093 |
+
# convert to 2d x 1d, interpolate, then back to 1d x 1d
|
| 1094 |
+
img_to_txt = rearrange (img_to_txt, "b (h w) t -> b t h w", h=hi, w=wi)
|
| 1095 |
+
img_to_txt = interpolate(img_to_txt, size=img_size_out, mode="bilinear")
|
| 1096 |
+
img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t")
|
| 1097 |
+
|
| 1098 |
+
# reassemble the mask from blocks
|
| 1099 |
+
out = torch.cat([
|
| 1100 |
+
torch.cat([txt_to_txt, txt_to_img], dim=2),
|
| 1101 |
+
torch.cat([img_to_txt, img_to_img], dim=2)],
|
| 1102 |
+
dim=1
|
| 1103 |
+
)
|
| 1104 |
+
return out
|
ComfyUI/comfy/weight_adapter/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
| 2 |
+
from .lora import LoRAAdapter
|
| 3 |
+
from .loha import LoHaAdapter
|
| 4 |
+
from .lokr import LoKrAdapter
|
| 5 |
+
from .glora import GLoRAAdapter
|
| 6 |
+
from .oft import OFTAdapter
|
| 7 |
+
from .boft import BOFTAdapter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
adapters: list[type[WeightAdapterBase]] = [
|
| 11 |
+
LoRAAdapter,
|
| 12 |
+
LoHaAdapter,
|
| 13 |
+
LoKrAdapter,
|
| 14 |
+
GLoRAAdapter,
|
| 15 |
+
OFTAdapter,
|
| 16 |
+
BOFTAdapter,
|
| 17 |
+
]
|
| 18 |
+
adapter_maps: dict[str, type[WeightAdapterBase]] = {
|
| 19 |
+
"LoRA": LoRAAdapter,
|
| 20 |
+
"LoHa": LoHaAdapter,
|
| 21 |
+
"LoKr": LoKrAdapter,
|
| 22 |
+
"OFT": OFTAdapter,
|
| 23 |
+
## We disable not implemented algo for now
|
| 24 |
+
# "GLoRA": GLoRAAdapter,
|
| 25 |
+
# "BOFT": BOFTAdapter,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
"WeightAdapterBase",
|
| 31 |
+
"WeightAdapterTrainBase",
|
| 32 |
+
"adapters",
|
| 33 |
+
"adapter_maps",
|
| 34 |
+
] + [a.__name__ for a in adapters]
|
ComfyUI/comfy/weight_adapter/boft.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import comfy.model_management
|
| 6 |
+
from .base import WeightAdapterBase, weight_decompose
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BOFTAdapter(WeightAdapterBase):
|
| 10 |
+
name = "boft"
|
| 11 |
+
|
| 12 |
+
def __init__(self, loaded_keys, weights):
|
| 13 |
+
self.loaded_keys = loaded_keys
|
| 14 |
+
self.weights = weights
|
| 15 |
+
|
| 16 |
+
@classmethod
|
| 17 |
+
def load(
|
| 18 |
+
cls,
|
| 19 |
+
x: str,
|
| 20 |
+
lora: dict[str, torch.Tensor],
|
| 21 |
+
alpha: float,
|
| 22 |
+
dora_scale: torch.Tensor,
|
| 23 |
+
loaded_keys: set[str] = None,
|
| 24 |
+
) -> Optional["BOFTAdapter"]:
|
| 25 |
+
if loaded_keys is None:
|
| 26 |
+
loaded_keys = set()
|
| 27 |
+
blocks_name = "{}.oft_blocks".format(x)
|
| 28 |
+
rescale_name = "{}.rescale".format(x)
|
| 29 |
+
|
| 30 |
+
blocks = None
|
| 31 |
+
if blocks_name in lora.keys():
|
| 32 |
+
blocks = lora[blocks_name]
|
| 33 |
+
if blocks.ndim == 4:
|
| 34 |
+
loaded_keys.add(blocks_name)
|
| 35 |
+
else:
|
| 36 |
+
blocks = None
|
| 37 |
+
if blocks is None:
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
rescale = None
|
| 41 |
+
if rescale_name in lora.keys():
|
| 42 |
+
rescale = lora[rescale_name]
|
| 43 |
+
loaded_keys.add(rescale_name)
|
| 44 |
+
|
| 45 |
+
weights = (blocks, rescale, alpha, dora_scale)
|
| 46 |
+
return cls(loaded_keys, weights)
|
| 47 |
+
|
| 48 |
+
def calculate_weight(
|
| 49 |
+
self,
|
| 50 |
+
weight,
|
| 51 |
+
key,
|
| 52 |
+
strength,
|
| 53 |
+
strength_model,
|
| 54 |
+
offset,
|
| 55 |
+
function,
|
| 56 |
+
intermediate_dtype=torch.float32,
|
| 57 |
+
original_weight=None,
|
| 58 |
+
):
|
| 59 |
+
v = self.weights
|
| 60 |
+
blocks = v[0]
|
| 61 |
+
rescale = v[1]
|
| 62 |
+
alpha = v[2]
|
| 63 |
+
dora_scale = v[3]
|
| 64 |
+
|
| 65 |
+
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
| 66 |
+
if rescale is not None:
|
| 67 |
+
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
| 68 |
+
|
| 69 |
+
boft_m, block_num, boft_b, *_ = blocks.shape
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
# Get r
|
| 73 |
+
I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
|
| 74 |
+
# for Q = -Q^T
|
| 75 |
+
q = blocks - blocks.transpose(-1, -2)
|
| 76 |
+
normed_q = q
|
| 77 |
+
if alpha > 0: # alpha in boft/bboft is for constraint
|
| 78 |
+
q_norm = torch.norm(q) + 1e-8
|
| 79 |
+
if q_norm > alpha:
|
| 80 |
+
normed_q = q * alpha / q_norm
|
| 81 |
+
# use float() to prevent unsupported type in .inverse()
|
| 82 |
+
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
| 83 |
+
r = r.to(weight)
|
| 84 |
+
inp = org = weight
|
| 85 |
+
|
| 86 |
+
r_b = boft_b//2
|
| 87 |
+
for i in range(boft_m):
|
| 88 |
+
bi = r[i]
|
| 89 |
+
g = 2
|
| 90 |
+
k = 2**i * r_b
|
| 91 |
+
if strength != 1:
|
| 92 |
+
bi = bi * strength + (1-strength) * I
|
| 93 |
+
inp = (
|
| 94 |
+
inp.unflatten(0, (-1, g, k))
|
| 95 |
+
.transpose(1, 2)
|
| 96 |
+
.flatten(0, 2)
|
| 97 |
+
.unflatten(0, (-1, boft_b))
|
| 98 |
+
)
|
| 99 |
+
inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
|
| 100 |
+
inp = (
|
| 101 |
+
inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if rescale is not None:
|
| 105 |
+
inp = inp * rescale
|
| 106 |
+
|
| 107 |
+
lora_diff = inp - org
|
| 108 |
+
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
|
| 109 |
+
if dora_scale is not None:
|
| 110 |
+
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
| 111 |
+
else:
|
| 112 |
+
weight += function((strength * lora_diff).type(weight.dtype))
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
| 115 |
+
return weight
|
ComfyUI/comfy_api/feature_flags.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feature flags module for ComfyUI WebSocket protocol negotiation.
|
| 3 |
+
|
| 4 |
+
This module handles capability negotiation between frontend and backend,
|
| 5 |
+
allowing graceful protocol evolution while maintaining backward compatibility.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Dict
|
| 9 |
+
|
| 10 |
+
from comfy.cli_args import args
|
| 11 |
+
|
| 12 |
+
# Default server capabilities
|
| 13 |
+
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
| 14 |
+
"supports_preview_metadata": True,
|
| 15 |
+
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_connection_feature(
|
| 20 |
+
sockets_metadata: Dict[str, Dict[str, Any]],
|
| 21 |
+
sid: str,
|
| 22 |
+
feature_name: str,
|
| 23 |
+
default: Any = False
|
| 24 |
+
) -> Any:
|
| 25 |
+
"""
|
| 26 |
+
Get a feature flag value for a specific connection.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
sockets_metadata: Dictionary of socket metadata
|
| 30 |
+
sid: Session ID of the connection
|
| 31 |
+
feature_name: Name of the feature to check
|
| 32 |
+
default: Default value if feature not found
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Feature value or default if not found
|
| 36 |
+
"""
|
| 37 |
+
if sid not in sockets_metadata:
|
| 38 |
+
return default
|
| 39 |
+
|
| 40 |
+
return sockets_metadata[sid].get("feature_flags", {}).get(feature_name, default)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def supports_feature(
|
| 44 |
+
sockets_metadata: Dict[str, Dict[str, Any]],
|
| 45 |
+
sid: str,
|
| 46 |
+
feature_name: str
|
| 47 |
+
) -> bool:
|
| 48 |
+
"""
|
| 49 |
+
Check if a connection supports a specific feature.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
sockets_metadata: Dictionary of socket metadata
|
| 53 |
+
sid: Session ID of the connection
|
| 54 |
+
feature_name: Name of the feature to check
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Boolean indicating if feature is supported
|
| 58 |
+
"""
|
| 59 |
+
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_server_features() -> Dict[str, Any]:
|
| 63 |
+
"""
|
| 64 |
+
Get the server's feature flags.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Dictionary of server feature flags
|
| 68 |
+
"""
|
| 69 |
+
return SERVER_FEATURE_FLAGS.copy()
|
ComfyUI/comfy_api_nodes/README.md
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ComfyUI API Nodes
|
| 2 |
+
|
| 3 |
+
## Introduction
|
| 4 |
+
|
| 5 |
+
Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview).
|
| 6 |
+
|
| 7 |
+
## Development
|
| 8 |
+
|
| 9 |
+
While developing, you should be testing against the Staging environment. To test against staging:
|
| 10 |
+
|
| 11 |
+
**Install ComfyUI_frontend**
|
| 12 |
+
|
| 13 |
+
Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication.
|
| 14 |
+
|
| 15 |
+
> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication.
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
python run main.py --comfy-api-base https://stagingapi.comfy.org
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging.
|
| 22 |
+
|
| 23 |
+
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
|
| 24 |
+
|
| 25 |
+
### Redocly Instructions
|
| 26 |
+
|
| 27 |
+
**Tip**
|
| 28 |
+
When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet.
|
| 29 |
+
|
| 30 |
+
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
# Download the OpenAPI file from staging server.
|
| 34 |
+
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
|
| 35 |
+
|
| 36 |
+
# Filter out unneeded API definitions.
|
| 37 |
+
npm install -g @redocly/cli
|
| 38 |
+
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components
|
| 39 |
+
|
| 40 |
+
# Generate the pydantic datamodels for validation.
|
| 41 |
+
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
| 42 |
+
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Merging to Master
|
| 47 |
+
|
| 48 |
+
Before merging to comfyanonymous/ComfyUI master, follow these steps:
|
| 49 |
+
|
| 50 |
+
1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes.
|
| 51 |
+
1. Make sure the ComfyUI API is deployed to prod with your changes.
|
| 52 |
+
1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file.
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
# Download the OpenAPI file from prod server.
|
| 56 |
+
curl -o openapi.yaml https://api.comfy.org/openapi
|
| 57 |
+
|
| 58 |
+
# Filter out unneeded API definitions.
|
| 59 |
+
npm install -g @redocly/cli
|
| 60 |
+
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
|
| 61 |
+
|
| 62 |
+
# Generate the pydantic datamodels for validation.
|
| 63 |
+
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
| 64 |
+
|
| 65 |
+
```
|