| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| import comfy.utils
|
| import comfy.ops
|
| import comfy.model_management
|
| import folder_paths
|
|
|
|
|
|
|
|
|
|
|
|
|
| def conv(n_in, n_out, **kwargs):
|
| return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
|
|
|
|
| class Clamp(nn.Module):
|
| def forward(self, x):
|
| return torch.tanh(x / 3) * 3
|
|
|
|
|
| class Block(nn.Module):
|
| def __init__(self, n_in, n_out, use_midblock_gn=False):
|
| super().__init__()
|
| self.conv = nn.Sequential(
|
| conv(n_in, n_out), nn.ReLU(),
|
| conv(n_out, n_out), nn.ReLU(),
|
| conv(n_out, n_out),
|
| )
|
| self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
| self.fuse = nn.ReLU()
|
|
|
| self.pool = None
|
| if use_midblock_gn:
|
| conv1x1 = lambda a, b: comfy.ops.disable_weight_init.Conv2d(a, b, 1, bias=False)
|
| n_gn = n_in * 4
|
| self.pool = nn.Sequential(
|
| conv1x1(n_in, n_gn),
|
| comfy.ops.disable_weight_init.GroupNorm(4, n_gn),
|
| nn.ReLU(inplace=True),
|
| conv1x1(n_gn, n_in),
|
| )
|
|
|
| def forward(self, x):
|
| if self.pool is not None:
|
| x = x + self.pool(x)
|
| return self.fuse(self.conv(x) + self.skip(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def build_encoder_scale8(latent_channels, pool_blocks):
|
| def B(idx): return Block(64, 64, use_midblock_gn=(idx in pool_blocks))
|
| return nn.Sequential(
|
| conv(3, 64),
|
| B(1),
|
| conv(64, 64, stride=2, bias=False),
|
| B(3), B(4), B(5),
|
| conv(64, 64, stride=2, bias=False),
|
| B(7), B(8), B(9),
|
| conv(64, 64, stride=2, bias=False),
|
| B(11), B(12), B(13),
|
| conv(64, latent_channels),
|
| )
|
|
|
|
|
| def build_decoder_scale8(latent_channels, pool_blocks):
|
| def B(idx): return Block(64, 64, use_midblock_gn=(idx in pool_blocks))
|
| return nn.Sequential(
|
| Clamp(),
|
| conv(latent_channels, 64),
|
| nn.ReLU(),
|
| B(3), B(4), B(5),
|
| nn.Upsample(scale_factor=2),
|
| conv(64, 64, bias=False),
|
| B(8), B(9), B(10),
|
| nn.Upsample(scale_factor=2),
|
| conv(64, 64, bias=False),
|
| B(13), B(14), B(15),
|
| nn.Upsample(scale_factor=2),
|
| conv(64, 64, bias=False),
|
| B(18),
|
| conv(64, 3),
|
| )
|
|
|
|
|
| def build_encoder_scale16(latent_channels, pool_blocks):
|
| def B(idx): return Block(64, 64, use_midblock_gn=(idx in pool_blocks))
|
| return nn.Sequential(
|
| conv(3, 64),
|
| B(1),
|
| conv(64, 64, stride=2, bias=False),
|
| B(3), B(4), B(5),
|
| conv(64, 64, stride=2, bias=False),
|
| B(7), B(8), B(9),
|
| conv(64, 64, stride=2, bias=False),
|
| B(11), B(12), B(13),
|
| conv(64, 64, stride=2, bias=False),
|
| B(15), B(16), B(17),
|
| conv(64, latent_channels),
|
| )
|
|
|
|
|
| def build_decoder_scale16(latent_channels, pool_blocks):
|
| def B(idx): return Block(64, 64, use_midblock_gn=(idx in pool_blocks))
|
| return nn.Sequential(
|
| Clamp(),
|
| conv(latent_channels, 64),
|
| nn.ReLU(),
|
| B(3), B(4), B(5),
|
| nn.Upsample(scale_factor=2),
|
| conv(64, 64, bias=False),
|
| B(8), B(9), B(10),
|
| nn.Upsample(scale_factor=2),
|
| conv(64, 64, bias=False),
|
| B(13), B(14), B(15),
|
| nn.Upsample(scale_factor=2),
|
| conv(64, 64, bias=False),
|
| B(18), B(19), B(20),
|
| nn.Upsample(scale_factor=2),
|
| conv(64, 64, bias=False),
|
| B(23),
|
| conv(64, 3),
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
| def unpack_packed_latents(x, latent_channels):
|
|
|
| if x.ndim == 4 and x.shape[1] == latent_channels * 4:
|
| return (
|
| x.reshape(x.shape[0], latent_channels, 2, 2, x.shape[-2], x.shape[-1])
|
| .permute(0, 1, 4, 2, 5, 3)
|
| .reshape(x.shape[0], latent_channels, x.shape[-2] * 2, x.shape[-1] * 2)
|
| )
|
| return x
|
|
|
|
|
| def pack_packed_latents(z, latent_channels):
|
|
|
| if z.ndim == 4 and z.shape[1] == latent_channels:
|
| h, w = z.shape[-2], z.shape[-1]
|
| pad_h = h & 1
|
| pad_w = w & 1
|
| if pad_h or pad_w:
|
| z = F.pad(z, (0, pad_w, 0, pad_h), mode="replicate")
|
| h, w = z.shape[-2], z.shape[-1]
|
|
|
| return (
|
| z.reshape(z.shape[0], latent_channels, h // 2, 2, w // 2, 2)
|
| .permute(0, 1, 3, 5, 2, 4)
|
| .reshape(z.shape[0], latent_channels * 4, h // 2, w // 2)
|
| )
|
| return z
|
|
|
|
|
| def pad_nchw_to_multiple(x, multiple):
|
|
|
| _, _, h, w = x.shape
|
| pad_h = (multiple - (h % multiple)) % multiple
|
| pad_w = (multiple - (w % multiple)) % multiple
|
| if pad_h or pad_w:
|
| x = F.pad(x, (0, pad_w, 0, pad_h), mode="replicate")
|
| return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def normalize_state_dict(sd_raw):
|
| keys = list(sd_raw.keys())
|
|
|
|
|
| if any(k.startswith("taesd_encoder.") for k in keys) or any(k.startswith("taesd_decoder.") for k in keys):
|
| return sd_raw
|
|
|
| out = {}
|
|
|
|
|
| if any(k.startswith("encoder.layers.") for k in keys) or any(k.startswith("decoder.layers.") for k in keys):
|
| for k, v in sd_raw.items():
|
| if k.startswith("encoder.layers."):
|
|
|
| out["taesd_encoder." + k[len("encoder.layers."):]] = v
|
| elif k.startswith("decoder.layers."):
|
|
|
| rest = k[len("decoder.layers."):]
|
| parts = rest.split(".", 1)
|
| try:
|
| n = int(parts[0])
|
| n2 = n + 1
|
| tail = parts[1] if len(parts) > 1 else ""
|
| out_key = f"taesd_decoder.{n2}" + (("." + tail) if tail else "")
|
| out[out_key] = v
|
| except Exception:
|
|
|
| out[k] = v
|
| else:
|
| out[k] = v
|
| return out
|
|
|
|
|
| if any(k.startswith("encoder.") for k in keys) or any(k.startswith("decoder.") for k in keys):
|
| decoder_needs_offset = False
|
| w0 = sd_raw.get("decoder.0.weight", None)
|
| if isinstance(w0, torch.Tensor) and w0.ndim == 4 and w0.shape[0] == 64 and w0.shape[2:] == (3, 3):
|
| decoder_needs_offset = True
|
|
|
| for k, v in sd_raw.items():
|
| if k.startswith("encoder."):
|
| out["taesd_encoder." + k[len("encoder."):]] = v
|
| elif k.startswith("decoder."):
|
| rest = k[len("decoder."):]
|
| if decoder_needs_offset:
|
| parts = rest.split(".", 1)
|
| if parts[0].isdigit():
|
| n = int(parts[0]) + 1
|
| tail = parts[1] if len(parts) > 1 else ""
|
| out_key = f"taesd_decoder.{n}" + (("." + tail) if tail else "")
|
| out[out_key] = v
|
| else:
|
| out["taesd_decoder." + rest] = v
|
| else:
|
| out["taesd_decoder." + rest] = v
|
| else:
|
| out[k] = v
|
| return out
|
|
|
|
|
| return sd_raw
|
|
|
|
|
| def split_encoder_decoder(sd):
|
| enc = {k[len("taesd_encoder."):]: v for k, v in sd.items() if k.startswith("taesd_encoder.")}
|
| dec = {k[len("taesd_decoder."):]: v for k, v in sd.items() if k.startswith("taesd_decoder.")}
|
| return enc, dec
|
|
|
|
|
| def pool_blocks_from_sd(part_sd):
|
| blocks = set()
|
| for k in part_sd.keys():
|
| if ".pool.0.weight" in k or ".pool.0.bias" in k:
|
| head = k.split(".", 1)[0]
|
| if head.isdigit():
|
| blocks.add(int(head))
|
| return blocks
|
|
|
|
|
| def infer_latent_channels_from_decoder(dec_sd):
|
|
|
| candidates = []
|
| for k, v in dec_sd.items():
|
| if not isinstance(v, torch.Tensor) or v.ndim != 4:
|
| continue
|
| head = k.split(".", 1)[0]
|
| if head.isdigit() and v.shape[0] == 64 and v.shape[2:] == (3, 3):
|
| candidates.append((int(head), int(v.shape[1])))
|
| if not candidates:
|
| raise RuntimeError("Could not infer latent_channels from decoder weights.")
|
| candidates.sort(key=lambda t: t[0])
|
| return candidates[0][1]
|
|
|
|
|
| def detect_layout(enc_sd, latent_channels):
|
|
|
| if "14.weight" in enc_sd:
|
| w = enc_sd["14.weight"]
|
| if isinstance(w, torch.Tensor) and w.ndim == 4 and w.shape[0] == latent_channels and w.shape[1] == 64:
|
| return "scale8"
|
| if "18.weight" in enc_sd:
|
| w = enc_sd["18.weight"]
|
| if isinstance(w, torch.Tensor) and w.ndim == 4 and w.shape[0] == latent_channels and w.shape[1] == 64:
|
| return "scale16"
|
|
|
|
|
| best = None
|
| for k, v in enc_sd.items():
|
| if not isinstance(v, torch.Tensor) or v.ndim != 4:
|
| continue
|
| head = k.split(".", 1)[0]
|
| if head.isdigit() and v.shape[0] == latent_channels and v.shape[1] == 64 and v.shape[2:] == (3, 3):
|
| idx = int(head)
|
| best = idx if best is None else min(best, idx)
|
| if best is None:
|
| raise RuntimeError("Could not detect encoder layout (scale8 vs scale16).")
|
| return "scale8" if best <= 14 else "scale16"
|
|
|
|
|
|
|
|
|
|
|
|
|
| class TAESDCore(nn.Module):
|
| def __init__(self, encoder, decoder, latent_channels, is_taef2):
|
| super().__init__()
|
| self.encoder = encoder
|
| self.decoder = decoder
|
| self.latent_channels = int(latent_channels)
|
| self.is_taef2 = bool(is_taef2)
|
|
|
| self.vae_scale = nn.Parameter(torch.tensor(1.0))
|
| self.vae_shift = nn.Parameter(torch.tensor(0.0))
|
|
|
| @torch.inference_mode()
|
| def decode(self, x):
|
| x = unpack_packed_latents(x, self.latent_channels)
|
| x = (x - self.vae_shift) * self.vae_scale
|
| x_sample = self.decoder(x)
|
|
|
| return x_sample.sub(0.5).mul(2.0)
|
|
|
| @torch.inference_mode()
|
| def encode(self, x):
|
|
|
| z = (self.encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
|
| if self.is_taef2:
|
| z = pack_packed_latents(z, self.latent_channels)
|
| return z
|
|
|
|
|
| def load_core(path, device, dtype):
|
| sd_raw = comfy.utils.load_torch_file(path, safe_load=True)
|
| sd = normalize_state_dict(sd_raw)
|
| enc_sd, dec_sd = split_encoder_decoder(sd)
|
|
|
| if not enc_sd or not dec_sd:
|
| sample = list(sd_raw.keys())[:40]
|
| raise RuntimeError(
|
| "Could not split encoder/decoder weights.\n"
|
| "Use Dump VAE Keys node and paste first ~40 keys.\n"
|
| f"First keys: {sample}"
|
| )
|
|
|
| enc_pool = pool_blocks_from_sd(enc_sd)
|
| dec_pool = pool_blocks_from_sd(dec_sd)
|
|
|
| latent_channels = infer_latent_channels_from_decoder(dec_sd)
|
| layout = detect_layout(enc_sd, latent_channels)
|
|
|
|
|
| has_midblock_gn = (len(enc_pool) > 0) or (len(dec_pool) > 0)
|
| is_taef2 = (latent_channels == 32) and has_midblock_gn
|
|
|
| if layout == "scale8":
|
| encoder = build_encoder_scale8(latent_channels, enc_pool)
|
| decoder = build_decoder_scale8(latent_channels, dec_pool)
|
| base_downscale = 8
|
| else:
|
| encoder = build_encoder_scale16(latent_channels, enc_pool)
|
| decoder = build_decoder_scale16(latent_channels, dec_pool)
|
| base_downscale = 16
|
|
|
|
|
| core = TAESDCore(encoder, decoder, latent_channels, is_taef2)
|
| core.encoder.load_state_dict(enc_sd, strict=False)
|
| core.decoder.load_state_dict(dec_sd, strict=False)
|
|
|
| core = core.to(device=device, dtype=dtype).eval()
|
| for p in core.parameters():
|
| p.requires_grad_(False)
|
|
|
| core._base_downscale = base_downscale
|
| return core
|
|
|
|
|
|
|
|
|
|
|
|
|
| class TAEF2VAE:
|
| def __init__(self, weights_path, device, dtype):
|
| self.device = device
|
| self.dtype = dtype
|
| self.core = load_core(weights_path, device=device, dtype=dtype)
|
|
|
|
|
| self.downscale_ratio = self.core._base_downscale * (2 if self.core.is_taef2 else 1)
|
|
|
| print(
|
| f"[TAEF2] Loaded: {weights_path} | latent_channels={self.core.latent_channels} "
|
| f"| is_taef2={self.core.is_taef2} | base_downscale={self.core._base_downscale} "
|
| f"| effective_downscale={self.downscale_ratio}"
|
| )
|
|
|
| @torch.inference_mode()
|
| def decode(self, latents):
|
| x = latents.to(device=self.device, dtype=self.dtype)
|
| img = self.core.decode(x)
|
| img = img.clamp(-1, 1).add(1.0).mul(0.5)
|
| return img.to(torch.float32).permute(0, 2, 3, 1).contiguous()
|
|
|
| @torch.inference_mode()
|
| def encode(self, pixels):
|
|
|
| x = pixels[..., :3].permute(0, 3, 1, 2).contiguous()
|
| x = x.to(device=self.device, dtype=self.dtype).clamp(0, 1).mul(2.0).sub(1.0)
|
|
|
|
|
| x = pad_nchw_to_multiple(x, self.downscale_ratio)
|
|
|
| z = self.core.encode(x)
|
| return z.to(torch.float32)
|
|
|
| def decode_tiled(self, latents, **kwargs):
|
| return self.decode(latents)
|
|
|
| def encode_tiled(self, pixels, **kwargs):
|
| return self.encode(pixels)
|
|
|
| def spacial_compression_decode(self):
|
| return self.downscale_ratio
|
|
|
| def spacial_compression_encode(self):
|
| return self.downscale_ratio
|
|
|
| def temporal_compression_decode(self):
|
| return None
|
|
|
| def temporal_compression_encode(self):
|
| return None
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _list_vae_files():
|
| vae_files = folder_paths.get_filename_list("vae")
|
| approx_files = folder_paths.get_filename_list("vae_approx")
|
| return sorted(set(vae_files + approx_files))
|
|
|
| def _resolve_vae_path(fname):
|
| path = folder_paths.get_full_path("vae_approx", fname)
|
| if path is None:
|
| path = folder_paths.get_full_path("vae", fname)
|
| return path
|
|
|
|
|
| class LoadTAEF2VAE:
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| return {
|
| "required": {
|
| "weights": (_list_vae_files(),),
|
| "dtype": (["bf16", "fp16", "fp32"], {"default": "bf16"}),
|
| }
|
| }
|
|
|
| RETURN_TYPES = ("VAE",)
|
| FUNCTION = "load"
|
| CATEGORY = "latent/vae"
|
|
|
| def load(self, weights, dtype):
|
| path = _resolve_vae_path(weights)
|
| if path is None:
|
| raise FileNotFoundError(f"Could not find weights file: {weights}")
|
|
|
| device = comfy.model_management.get_torch_device()
|
| if dtype == "bf16":
|
| tdtype = torch.bfloat16
|
| elif dtype == "fp16":
|
| tdtype = torch.float16
|
| else:
|
| tdtype = torch.float32
|
|
|
| return (TAEF2VAE(path, device=device, dtype=tdtype),)
|
|
|
|
|
| class DumpVAEKeys:
|
| @classmethod
|
| def INPUT_TYPES(cls):
|
| return {
|
| "required": {
|
| "weights": (_list_vae_files(),),
|
| "include_shapes": ("BOOLEAN", {"default": True}),
|
| "sort_keys": ("BOOLEAN", {"default": True}),
|
| "max_lines": ("INT", {"default": 0, "min": 0, "max": 200000}),
|
| }
|
| }
|
|
|
| RETURN_TYPES = ("STRING",)
|
| FUNCTION = "dump"
|
| CATEGORY = "utils/debug"
|
|
|
| def dump(self, weights, include_shapes, sort_keys, max_lines):
|
| path = _resolve_vae_path(weights)
|
| if path is None:
|
| raise FileNotFoundError(f"Could not find weights file: {weights}")
|
|
|
| sd = comfy.utils.load_torch_file(path, safe_load=True)
|
| keys = list(sd.keys())
|
| if sort_keys:
|
| keys.sort()
|
|
|
| lines = []
|
| if include_shapes:
|
| for k in keys:
|
| v = sd[k]
|
| if isinstance(v, torch.Tensor):
|
| lines.append(f"{k}\t{tuple(v.shape)}\t{str(v.dtype)}")
|
| else:
|
| lines.append(f"{k}\t{type(v)}")
|
| else:
|
| lines = keys
|
|
|
| if max_lines and len(lines) > max_lines:
|
| head = lines[:max_lines]
|
| head.append(f"... TRUNCATED: total_keys={len(lines)} (showing first {max_lines}) ...")
|
| lines = head
|
|
|
| text = "\n".join(lines)
|
| return {"ui": {"text": [text]}, "result": (text,)}
|
|
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "LoadTAEF2VAE": LoadTAEF2VAE,
|
| "DumpVAEKeys": DumpVAEKeys,
|
| }
|
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {
|
| "LoadTAEF2VAE": "Load TAEF2 (Flux2 Tiny VAE)",
|
| "DumpVAEKeys": "Dump VAE Keys (as String)",
|
| }
|
|
|