import torch import torch.nn as nn import torch.nn.functional as F import comfy.utils import comfy.ops import comfy.model_management import folder_paths # ============================================================ # Layers (Comfy-style: disable_weight_init) # ============================================================ def conv(n_in, n_out, **kwargs): return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs) class Clamp(nn.Module): def forward(self, x): return torch.tanh(x / 3) * 3 class Block(nn.Module): def __init__(self, n_in, n_out, use_midblock_gn=False): super().__init__() self.conv = nn.Sequential( conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out), ) self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.fuse = nn.ReLU() self.pool = None if use_midblock_gn: conv1x1 = lambda a, b: comfy.ops.disable_weight_init.Conv2d(a, b, 1, bias=False) n_gn = n_in * 4 self.pool = nn.Sequential( conv1x1(n_in, n_gn), comfy.ops.disable_weight_init.GroupNorm(4, n_gn), nn.ReLU(inplace=True), conv1x1(n_gn, n_in), ) def forward(self, x): if self.pool is not None: x = x + self.pool(x) return self.fuse(self.conv(x) + self.skip(x)) # ============================================================ # TAESD scale8/scale16 builders # Your file is scale8 (encoder final conv at layers.14) # ============================================================ def build_encoder_scale8(latent_channels, pool_blocks): def B(idx): return Block(64, 64, use_midblock_gn=(idx in pool_blocks)) return nn.Sequential( conv(3, 64), # 0 B(1), # 1 conv(64, 64, stride=2, bias=False), # 2 B(3), B(4), B(5), # 3-5 conv(64, 64, stride=2, bias=False), # 6 B(7), B(8), B(9), # 7-9 conv(64, 64, stride=2, bias=False), # 10 B(11), B(12), B(13), # 11-13 conv(64, latent_channels), # 14 ) def build_decoder_scale8(latent_channels, pool_blocks): def B(idx): return Block(64, 64, use_midblock_gn=(idx in pool_blocks)) return nn.Sequential( Clamp(), # 0 (no weights) conv(latent_channels, 64), # 1 nn.ReLU(), # 2 B(3), B(4), B(5), # 3-5 nn.Upsample(scale_factor=2), # 6 conv(64, 64, bias=False), # 7 B(8), B(9), B(10), # 8-10 nn.Upsample(scale_factor=2), # 11 conv(64, 64, bias=False), # 12 B(13), B(14), B(15), # 13-15 nn.Upsample(scale_factor=2), # 16 conv(64, 64, bias=False), # 17 B(18), # 18 conv(64, 3), # 19 ) def build_encoder_scale16(latent_channels, pool_blocks): def B(idx): return Block(64, 64, use_midblock_gn=(idx in pool_blocks)) return nn.Sequential( conv(3, 64), # 0 B(1), # 1 conv(64, 64, stride=2, bias=False), # 2 B(3), B(4), B(5), # 3-5 conv(64, 64, stride=2, bias=False), # 6 B(7), B(8), B(9), # 7-9 conv(64, 64, stride=2, bias=False), # 10 B(11), B(12), B(13), # 11-13 conv(64, 64, stride=2, bias=False), # 14 B(15), B(16), B(17), # 15-17 conv(64, latent_channels), # 18 ) def build_decoder_scale16(latent_channels, pool_blocks): def B(idx): return Block(64, 64, use_midblock_gn=(idx in pool_blocks)) return nn.Sequential( Clamp(), # 0 conv(latent_channels, 64), # 1 nn.ReLU(), # 2 B(3), B(4), B(5), # 3-5 nn.Upsample(scale_factor=2), # 6 conv(64, 64, bias=False), # 7 B(8), B(9), B(10), # 8-10 nn.Upsample(scale_factor=2), # 11 conv(64, 64, bias=False), # 12 B(13), B(14), B(15), # 13-15 nn.Upsample(scale_factor=2), # 16 conv(64, 64, bias=False), # 17 B(18), B(19), B(20), # 18-20 nn.Upsample(scale_factor=2), # 21 conv(64, 64, bias=False), # 22 B(23), # 23 conv(64, 3), # 24 ) # ============================================================ # Packed latents (auto-pad so it never errors) # ============================================================ def unpack_packed_latents(x, latent_channels): # [B, C*4, H, W] -> [B, C, H*2, W*2] if x.ndim == 4 and x.shape[1] == latent_channels * 4: return ( x.reshape(x.shape[0], latent_channels, 2, 2, x.shape[-2], x.shape[-1]) .permute(0, 1, 4, 2, 5, 3) .reshape(x.shape[0], latent_channels, x.shape[-2] * 2, x.shape[-1] * 2) ) return x def pack_packed_latents(z, latent_channels): # [B, C, H, W] -> [B, C*4, H//2, W//2] if z.ndim == 4 and z.shape[1] == latent_channels: h, w = z.shape[-2], z.shape[-1] pad_h = h & 1 pad_w = w & 1 if pad_h or pad_w: z = F.pad(z, (0, pad_w, 0, pad_h), mode="replicate") h, w = z.shape[-2], z.shape[-1] return ( z.reshape(z.shape[0], latent_channels, h // 2, 2, w // 2, 2) .permute(0, 1, 3, 5, 2, 4) .reshape(z.shape[0], latent_channels * 4, h // 2, w // 2) ) return z def pad_nchw_to_multiple(x, multiple): # replicate pad right/bottom so any size works _, _, h, w = x.shape pad_h = (multiple - (h % multiple)) % multiple pad_w = (multiple - (w % multiple)) % multiple if pad_h or pad_w: x = F.pad(x, (0, pad_w, 0, pad_h), mode="replicate") return x # ============================================================ # Key conversion for your file format: # encoder.layers.N.* and decoder.layers.N.* # decoder layers must shift +1 because our decoder has Clamp() at index 0. # ============================================================ def normalize_state_dict(sd_raw): keys = list(sd_raw.keys()) # Already comfy split format? if any(k.startswith("taesd_encoder.") for k in keys) or any(k.startswith("taesd_decoder.") for k in keys): return sd_raw out = {} # Diffusers "encoder.layers.* / decoder.layers.*" if any(k.startswith("encoder.layers.") for k in keys) or any(k.startswith("decoder.layers.") for k in keys): for k, v in sd_raw.items(): if k.startswith("encoder.layers."): # encoder.layers.N.xxx -> taesd_encoder.N.xxx out["taesd_encoder." + k[len("encoder.layers."):]] = v elif k.startswith("decoder.layers."): # decoder.layers.N.xxx -> taesd_decoder.(N+1).xxx (Clamp at 0) rest = k[len("decoder.layers."):] parts = rest.split(".", 1) try: n = int(parts[0]) n2 = n + 1 tail = parts[1] if len(parts) > 1 else "" out_key = f"taesd_decoder.{n2}" + (("." + tail) if tail else "") out[out_key] = v except Exception: # fallback, keep out[k] = v else: out[k] = v return out # Fallback: encoder./decoder. (numeric) — if decoder.0.weight looks like [64,C,3,3], offset it too if any(k.startswith("encoder.") for k in keys) or any(k.startswith("decoder.") for k in keys): decoder_needs_offset = False w0 = sd_raw.get("decoder.0.weight", None) if isinstance(w0, torch.Tensor) and w0.ndim == 4 and w0.shape[0] == 64 and w0.shape[2:] == (3, 3): decoder_needs_offset = True for k, v in sd_raw.items(): if k.startswith("encoder."): out["taesd_encoder." + k[len("encoder."):]] = v elif k.startswith("decoder."): rest = k[len("decoder."):] if decoder_needs_offset: parts = rest.split(".", 1) if parts[0].isdigit(): n = int(parts[0]) + 1 tail = parts[1] if len(parts) > 1 else "" out_key = f"taesd_decoder.{n}" + (("." + tail) if tail else "") out[out_key] = v else: out["taesd_decoder." + rest] = v else: out["taesd_decoder." + rest] = v else: out[k] = v return out # Unknown layout: return as-is (Dump node will show keys) return sd_raw def split_encoder_decoder(sd): enc = {k[len("taesd_encoder."):]: v for k, v in sd.items() if k.startswith("taesd_encoder.")} dec = {k[len("taesd_decoder."):]: v for k, v in sd.items() if k.startswith("taesd_decoder.")} return enc, dec def pool_blocks_from_sd(part_sd): blocks = set() for k in part_sd.keys(): if ".pool.0.weight" in k or ".pool.0.bias" in k: head = k.split(".", 1)[0] if head.isdigit(): blocks.add(int(head)) return blocks def infer_latent_channels_from_decoder(dec_sd): # Find smallest-index conv weight that looks like decoder input conv: [64, C, 3, 3] candidates = [] for k, v in dec_sd.items(): if not isinstance(v, torch.Tensor) or v.ndim != 4: continue head = k.split(".", 1)[0] if head.isdigit() and v.shape[0] == 64 and v.shape[2:] == (3, 3): candidates.append((int(head), int(v.shape[1]))) if not candidates: raise RuntimeError("Could not infer latent_channels from decoder weights.") candidates.sort(key=lambda t: t[0]) return candidates[0][1] def detect_layout(enc_sd, latent_channels): # Your file has encoder.layers.14.* -> after normalize it's "14.weight" if "14.weight" in enc_sd: w = enc_sd["14.weight"] if isinstance(w, torch.Tensor) and w.ndim == 4 and w.shape[0] == latent_channels and w.shape[1] == 64: return "scale8" if "18.weight" in enc_sd: w = enc_sd["18.weight"] if isinstance(w, torch.Tensor) and w.ndim == 4 and w.shape[0] == latent_channels and w.shape[1] == 64: return "scale16" # Fallback: find earliest [C,64,3,3] conv in encoder best = None for k, v in enc_sd.items(): if not isinstance(v, torch.Tensor) or v.ndim != 4: continue head = k.split(".", 1)[0] if head.isdigit() and v.shape[0] == latent_channels and v.shape[1] == 64 and v.shape[2:] == (3, 3): idx = int(head) best = idx if best is None else min(best, idx) if best is None: raise RuntimeError("Could not detect encoder layout (scale8 vs scale16).") return "scale8" if best <= 14 else "scale16" # ============================================================ # Core model (PR behavior: decode -> [-1,1], encode -> packed for taef2) # ============================================================ class TAESDCore(nn.Module): def __init__(self, encoder, decoder, latent_channels, is_taef2): super().__init__() self.encoder = encoder self.decoder = decoder self.latent_channels = int(latent_channels) self.is_taef2 = bool(is_taef2) self.vae_scale = nn.Parameter(torch.tensor(1.0)) self.vae_shift = nn.Parameter(torch.tensor(0.0)) @torch.inference_mode() def decode(self, x): x = unpack_packed_latents(x, self.latent_channels) x = (x - self.vae_shift) * self.vae_scale x_sample = self.decoder(x) # decoder output in [0,1] -> [-1,1] return x_sample.sub(0.5).mul(2.0) @torch.inference_mode() def encode(self, x): # x is [-1,1] -> encoder expects [0,1] z = (self.encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift if self.is_taef2: z = pack_packed_latents(z, self.latent_channels) return z def load_core(path, device, dtype): sd_raw = comfy.utils.load_torch_file(path, safe_load=True) sd = normalize_state_dict(sd_raw) enc_sd, dec_sd = split_encoder_decoder(sd) if not enc_sd or not dec_sd: sample = list(sd_raw.keys())[:40] raise RuntimeError( "Could not split encoder/decoder weights.\n" "Use Dump VAE Keys node and paste first ~40 keys.\n" f"First keys: {sample}" ) enc_pool = pool_blocks_from_sd(enc_sd) dec_pool = pool_blocks_from_sd(dec_sd) latent_channels = infer_latent_channels_from_decoder(dec_sd) layout = detect_layout(enc_sd, latent_channels) # Flux2 taef2 packed-latents heuristic (matches your file): has_midblock_gn = (len(enc_pool) > 0) or (len(dec_pool) > 0) is_taef2 = (latent_channels == 32) and has_midblock_gn if layout == "scale8": encoder = build_encoder_scale8(latent_channels, enc_pool) decoder = build_decoder_scale8(latent_channels, dec_pool) base_downscale = 8 else: encoder = build_encoder_scale16(latent_channels, enc_pool) decoder = build_decoder_scale16(latent_channels, dec_pool) base_downscale = 16 # Load in fp32 first, then cast (more robust) core = TAESDCore(encoder, decoder, latent_channels, is_taef2) core.encoder.load_state_dict(enc_sd, strict=False) core.decoder.load_state_dict(dec_sd, strict=False) core = core.to(device=device, dtype=dtype).eval() for p in core.parameters(): p.requires_grad_(False) core._base_downscale = base_downscale return core # ============================================================ # Comfy VAE interface object # ============================================================ class TAEF2VAE: def __init__(self, weights_path, device, dtype): self.device = device self.dtype = dtype self.core = load_core(weights_path, device=device, dtype=dtype) # packed latents halves latent H/W again -> effective downscale doubles self.downscale_ratio = self.core._base_downscale * (2 if self.core.is_taef2 else 1) print( f"[TAEF2] Loaded: {weights_path} | latent_channels={self.core.latent_channels} " f"| is_taef2={self.core.is_taef2} | base_downscale={self.core._base_downscale} " f"| effective_downscale={self.downscale_ratio}" ) @torch.inference_mode() def decode(self, latents): x = latents.to(device=self.device, dtype=self.dtype) img = self.core.decode(x) # NCHW in [-1,1] img = img.clamp(-1, 1).add(1.0).mul(0.5) # -> [0,1] return img.to(torch.float32).permute(0, 2, 3, 1).contiguous() # NHWC float32 @torch.inference_mode() def encode(self, pixels): # pixels NHWC [0,1] x = pixels[..., :3].permute(0, 3, 1, 2).contiguous() x = x.to(device=self.device, dtype=self.dtype).clamp(0, 1).mul(2.0).sub(1.0) # -> [-1,1] # Make it behave like base VAE: pad to required multiple so any size works x = pad_nchw_to_multiple(x, self.downscale_ratio) z = self.core.encode(x) # packed if taef2 return z.to(torch.float32) def decode_tiled(self, latents, **kwargs): return self.decode(latents) def encode_tiled(self, pixels, **kwargs): return self.encode(pixels) def spacial_compression_decode(self): return self.downscale_ratio def spacial_compression_encode(self): return self.downscale_ratio def temporal_compression_decode(self): return None def temporal_compression_encode(self): return None # ============================================================ # Nodes # ============================================================ def _list_vae_files(): vae_files = folder_paths.get_filename_list("vae") approx_files = folder_paths.get_filename_list("vae_approx") return sorted(set(vae_files + approx_files)) def _resolve_vae_path(fname): path = folder_paths.get_full_path("vae_approx", fname) if path is None: path = folder_paths.get_full_path("vae", fname) return path class LoadTAEF2VAE: @classmethod def INPUT_TYPES(cls): return { "required": { "weights": (_list_vae_files(),), "dtype": (["bf16", "fp16", "fp32"], {"default": "bf16"}), } } RETURN_TYPES = ("VAE",) FUNCTION = "load" CATEGORY = "latent/vae" def load(self, weights, dtype): path = _resolve_vae_path(weights) if path is None: raise FileNotFoundError(f"Could not find weights file: {weights}") device = comfy.model_management.get_torch_device() if dtype == "bf16": tdtype = torch.bfloat16 elif dtype == "fp16": tdtype = torch.float16 else: tdtype = torch.float32 return (TAEF2VAE(path, device=device, dtype=tdtype),) class DumpVAEKeys: @classmethod def INPUT_TYPES(cls): return { "required": { "weights": (_list_vae_files(),), "include_shapes": ("BOOLEAN", {"default": True}), "sort_keys": ("BOOLEAN", {"default": True}), "max_lines": ("INT", {"default": 0, "min": 0, "max": 200000}), } } RETURN_TYPES = ("STRING",) FUNCTION = "dump" CATEGORY = "utils/debug" def dump(self, weights, include_shapes, sort_keys, max_lines): path = _resolve_vae_path(weights) if path is None: raise FileNotFoundError(f"Could not find weights file: {weights}") sd = comfy.utils.load_torch_file(path, safe_load=True) keys = list(sd.keys()) if sort_keys: keys.sort() lines = [] if include_shapes: for k in keys: v = sd[k] if isinstance(v, torch.Tensor): lines.append(f"{k}\t{tuple(v.shape)}\t{str(v.dtype)}") else: lines.append(f"{k}\t{type(v)}") else: lines = keys if max_lines and len(lines) > max_lines: head = lines[:max_lines] head.append(f"... TRUNCATED: total_keys={len(lines)} (showing first {max_lines}) ...") lines = head text = "\n".join(lines) return {"ui": {"text": [text]}, "result": (text,)} NODE_CLASS_MAPPINGS = { "LoadTAEF2VAE": LoadTAEF2VAE, "DumpVAEKeys": DumpVAEKeys, } NODE_DISPLAY_NAME_MAPPINGS = { "LoadTAEF2VAE": "Load TAEF2 (Flux2 Tiny VAE)", "DumpVAEKeys": "Dump VAE Keys (as String)", }