from typing import Dict, Tuple import torch from src.Device import Device from src.Utilities import util class LatentFormat: """Base class for latent formats.""" scale_factor: float = 1.0 latent_channels: int = 4 downscale_factor: int = 8 def process_in(self, latent: torch.Tensor) -> torch.Tensor: """Scale latent for input.""" return latent * self.scale_factor def process_out(self, latent: torch.Tensor) -> torch.Tensor: """Scale latent for output.""" return latent / self.scale_factor class SD15(LatentFormat): """SD1.5 latent format.""" latent_channels: int = 4 def __init__(self, scale_factor: float = 0.18215): self.scale_factor = scale_factor self.latent_rgb_factors = [ [0.3512, 0.2297, 0.3227], [0.3250, 0.4974, 0.2350], [-0.2829, 0.1762, 0.2721], [-0.2120, -0.2616, -0.7177], ] self.taesd_decoder_name = "taesd_decoder" class SDXL(LatentFormat): """SDXL latent format.""" latent_channels: int = 4 scale_factor = 0.13025 def __init__(self): self.latent_rgb_factors = [ [0.3651, 0.4232, 0.4341], [-0.2533, -0.0042, 0.1068], [0.1076, 0.1111, -0.0362], [-0.3165, -0.2492, -0.2188], ] self.latent_rgb_factors_bias = [0.1084, -0.0175, -0.0011] self.taesd_decoder_name = "taesdxl_decoder" class SDXL_Playground_2_5(LatentFormat): """SDXL Playground 2.5 with mean/std normalization.""" latent_channels: int = 4 def __init__(self): self.scale_factor = 0.5 self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1) self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1) self.latent_rgb_factors = [ [0.3920, 0.4054, 0.4549], [-0.2634, -0.0196, 0.0653], [0.0568, 0.1687, -0.0755], [-0.3112, -0.2359, -0.2076], ] self.taesd_decoder_name = "taesdxl_decoder" def process_in(self, latent: torch.Tensor) -> torch.Tensor: mean = self.latents_mean.to(latent.device, latent.dtype) std = self.latents_std.to(latent.device, latent.dtype) return (latent - mean) * self.scale_factor / std def process_out(self, latent: torch.Tensor) -> torch.Tensor: mean = self.latents_mean.to(latent.device, latent.dtype) std = self.latents_std.to(latent.device, latent.dtype) return latent * std / self.scale_factor + mean class SD3(LatentFormat): """SD3 latent format with shift factor.""" latent_channels = 16 def __init__(self): self.scale_factor = 1.5305 self.shift_factor = 0.0609 self.latent_rgb_factors = [ [-0.0645, 0.0177, 0.1052], [0.0028, 0.0312, 0.0650], [0.1848, 0.0762, 0.0360], [0.0944, 0.0360, 0.0889], [0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284], [0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047], [-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039], [0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481], [0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867], [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259], ] self.taesd_decoder_name = "taesd3_decoder" def process_in(self, latent: torch.Tensor) -> torch.Tensor: return (latent - self.shift_factor) * self.scale_factor def process_out(self, latent: torch.Tensor) -> torch.Tensor: return (latent / self.scale_factor) + self.shift_factor class Flux1(SD3): """Flux1 latent format.""" latent_channels = 16 def __init__(self): self.scale_factor = 0.3611 self.shift_factor = 0.1159 self.latent_rgb_factors = [ [-0.0404, 0.0159, 0.0609], [0.0043, 0.0298, 0.0850], [0.0328, -0.0749, -0.0503], [-0.0245, 0.0085, 0.0549], [0.0966, 0.0894, 0.0530], [0.0035, 0.0399, 0.0123], [0.0583, 0.1184, 0.1262], [-0.0191, -0.0206, -0.0306], [-0.0324, 0.0055, 0.1001], [0.0955, 0.0659, -0.0545], [-0.0504, 0.0231, -0.0013], [0.0500, -0.0008, -0.0088], [0.0982, 0.0941, 0.0976], [-0.1233, -0.0280, -0.0897], [-0.0005, -0.0530, -0.0020], [-0.1273, -0.0932, -0.0680], ] self.taesd_decoder_name = "taef1_decoder" class Flux2(LatentFormat): """Flux2 (Klein) latent format. Following ComfyUI's approach: - VAE shape: 32 channels, 8x downscale - Transformer shape: 128 channels, 16x downscale The pipeline works with VAE shape (32ch 8x). Conversion to Transformer shape is handled internally by the model forward pass. """ latent_channels = 32 downscale_factor = 8 spacial_downscale_ratio = 8 def __init__(self): # No scale/shift for Flux2 (identity transform) self.scale_factor = 1.0 self.shift_factor = 0.0 # RGB factors for latent preview (32 groups of 4 patches) self.latent_rgb_factors = [ [0.0058, 0.0113, 0.0073], [0.0495, 0.0443, 0.0836], [-0.0099, 0.0096, 0.0644], [0.2144, 0.3009, 0.3652], [0.0166, -0.0039, -0.0054], [0.0157, 0.0103, -0.0160], [-0.0398, 0.0902, -0.0235], [-0.0052, 0.0095, 0.0109], [-0.3527, -0.2712, -0.1666], [-0.0301, -0.0356, -0.0180], [-0.0107, 0.0078, 0.0013], [0.0746, 0.0090, -0.0941], [0.0156, 0.0169, 0.0070], [-0.0034, -0.0040, -0.0114], [0.0032, 0.0181, 0.0080], [-0.0939, -0.0008, 0.0186], [0.0018, 0.0043, 0.0104], [0.0284, 0.0056, -0.0127], [-0.0024, -0.0022, -0.0030], [0.1207, -0.0026, 0.0065], [0.0128, 0.0101, 0.0142], [0.0137, -0.0072, -0.0007], [0.0095, 0.0092, -0.0059], [0.0000, -0.0077, -0.0049], [-0.0465, -0.0204, -0.0312], [0.0095, 0.0012, -0.0066], [0.0290, -0.0034, 0.0025], [0.0220, 0.0169, -0.0048], [-0.0332, -0.0457, -0.0468], [-0.0085, 0.0389, 0.0609], [-0.0076, 0.0003, -0.0043], [-0.0111, -0.0460, -0.0614], ] self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] self.taesd_decoder_name = None # Flux2 doesn't use TAESD def process_in(self, latent: torch.Tensor) -> torch.Tensor: """Identity - no scale/shift for Flux2.""" return latent def process_out(self, latent: torch.Tensor) -> torch.Tensor: """Identity - no scale/shift for Flux2.""" return latent def unpatchify_for_vae(self, latent: torch.Tensor) -> torch.Tensor: """Convert patchified latent (128ch 16x) to VAE format (32ch 8x). Matches ComfyUI's latent_rgb_factors_reshape exactly. Args: latent: [B, 128, H/16, W/16] patchified latent Returns: [B, 32, H/8, W/8] VAE-compatible latent """ # Match ComfyUI exactly: t.reshape(b, 32, 2, 2, h, w).permute(0, 1, 4, 2, 5, 3).reshape(b, 32, h*2, w*2) b, c, h, w = latent.shape latent = latent.reshape(b, 32, 2, 2, h, w) latent = latent.permute(0, 1, 4, 2, 5, 3) # [B, 32, h, 2, w, 2] latent = latent.reshape(b, 32, h * 2, w * 2) return latent def patchify_from_vae(self, latent: torch.Tensor) -> torch.Tensor: """Convert VAE format (32ch 8x) to patchified latent (128ch 16x). This operation requires the spatial dimensions to be even because it groups each 2x2 spatial block into channel groups. If the incoming VAE latent has an odd height or width (possible after cropping/resize), pad the latent on the bottom/right with zeros so the reshape is safe. Args: latent: [B, 32, H/8, W/8] VAE-compatible latent Returns: [B, 128, H/16, W/16] patchified latent (uses padded dims when needed) """ # Reshape: 32 channels * 2*2 patches -> 128 channels # [B, 32, h*2, w*2] -> [B, 32, h, 2, w, 2] -> [B, 128, h, w] b, c, h, w = latent.shape assert c == 32, f"Expected 32 channels, got {c}" # Pad to even spatial dims so 2x2 grouping is valid. Padding is removed # later by Flux2.forward (it crops back to the original spatial size). pad_h = (2 - (h % 2)) % 2 pad_w = (2 - (w % 2)) % 2 if pad_h or pad_w: # pad format: (left, right, top, bottom) latent = torch.nn.functional.pad(latent, (0, pad_w, 0, pad_h), mode='constant', value=0) h += pad_h w += pad_w latent = latent.reshape(b, 32, h // 2, 2, w // 2, 2) latent = latent.permute(0, 1, 3, 5, 2, 4) # [B, 32, 2, 2, h//2, w//2] latent = latent.reshape(b, 128, h // 2, w // 2) return latent class EmptyLatentImage: """Generate empty latent images.""" def __init__(self): self.device = Device.intermediate_device() def generate(self, width: int, height: int, batch_size: int = 1, channels: int = 4) -> Tuple[Dict[str, torch.Tensor]]: latent = torch.zeros([batch_size, channels, height // 8, width // 8], device=self.device) return ({"samples": latent},) def fix_empty_latent_channels(model, latent_image): """Fix empty latent channels to match model requirements. Defensive: handles non-tensor inputs, unexpected dimensionality (channel-last vs channel-first), and MagicMock objects returned by broken/mocked VAEs in tests. Guarantees a 4-D tensor [B, C, H, W] is returned with the expected number of channels. """ latent_channels = model.get_model_object("latent_format").latent_channels # Coerce to tensor when possible, otherwise fall back to a sensible zero # tensor with the required channel count. This avoids TypeErrors from # torch.count_nonzero when the input is a MagicMock or other exotic type. logger = __import__('logging').getLogger(__name__) if not isinstance(latent_image, torch.Tensor): logger.debug("fix_empty_latent_channels: non-tensor latent_image type=%r repr=%r", type(latent_image), repr(latent_image)[:200]) try: latent_image = torch.as_tensor(latent_image) except Exception: logger.warning("fix_empty_latent_channels: failed to coerce latent to tensor, returning zeros") return torch.zeros((1, latent_channels, 64, 64), device=Device.intermediate_device()) # Normalize dimensionality to 4-D [B, C, H, W] try: if latent_image.ndim == 4: pass elif latent_image.ndim == 3: # Try to detect common layouts: [C,H,W], [H,W,C], [B,H,W] if latent_image.shape[0] == latent_channels: latent_image = latent_image.unsqueeze(0) elif latent_image.shape[-1] == latent_channels: # Assume [H, W, C] latent_image = latent_image.permute(2, 0, 1).unsqueeze(0) else: # Assume [B, H, W] -> add channel dim latent_image = latent_image.unsqueeze(1) elif latent_image.ndim == 2: # [H, W] -> [1, 1, H, W] latent_image = latent_image.unsqueeze(0).unsqueeze(0) else: # 0-D or 1-D -> replace with zeros latent_image = torch.zeros((1, latent_channels, 64, 64), device=Device.intermediate_device()) except Exception: return torch.zeros((1, latent_channels, 64, 64), device=Device.intermediate_device()) # Safely check channel mismatch and zero content try: curr_channels = int(latent_image.shape[1]) except Exception: return torch.zeros((1, latent_channels, 64, 64), device=Device.intermediate_device()) try: is_zero = (torch.count_nonzero(latent_image) == 0) except Exception: # Fall back to a conservative 'empty' assumption is_zero = True # If channels don't match and the latent is empty, expand or recreate to # match the model's expected number of channels. if curr_channels != latent_channels and is_zero: # Handle possible channel-last inputs that survived earlier checks if latent_image.ndim == 4 and latent_image.shape[-1] == curr_channels and latent_image.shape[1] != curr_channels: latent_image = latent_image.permute(0, 3, 1, 2) curr_channels = int(latent_image.shape[1]) if curr_channels == 1: latent_image = util.repeat_to_batch_size(latent_image, latent_channels, dim=1) else: # Create a zero tensor with the expected channel count and preserved spatial dims try: batch = int(latent_image.shape[0]) h = int(latent_image.shape[2]) w = int(latent_image.shape[3]) latent_image = torch.zeros((batch, latent_channels, h, w), device=latent_image.device) except Exception: latent_image = torch.zeros((1, latent_channels, 64, 64), device=Device.intermediate_device()) return latent_image