ApacheOne's picture
Upload 2 files
7798da9 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.utils
import comfy.ops
import comfy.model_management
import folder_paths
# ============================================================
# Layers (Comfy-style: disable_weight_init)
# ============================================================
def conv(n_in, n_out, **kwargs):
return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class Block(nn.Module):
def __init__(self, n_in, n_out, use_midblock_gn=False):
super().__init__()
self.conv = nn.Sequential(
conv(n_in, n_out), nn.ReLU(),
conv(n_out, n_out), nn.ReLU(),
conv(n_out, n_out),
)
self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
self.pool = None
if use_midblock_gn:
conv1x1 = lambda a, b: comfy.ops.disable_weight_init.Conv2d(a, b, 1, bias=False)
n_gn = n_in * 4
self.pool = nn.Sequential(
conv1x1(n_in, n_gn),
comfy.ops.disable_weight_init.GroupNorm(4, n_gn),
nn.ReLU(inplace=True),
conv1x1(n_gn, n_in),
)
def forward(self, x):
if self.pool is not None:
x = x + self.pool(x)
return self.fuse(self.conv(x) + self.skip(x))
# ============================================================
# TAESD scale8/scale16 builders
# Your file is scale8 (encoder final conv at layers.14)
# ============================================================
def build_encoder_scale8(latent_channels, pool_blocks):
def B(idx): return Block(64, 64, use_midblock_gn=(idx in pool_blocks))
return nn.Sequential(
conv(3, 64), # 0
B(1), # 1
conv(64, 64, stride=2, bias=False), # 2
B(3), B(4), B(5), # 3-5
conv(64, 64, stride=2, bias=False), # 6
B(7), B(8), B(9), # 7-9
conv(64, 64, stride=2, bias=False), # 10
B(11), B(12), B(13), # 11-13
conv(64, latent_channels), # 14
)
def build_decoder_scale8(latent_channels, pool_blocks):
def B(idx): return Block(64, 64, use_midblock_gn=(idx in pool_blocks))
return nn.Sequential(
Clamp(), # 0 (no weights)
conv(latent_channels, 64), # 1
nn.ReLU(), # 2
B(3), B(4), B(5), # 3-5
nn.Upsample(scale_factor=2), # 6
conv(64, 64, bias=False), # 7
B(8), B(9), B(10), # 8-10
nn.Upsample(scale_factor=2), # 11
conv(64, 64, bias=False), # 12
B(13), B(14), B(15), # 13-15
nn.Upsample(scale_factor=2), # 16
conv(64, 64, bias=False), # 17
B(18), # 18
conv(64, 3), # 19
)
def build_encoder_scale16(latent_channels, pool_blocks):
def B(idx): return Block(64, 64, use_midblock_gn=(idx in pool_blocks))
return nn.Sequential(
conv(3, 64), # 0
B(1), # 1
conv(64, 64, stride=2, bias=False), # 2
B(3), B(4), B(5), # 3-5
conv(64, 64, stride=2, bias=False), # 6
B(7), B(8), B(9), # 7-9
conv(64, 64, stride=2, bias=False), # 10
B(11), B(12), B(13), # 11-13
conv(64, 64, stride=2, bias=False), # 14
B(15), B(16), B(17), # 15-17
conv(64, latent_channels), # 18
)
def build_decoder_scale16(latent_channels, pool_blocks):
def B(idx): return Block(64, 64, use_midblock_gn=(idx in pool_blocks))
return nn.Sequential(
Clamp(), # 0
conv(latent_channels, 64), # 1
nn.ReLU(), # 2
B(3), B(4), B(5), # 3-5
nn.Upsample(scale_factor=2), # 6
conv(64, 64, bias=False), # 7
B(8), B(9), B(10), # 8-10
nn.Upsample(scale_factor=2), # 11
conv(64, 64, bias=False), # 12
B(13), B(14), B(15), # 13-15
nn.Upsample(scale_factor=2), # 16
conv(64, 64, bias=False), # 17
B(18), B(19), B(20), # 18-20
nn.Upsample(scale_factor=2), # 21
conv(64, 64, bias=False), # 22
B(23), # 23
conv(64, 3), # 24
)
# ============================================================
# Packed latents (auto-pad so it never errors)
# ============================================================
def unpack_packed_latents(x, latent_channels):
# [B, C*4, H, W] -> [B, C, H*2, W*2]
if x.ndim == 4 and x.shape[1] == latent_channels * 4:
return (
x.reshape(x.shape[0], latent_channels, 2, 2, x.shape[-2], x.shape[-1])
.permute(0, 1, 4, 2, 5, 3)
.reshape(x.shape[0], latent_channels, x.shape[-2] * 2, x.shape[-1] * 2)
)
return x
def pack_packed_latents(z, latent_channels):
# [B, C, H, W] -> [B, C*4, H//2, W//2]
if z.ndim == 4 and z.shape[1] == latent_channels:
h, w = z.shape[-2], z.shape[-1]
pad_h = h & 1
pad_w = w & 1
if pad_h or pad_w:
z = F.pad(z, (0, pad_w, 0, pad_h), mode="replicate")
h, w = z.shape[-2], z.shape[-1]
return (
z.reshape(z.shape[0], latent_channels, h // 2, 2, w // 2, 2)
.permute(0, 1, 3, 5, 2, 4)
.reshape(z.shape[0], latent_channels * 4, h // 2, w // 2)
)
return z
def pad_nchw_to_multiple(x, multiple):
# replicate pad right/bottom so any size works
_, _, h, w = x.shape
pad_h = (multiple - (h % multiple)) % multiple
pad_w = (multiple - (w % multiple)) % multiple
if pad_h or pad_w:
x = F.pad(x, (0, pad_w, 0, pad_h), mode="replicate")
return x
# ============================================================
# Key conversion for your file format:
# encoder.layers.N.* and decoder.layers.N.*
# decoder layers must shift +1 because our decoder has Clamp() at index 0.
# ============================================================
def normalize_state_dict(sd_raw):
keys = list(sd_raw.keys())
# Already comfy split format?
if any(k.startswith("taesd_encoder.") for k in keys) or any(k.startswith("taesd_decoder.") for k in keys):
return sd_raw
out = {}
# Diffusers "encoder.layers.* / decoder.layers.*"
if any(k.startswith("encoder.layers.") for k in keys) or any(k.startswith("decoder.layers.") for k in keys):
for k, v in sd_raw.items():
if k.startswith("encoder.layers."):
# encoder.layers.N.xxx -> taesd_encoder.N.xxx
out["taesd_encoder." + k[len("encoder.layers."):]] = v
elif k.startswith("decoder.layers."):
# decoder.layers.N.xxx -> taesd_decoder.(N+1).xxx (Clamp at 0)
rest = k[len("decoder.layers."):]
parts = rest.split(".", 1)
try:
n = int(parts[0])
n2 = n + 1
tail = parts[1] if len(parts) > 1 else ""
out_key = f"taesd_decoder.{n2}" + (("." + tail) if tail else "")
out[out_key] = v
except Exception:
# fallback, keep
out[k] = v
else:
out[k] = v
return out
# Fallback: encoder./decoder. (numeric) — if decoder.0.weight looks like [64,C,3,3], offset it too
if any(k.startswith("encoder.") for k in keys) or any(k.startswith("decoder.") for k in keys):
decoder_needs_offset = False
w0 = sd_raw.get("decoder.0.weight", None)
if isinstance(w0, torch.Tensor) and w0.ndim == 4 and w0.shape[0] == 64 and w0.shape[2:] == (3, 3):
decoder_needs_offset = True
for k, v in sd_raw.items():
if k.startswith("encoder."):
out["taesd_encoder." + k[len("encoder."):]] = v
elif k.startswith("decoder."):
rest = k[len("decoder."):]
if decoder_needs_offset:
parts = rest.split(".", 1)
if parts[0].isdigit():
n = int(parts[0]) + 1
tail = parts[1] if len(parts) > 1 else ""
out_key = f"taesd_decoder.{n}" + (("." + tail) if tail else "")
out[out_key] = v
else:
out["taesd_decoder." + rest] = v
else:
out["taesd_decoder." + rest] = v
else:
out[k] = v
return out
# Unknown layout: return as-is (Dump node will show keys)
return sd_raw
def split_encoder_decoder(sd):
enc = {k[len("taesd_encoder."):]: v for k, v in sd.items() if k.startswith("taesd_encoder.")}
dec = {k[len("taesd_decoder."):]: v for k, v in sd.items() if k.startswith("taesd_decoder.")}
return enc, dec
def pool_blocks_from_sd(part_sd):
blocks = set()
for k in part_sd.keys():
if ".pool.0.weight" in k or ".pool.0.bias" in k:
head = k.split(".", 1)[0]
if head.isdigit():
blocks.add(int(head))
return blocks
def infer_latent_channels_from_decoder(dec_sd):
# Find smallest-index conv weight that looks like decoder input conv: [64, C, 3, 3]
candidates = []
for k, v in dec_sd.items():
if not isinstance(v, torch.Tensor) or v.ndim != 4:
continue
head = k.split(".", 1)[0]
if head.isdigit() and v.shape[0] == 64 and v.shape[2:] == (3, 3):
candidates.append((int(head), int(v.shape[1])))
if not candidates:
raise RuntimeError("Could not infer latent_channels from decoder weights.")
candidates.sort(key=lambda t: t[0])
return candidates[0][1]
def detect_layout(enc_sd, latent_channels):
# Your file has encoder.layers.14.* -> after normalize it's "14.weight"
if "14.weight" in enc_sd:
w = enc_sd["14.weight"]
if isinstance(w, torch.Tensor) and w.ndim == 4 and w.shape[0] == latent_channels and w.shape[1] == 64:
return "scale8"
if "18.weight" in enc_sd:
w = enc_sd["18.weight"]
if isinstance(w, torch.Tensor) and w.ndim == 4 and w.shape[0] == latent_channels and w.shape[1] == 64:
return "scale16"
# Fallback: find earliest [C,64,3,3] conv in encoder
best = None
for k, v in enc_sd.items():
if not isinstance(v, torch.Tensor) or v.ndim != 4:
continue
head = k.split(".", 1)[0]
if head.isdigit() and v.shape[0] == latent_channels and v.shape[1] == 64 and v.shape[2:] == (3, 3):
idx = int(head)
best = idx if best is None else min(best, idx)
if best is None:
raise RuntimeError("Could not detect encoder layout (scale8 vs scale16).")
return "scale8" if best <= 14 else "scale16"
# ============================================================
# Core model (PR behavior: decode -> [-1,1], encode -> packed for taef2)
# ============================================================
class TAESDCore(nn.Module):
def __init__(self, encoder, decoder, latent_channels, is_taef2):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.latent_channels = int(latent_channels)
self.is_taef2 = bool(is_taef2)
self.vae_scale = nn.Parameter(torch.tensor(1.0))
self.vae_shift = nn.Parameter(torch.tensor(0.0))
@torch.inference_mode()
def decode(self, x):
x = unpack_packed_latents(x, self.latent_channels)
x = (x - self.vae_shift) * self.vae_scale
x_sample = self.decoder(x)
# decoder output in [0,1] -> [-1,1]
return x_sample.sub(0.5).mul(2.0)
@torch.inference_mode()
def encode(self, x):
# x is [-1,1] -> encoder expects [0,1]
z = (self.encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
if self.is_taef2:
z = pack_packed_latents(z, self.latent_channels)
return z
def load_core(path, device, dtype):
sd_raw = comfy.utils.load_torch_file(path, safe_load=True)
sd = normalize_state_dict(sd_raw)
enc_sd, dec_sd = split_encoder_decoder(sd)
if not enc_sd or not dec_sd:
sample = list(sd_raw.keys())[:40]
raise RuntimeError(
"Could not split encoder/decoder weights.\n"
"Use Dump VAE Keys node and paste first ~40 keys.\n"
f"First keys: {sample}"
)
enc_pool = pool_blocks_from_sd(enc_sd)
dec_pool = pool_blocks_from_sd(dec_sd)
latent_channels = infer_latent_channels_from_decoder(dec_sd)
layout = detect_layout(enc_sd, latent_channels)
# Flux2 taef2 packed-latents heuristic (matches your file):
has_midblock_gn = (len(enc_pool) > 0) or (len(dec_pool) > 0)
is_taef2 = (latent_channels == 32) and has_midblock_gn
if layout == "scale8":
encoder = build_encoder_scale8(latent_channels, enc_pool)
decoder = build_decoder_scale8(latent_channels, dec_pool)
base_downscale = 8
else:
encoder = build_encoder_scale16(latent_channels, enc_pool)
decoder = build_decoder_scale16(latent_channels, dec_pool)
base_downscale = 16
# Load in fp32 first, then cast (more robust)
core = TAESDCore(encoder, decoder, latent_channels, is_taef2)
core.encoder.load_state_dict(enc_sd, strict=False)
core.decoder.load_state_dict(dec_sd, strict=False)
core = core.to(device=device, dtype=dtype).eval()
for p in core.parameters():
p.requires_grad_(False)
core._base_downscale = base_downscale
return core
# ============================================================
# Comfy VAE interface object
# ============================================================
class TAEF2VAE:
def __init__(self, weights_path, device, dtype):
self.device = device
self.dtype = dtype
self.core = load_core(weights_path, device=device, dtype=dtype)
# packed latents halves latent H/W again -> effective downscale doubles
self.downscale_ratio = self.core._base_downscale * (2 if self.core.is_taef2 else 1)
print(
f"[TAEF2] Loaded: {weights_path} | latent_channels={self.core.latent_channels} "
f"| is_taef2={self.core.is_taef2} | base_downscale={self.core._base_downscale} "
f"| effective_downscale={self.downscale_ratio}"
)
@torch.inference_mode()
def decode(self, latents):
x = latents.to(device=self.device, dtype=self.dtype)
img = self.core.decode(x) # NCHW in [-1,1]
img = img.clamp(-1, 1).add(1.0).mul(0.5) # -> [0,1]
return img.to(torch.float32).permute(0, 2, 3, 1).contiguous() # NHWC float32
@torch.inference_mode()
def encode(self, pixels):
# pixels NHWC [0,1]
x = pixels[..., :3].permute(0, 3, 1, 2).contiguous()
x = x.to(device=self.device, dtype=self.dtype).clamp(0, 1).mul(2.0).sub(1.0) # -> [-1,1]
# Make it behave like base VAE: pad to required multiple so any size works
x = pad_nchw_to_multiple(x, self.downscale_ratio)
z = self.core.encode(x) # packed if taef2
return z.to(torch.float32)
def decode_tiled(self, latents, **kwargs):
return self.decode(latents)
def encode_tiled(self, pixels, **kwargs):
return self.encode(pixels)
def spacial_compression_decode(self):
return self.downscale_ratio
def spacial_compression_encode(self):
return self.downscale_ratio
def temporal_compression_decode(self):
return None
def temporal_compression_encode(self):
return None
# ============================================================
# Nodes
# ============================================================
def _list_vae_files():
vae_files = folder_paths.get_filename_list("vae")
approx_files = folder_paths.get_filename_list("vae_approx")
return sorted(set(vae_files + approx_files))
def _resolve_vae_path(fname):
path = folder_paths.get_full_path("vae_approx", fname)
if path is None:
path = folder_paths.get_full_path("vae", fname)
return path
class LoadTAEF2VAE:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"weights": (_list_vae_files(),),
"dtype": (["bf16", "fp16", "fp32"], {"default": "bf16"}),
}
}
RETURN_TYPES = ("VAE",)
FUNCTION = "load"
CATEGORY = "latent/vae"
def load(self, weights, dtype):
path = _resolve_vae_path(weights)
if path is None:
raise FileNotFoundError(f"Could not find weights file: {weights}")
device = comfy.model_management.get_torch_device()
if dtype == "bf16":
tdtype = torch.bfloat16
elif dtype == "fp16":
tdtype = torch.float16
else:
tdtype = torch.float32
return (TAEF2VAE(path, device=device, dtype=tdtype),)
class DumpVAEKeys:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"weights": (_list_vae_files(),),
"include_shapes": ("BOOLEAN", {"default": True}),
"sort_keys": ("BOOLEAN", {"default": True}),
"max_lines": ("INT", {"default": 0, "min": 0, "max": 200000}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "dump"
CATEGORY = "utils/debug"
def dump(self, weights, include_shapes, sort_keys, max_lines):
path = _resolve_vae_path(weights)
if path is None:
raise FileNotFoundError(f"Could not find weights file: {weights}")
sd = comfy.utils.load_torch_file(path, safe_load=True)
keys = list(sd.keys())
if sort_keys:
keys.sort()
lines = []
if include_shapes:
for k in keys:
v = sd[k]
if isinstance(v, torch.Tensor):
lines.append(f"{k}\t{tuple(v.shape)}\t{str(v.dtype)}")
else:
lines.append(f"{k}\t{type(v)}")
else:
lines = keys
if max_lines and len(lines) > max_lines:
head = lines[:max_lines]
head.append(f"... TRUNCATED: total_keys={len(lines)} (showing first {max_lines}) ...")
lines = head
text = "\n".join(lines)
return {"ui": {"text": [text]}, "result": (text,)}
NODE_CLASS_MAPPINGS = {
"LoadTAEF2VAE": LoadTAEF2VAE,
"DumpVAEKeys": DumpVAEKeys,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoadTAEF2VAE": "Load TAEF2 (Flux2 Tiny VAE)",
"DumpVAEKeys": "Dump VAE Keys (as String)",
}