Spaces:
Running on Zero
Running on Zero
| from typing import Dict, Tuple | |
| import torch | |
| from src.Device import Device | |
| from src.Utilities import util | |
| class LatentFormat: | |
| """Base class for latent formats.""" | |
| scale_factor: float = 1.0 | |
| latent_channels: int = 4 | |
| downscale_factor: int = 8 | |
| def process_in(self, latent: torch.Tensor) -> torch.Tensor: | |
| """Scale latent for input.""" | |
| return latent * self.scale_factor | |
| def process_out(self, latent: torch.Tensor) -> torch.Tensor: | |
| """Scale latent for output.""" | |
| return latent / self.scale_factor | |
| class SD15(LatentFormat): | |
| """SD1.5 latent format.""" | |
| latent_channels: int = 4 | |
| def __init__(self, scale_factor: float = 0.18215): | |
| self.scale_factor = scale_factor | |
| self.latent_rgb_factors = [ | |
| [0.3512, 0.2297, 0.3227], [0.3250, 0.4974, 0.2350], | |
| [-0.2829, 0.1762, 0.2721], [-0.2120, -0.2616, -0.7177], | |
| ] | |
| self.taesd_decoder_name = "taesd_decoder" | |
| class SDXL(LatentFormat): | |
| """SDXL latent format.""" | |
| latent_channels: int = 4 | |
| scale_factor = 0.13025 | |
| def __init__(self): | |
| self.latent_rgb_factors = [ | |
| [0.3651, 0.4232, 0.4341], [-0.2533, -0.0042, 0.1068], | |
| [0.1076, 0.1111, -0.0362], [-0.3165, -0.2492, -0.2188], | |
| ] | |
| self.latent_rgb_factors_bias = [0.1084, -0.0175, -0.0011] | |
| self.taesd_decoder_name = "taesdxl_decoder" | |
| class SDXL_Playground_2_5(LatentFormat): | |
| """SDXL Playground 2.5 with mean/std normalization.""" | |
| latent_channels: int = 4 | |
| def __init__(self): | |
| self.scale_factor = 0.5 | |
| self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1) | |
| self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1) | |
| self.latent_rgb_factors = [ | |
| [0.3920, 0.4054, 0.4549], [-0.2634, -0.0196, 0.0653], | |
| [0.0568, 0.1687, -0.0755], [-0.3112, -0.2359, -0.2076], | |
| ] | |
| self.taesd_decoder_name = "taesdxl_decoder" | |
| def process_in(self, latent: torch.Tensor) -> torch.Tensor: | |
| mean = self.latents_mean.to(latent.device, latent.dtype) | |
| std = self.latents_std.to(latent.device, latent.dtype) | |
| return (latent - mean) * self.scale_factor / std | |
| def process_out(self, latent: torch.Tensor) -> torch.Tensor: | |
| mean = self.latents_mean.to(latent.device, latent.dtype) | |
| std = self.latents_std.to(latent.device, latent.dtype) | |
| return latent * std / self.scale_factor + mean | |
| class SD3(LatentFormat): | |
| """SD3 latent format with shift factor.""" | |
| latent_channels = 16 | |
| def __init__(self): | |
| self.scale_factor = 1.5305 | |
| self.shift_factor = 0.0609 | |
| self.latent_rgb_factors = [ | |
| [-0.0645, 0.0177, 0.1052], [0.0028, 0.0312, 0.0650], | |
| [0.1848, 0.0762, 0.0360], [0.0944, 0.0360, 0.0889], | |
| [0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284], | |
| [0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047], | |
| [-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039], | |
| [0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481], | |
| [0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867], | |
| [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259], | |
| ] | |
| self.taesd_decoder_name = "taesd3_decoder" | |
| def process_in(self, latent: torch.Tensor) -> torch.Tensor: | |
| return (latent - self.shift_factor) * self.scale_factor | |
| def process_out(self, latent: torch.Tensor) -> torch.Tensor: | |
| return (latent / self.scale_factor) + self.shift_factor | |
| class Flux1(SD3): | |
| """Flux1 latent format.""" | |
| latent_channels = 16 | |
| def __init__(self): | |
| self.scale_factor = 0.3611 | |
| self.shift_factor = 0.1159 | |
| self.latent_rgb_factors = [ | |
| [-0.0404, 0.0159, 0.0609], [0.0043, 0.0298, 0.0850], | |
| [0.0328, -0.0749, -0.0503], [-0.0245, 0.0085, 0.0549], | |
| [0.0966, 0.0894, 0.0530], [0.0035, 0.0399, 0.0123], | |
| [0.0583, 0.1184, 0.1262], [-0.0191, -0.0206, -0.0306], | |
| [-0.0324, 0.0055, 0.1001], [0.0955, 0.0659, -0.0545], | |
| [-0.0504, 0.0231, -0.0013], [0.0500, -0.0008, -0.0088], | |
| [0.0982, 0.0941, 0.0976], [-0.1233, -0.0280, -0.0897], | |
| [-0.0005, -0.0530, -0.0020], [-0.1273, -0.0932, -0.0680], | |
| ] | |
| self.taesd_decoder_name = "taef1_decoder" | |
| class Flux2(LatentFormat): | |
| """Flux2 (Klein) latent format. | |
| Following ComfyUI's approach: | |
| - VAE shape: 32 channels, 8x downscale | |
| - Transformer shape: 128 channels, 16x downscale | |
| The pipeline works with VAE shape (32ch 8x). | |
| Conversion to Transformer shape is handled internally by the model forward pass. | |
| """ | |
| latent_channels = 32 | |
| downscale_factor = 8 | |
| spacial_downscale_ratio = 8 | |
| def __init__(self): | |
| # No scale/shift for Flux2 (identity transform) | |
| self.scale_factor = 1.0 | |
| self.shift_factor = 0.0 | |
| # RGB factors for latent preview (32 groups of 4 patches) | |
| self.latent_rgb_factors = [ | |
| [0.0058, 0.0113, 0.0073], [0.0495, 0.0443, 0.0836], | |
| [-0.0099, 0.0096, 0.0644], [0.2144, 0.3009, 0.3652], | |
| [0.0166, -0.0039, -0.0054], [0.0157, 0.0103, -0.0160], | |
| [-0.0398, 0.0902, -0.0235], [-0.0052, 0.0095, 0.0109], | |
| [-0.3527, -0.2712, -0.1666], [-0.0301, -0.0356, -0.0180], | |
| [-0.0107, 0.0078, 0.0013], [0.0746, 0.0090, -0.0941], | |
| [0.0156, 0.0169, 0.0070], [-0.0034, -0.0040, -0.0114], | |
| [0.0032, 0.0181, 0.0080], [-0.0939, -0.0008, 0.0186], | |
| [0.0018, 0.0043, 0.0104], [0.0284, 0.0056, -0.0127], | |
| [-0.0024, -0.0022, -0.0030], [0.1207, -0.0026, 0.0065], | |
| [0.0128, 0.0101, 0.0142], [0.0137, -0.0072, -0.0007], | |
| [0.0095, 0.0092, -0.0059], [0.0000, -0.0077, -0.0049], | |
| [-0.0465, -0.0204, -0.0312], [0.0095, 0.0012, -0.0066], | |
| [0.0290, -0.0034, 0.0025], [0.0220, 0.0169, -0.0048], | |
| [-0.0332, -0.0457, -0.0468], [-0.0085, 0.0389, 0.0609], | |
| [-0.0076, 0.0003, -0.0043], [-0.0111, -0.0460, -0.0614], | |
| ] | |
| self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] | |
| self.taesd_decoder_name = None # Flux2 doesn't use TAESD | |
| def process_in(self, latent: torch.Tensor) -> torch.Tensor: | |
| """Identity - no scale/shift for Flux2.""" | |
| return latent | |
| def process_out(self, latent: torch.Tensor) -> torch.Tensor: | |
| """Identity - no scale/shift for Flux2.""" | |
| return latent | |
| def unpatchify_for_vae(self, latent: torch.Tensor) -> torch.Tensor: | |
| """Convert patchified latent (128ch 16x) to VAE format (32ch 8x). | |
| Matches ComfyUI's latent_rgb_factors_reshape exactly. | |
| Args: | |
| latent: [B, 128, H/16, W/16] patchified latent | |
| Returns: | |
| [B, 32, H/8, W/8] VAE-compatible latent | |
| """ | |
| # Match ComfyUI exactly: t.reshape(b, 32, 2, 2, h, w).permute(0, 1, 4, 2, 5, 3).reshape(b, 32, h*2, w*2) | |
| b, c, h, w = latent.shape | |
| latent = latent.reshape(b, 32, 2, 2, h, w) | |
| latent = latent.permute(0, 1, 4, 2, 5, 3) # [B, 32, h, 2, w, 2] | |
| latent = latent.reshape(b, 32, h * 2, w * 2) | |
| return latent | |
| def patchify_from_vae(self, latent: torch.Tensor) -> torch.Tensor: | |
| """Convert VAE format (32ch 8x) to patchified latent (128ch 16x). | |
| This operation requires the spatial dimensions to be even because it | |
| groups each 2x2 spatial block into channel groups. If the incoming | |
| VAE latent has an odd height or width (possible after cropping/resize), | |
| pad the latent on the bottom/right with zeros so the reshape is safe. | |
| Args: | |
| latent: [B, 32, H/8, W/8] VAE-compatible latent | |
| Returns: | |
| [B, 128, H/16, W/16] patchified latent (uses padded dims when needed) | |
| """ | |
| # Reshape: 32 channels * 2*2 patches -> 128 channels | |
| # [B, 32, h*2, w*2] -> [B, 32, h, 2, w, 2] -> [B, 128, h, w] | |
| b, c, h, w = latent.shape | |
| assert c == 32, f"Expected 32 channels, got {c}" | |
| # Pad to even spatial dims so 2x2 grouping is valid. Padding is removed | |
| # later by Flux2.forward (it crops back to the original spatial size). | |
| pad_h = (2 - (h % 2)) % 2 | |
| pad_w = (2 - (w % 2)) % 2 | |
| if pad_h or pad_w: | |
| # pad format: (left, right, top, bottom) | |
| latent = torch.nn.functional.pad(latent, (0, pad_w, 0, pad_h), mode='constant', value=0) | |
| h += pad_h | |
| w += pad_w | |
| latent = latent.reshape(b, 32, h // 2, 2, w // 2, 2) | |
| latent = latent.permute(0, 1, 3, 5, 2, 4) # [B, 32, 2, 2, h//2, w//2] | |
| latent = latent.reshape(b, 128, h // 2, w // 2) | |
| return latent | |
| class EmptyLatentImage: | |
| """Generate empty latent images.""" | |
| def __init__(self): | |
| self.device = Device.intermediate_device() | |
| def generate(self, width: int, height: int, batch_size: int = 1, channels: int = 4) -> Tuple[Dict[str, torch.Tensor]]: | |
| latent = torch.zeros([batch_size, channels, height // 8, width // 8], device=self.device) | |
| return ({"samples": latent},) | |
| def fix_empty_latent_channels(model, latent_image): | |
| """Fix empty latent channels to match model requirements. | |
| Defensive: handles non-tensor inputs, unexpected dimensionality (channel-last | |
| vs channel-first), and MagicMock objects returned by broken/mocked VAEs | |
| in tests. Guarantees a 4-D tensor [B, C, H, W] is returned with the | |
| expected number of channels. | |
| """ | |
| latent_channels = model.get_model_object("latent_format").latent_channels | |
| # Coerce to tensor when possible, otherwise fall back to a sensible zero | |
| # tensor with the required channel count. This avoids TypeErrors from | |
| # torch.count_nonzero when the input is a MagicMock or other exotic type. | |
| logger = __import__('logging').getLogger(__name__) | |
| if not isinstance(latent_image, torch.Tensor): | |
| logger.debug("fix_empty_latent_channels: non-tensor latent_image type=%r repr=%r", type(latent_image), repr(latent_image)[:200]) | |
| try: | |
| latent_image = torch.as_tensor(latent_image) | |
| except Exception: | |
| logger.warning("fix_empty_latent_channels: failed to coerce latent to tensor, returning zeros") | |
| return torch.zeros((1, latent_channels, 64, 64), device=Device.intermediate_device()) | |
| # Normalize dimensionality to 4-D [B, C, H, W] | |
| try: | |
| if latent_image.ndim == 4: | |
| pass | |
| elif latent_image.ndim == 3: | |
| # Try to detect common layouts: [C,H,W], [H,W,C], [B,H,W] | |
| if latent_image.shape[0] == latent_channels: | |
| latent_image = latent_image.unsqueeze(0) | |
| elif latent_image.shape[-1] == latent_channels: | |
| # Assume [H, W, C] | |
| latent_image = latent_image.permute(2, 0, 1).unsqueeze(0) | |
| else: | |
| # Assume [B, H, W] -> add channel dim | |
| latent_image = latent_image.unsqueeze(1) | |
| elif latent_image.ndim == 2: | |
| # [H, W] -> [1, 1, H, W] | |
| latent_image = latent_image.unsqueeze(0).unsqueeze(0) | |
| else: | |
| # 0-D or 1-D -> replace with zeros | |
| latent_image = torch.zeros((1, latent_channels, 64, 64), device=Device.intermediate_device()) | |
| except Exception: | |
| return torch.zeros((1, latent_channels, 64, 64), device=Device.intermediate_device()) | |
| # Safely check channel mismatch and zero content | |
| try: | |
| curr_channels = int(latent_image.shape[1]) | |
| except Exception: | |
| return torch.zeros((1, latent_channels, 64, 64), device=Device.intermediate_device()) | |
| try: | |
| is_zero = (torch.count_nonzero(latent_image) == 0) | |
| except Exception: | |
| # Fall back to a conservative 'empty' assumption | |
| is_zero = True | |
| # If channels don't match and the latent is empty, expand or recreate to | |
| # match the model's expected number of channels. | |
| if curr_channels != latent_channels and is_zero: | |
| # Handle possible channel-last inputs that survived earlier checks | |
| if latent_image.ndim == 4 and latent_image.shape[-1] == curr_channels and latent_image.shape[1] != curr_channels: | |
| latent_image = latent_image.permute(0, 3, 1, 2) | |
| curr_channels = int(latent_image.shape[1]) | |
| if curr_channels == 1: | |
| latent_image = util.repeat_to_batch_size(latent_image, latent_channels, dim=1) | |
| else: | |
| # Create a zero tensor with the expected channel count and preserved spatial dims | |
| try: | |
| batch = int(latent_image.shape[0]) | |
| h = int(latent_image.shape[2]) | |
| w = int(latent_image.shape[3]) | |
| latent_image = torch.zeros((batch, latent_channels, h, w), device=latent_image.device) | |
| except Exception: | |
| latent_image = torch.zeros((1, latent_channels, 64, 64), device=Device.intermediate_device()) | |
| return latent_image | |