"""Flux2 Klein model adapter for LightDiffusion-Next. Provides a clean interface to the Flux2 Klein 4B model that inherits from AbstractModel and integrates with the LightDiffusion-Next model factory. This implementation uses ONLY native LightDiffusion-Next components, without any ComfyUI imports. File structure expected: - include/diffusion_model/flux-2-klein-4b.safetensors (or similar) - include/text_encoder/qwen_3_4b.safetensors - include/vae/ae.safetensors (Flux VAE) """ import logging import os from typing import TYPE_CHECKING, Any, Callable, Optional import torch from src.Core.AbstractModel import AbstractModel, ModelCapabilities from src.Utilities import util from src.Device import Device # Import modules that were previously lazy-loaded inside methods # This avoids KeyError: 'src' when running via uv run streamlit from src.NeuralNetwork.flux2.model import Flux2, Flux2Params from src.Model.ModelPatcher import ModelPatcher from src.clip.KleinEncoder import KleinCLIP, Qwen3_4BModel from src.AutoEncoders import VariationalAE from src.sample import sampling from src.Utilities import Latent from src.Model import LoRas if TYPE_CHECKING: from src.Core.Context import Context logger = logging.getLogger(__name__) # Default paths for Flux2 Klein components DEFAULT_DIFFUSION_MODEL_DIR = "./include/diffusion_model" DEFAULT_TEXT_ENCODER_DIR = "./include/text_encoder" DEFAULT_VAE_DIR = "./include/vae" class Flux2KleinModel(AbstractModel): """Flux2 Klein 4B model implementation. Wraps the Flux2 Klein model with the clean AbstractModel interface for use with the LightDiffusion-Next pipeline system. The Flux2 Klein model is a distilled version of the Flux2 architecture using the Klein (Qwen3 4B) text encoder. Unlike SD1.5/SDXL which use combined checkpoints, Flux2 Klein loads components separately: - Diffusion model from include/diffusion_model/ - Text encoder (Qwen3 4B) from include/text_encoder/ - VAE from include/vae/ """ def __init__( self, model_path: str = None, text_encoder_path: str = None, vae_path: str = None, quantization: str = None, # "fp8", "nvfp4", or None ): """Initialize the Flux2 Klein model adapter. Args: model_path: Path to diffusion model (safetensors) text_encoder_path: Path to Qwen3 text encoder (optional, auto-detected) vae_path: Path to VAE (optional, auto-detected) quantization: Quantization format to use ("fp8", "nvfp4", or None) """ super().__init__(model_path) self._text_encoder = None self._tokenizer = None self._model_config = None self._text_encoder_path = text_encoder_path self._vae_path = vae_path self._raw_model = None # The raw Flux2 nn.Module self.quantization = quantization # Device management self.load_device = Device.get_torch_device() self.offload_device = torch.device("cpu") def _create_capabilities(self) -> ModelCapabilities: """Create capabilities for Flux2 Klein model.""" return ModelCapabilities( min_resolution=256, max_resolution=4096, preferred_resolution=1024, requires_resolution_multiple=16, # Flux2 uses 16-pixel patches supports_hires_fix=True, supports_img2img=True, supports_inpainting=False, # Not yet implemented for Flux2 supports_controlnet=False, # ControlNet support pending supports_stable_fast=False, # May need special handling supports_deepcache=False, # Architecture differs from UNet supports_tome=False, # Token merging needs special implementation supports_lora=False, # Flux2 LoRA format differs from SD uses_dual_clip=False, # Uses single Klein (Qwen3) encoder requires_size_conditioning=False, is_flux=True, is_flux2=True, ) def _find_diffusion_model(self) -> Optional[str]: """Auto-detect Flux2 diffusion model in default directory.""" if os.path.exists(DEFAULT_DIFFUSION_MODEL_DIR): for f in os.listdir(DEFAULT_DIFFUSION_MODEL_DIR): f_lower = f.lower() if ("flux" in f_lower or "klein" in f_lower) and f.endswith((".safetensors", ".pt", ".pth")): return os.path.join(DEFAULT_DIFFUSION_MODEL_DIR, f) return None def _find_text_encoder(self) -> Optional[str]: """Auto-detect Qwen3 text encoder in default directory.""" if os.path.exists(DEFAULT_TEXT_ENCODER_DIR): for f in os.listdir(DEFAULT_TEXT_ENCODER_DIR): f_lower = f.lower() if ("qwen" in f_lower or "klein" in f_lower) and f.endswith((".safetensors", ".pt", ".pth")): return os.path.join(DEFAULT_TEXT_ENCODER_DIR, f) return None def _find_vae(self) -> Optional[str]: """Auto-detect VAE in default directory.""" if os.path.exists(DEFAULT_VAE_DIR): # Look for Flux-compatible VAE (ae.safetensors) for f in os.listdir(DEFAULT_VAE_DIR): if f.endswith((".safetensors", ".pt", ".pth")): return os.path.join(DEFAULT_VAE_DIR, f) return None def load(self, model_path: str = None) -> "Flux2KleinModel": """Load the Flux2 Klein model components from disk. Components are loaded separately: - Diffusion model (Flux2 transformer) - Text encoder (Qwen3 4B via Klein tokenizer) - VAE Args: model_path: Optional override for the diffusion model path Returns: Self for method chaining """ # Resolve paths diffusion_path = model_path or self.model_path or self._find_diffusion_model() # Guard: Don't reload if already loaded with same diffusion model if self._loaded and self.model_path == diffusion_path: logger.info("Flux2KleinModel: Already loaded, skipping redundant load") return self if diffusion_path is None: raise ValueError( "No Flux2 diffusion model found. Please place the model in " f"{DEFAULT_DIFFUSION_MODEL_DIR}/ with 'flux' or 'klein' in the filename." ) self.model_path = diffusion_path # Resolve other paths only when loading is actually needed text_encoder_path = self._text_encoder_path or self._find_text_encoder() vae_path = self._vae_path or self._find_vae() logger.info(f"Flux2KleinModel: Loading components...") logger.info(f" Diffusion model: {diffusion_path}") logger.info(f" Text encoder: {text_encoder_path}") logger.info(f" VAE: {vae_path}") try: # Load diffusion model # self.model = self._load_diffusion_model(diffusion_path) # Original line # New FP8 loading logic from src.NeuralNetwork.flux2.model import create_flux2_klein from src.Device import Device from src.FileManaging import Loader # Check for FP8 support and user preference/environment use_fp8 = Device.is_fp8_supported(self.load_device) # For 8GB cards, we force FP8 for Flux2 Klein 4B to avoid swapping total_vram = Device.get_total_memory(self.load_device) / (1024**3) if total_vram < 12.0: # If less than 12GB, FP8 is highly recommended for Flux use_fp8 = use_fp8 and True dtype = torch.bfloat16 # Base weight dtype # Create model with detected config config = self._detect_flux2_config(util.load_torch_file(diffusion_path, device=torch.device("cpu"))) # Load temporarily to detect config params = Flux2Params(**config) self.model = Flux2(params=params, dtype=dtype, device=torch.device("cpu")) # Create on CPU first self.model.eval() # Attach config for compatibility self._model_config = self._create_model_config() # Ensure _model_config is set # Load weights sd = util.load_torch_file(diffusion_path, device=self.offload_device) # Sanitize NaN values in weights (some Flux2 checkpoints have NaN biases) nan_keys = [] for key, value in sd.items(): if isinstance(value, torch.Tensor) and torch.isnan(value).any(): nan_keys.append(key) sd[key] = torch.where(torch.isnan(value), torch.zeros_like(value), value) if nan_keys: logger.warning(f"Sanitized NaN values in {len(nan_keys)} keys: {nan_keys[:5]}...") self.model.load_state_dict(sd, strict=False) del sd self._raw_model = self.model # Store raw model # Create ModelPatcher self.model = ModelPatcher(self.model, self.load_device, self.offload_device) # Apply quantization if requested or needed quant_format = self.quantization if quant_format is None and use_fp8: quant_format = "fp8" if quant_format == "nvfp4": logging.info("Flux2: Applying NVFP4 (4-bit) weight-only quantization") self.model.weight_only_quantize("nvfp4") self.model.model_dtype = lambda: torch.float16 # Compute in FP16 for dequantization elif quant_format == "fp8": logging.info("Flux2: Applying FP8 weight-only quantization") self.model.weight_only_quantize(torch.float8_e4m3fn) self.model.model_dtype = lambda: torch.float8_e4m3fn # Override # Load text encoder if text_encoder_path: self.clip = self._load_klein_text_encoder(text_encoder_path, quantize=quant_format) self._text_encoder = self.clip # For internal reference self._tokenizer = self.clip.tokenizer else: logger.warning("No Qwen3 text encoder found - prompt encoding may fail") self.clip = None # Load VAE if vae_path: self.vae = self._load_vae(vae_path) else: logger.warning("No VAE found - image decoding may fail") self.vae = None # Store config for sampling self._model_config = self._create_model_config() # Attach model_sampling for sampler infrastructure from src.sample import sampling self.model.model_sampling = sampling.model_sampling(self._model_config, "flux2", flux=True, flux2=True) self._loaded = True logger.info(f"Flux2KleinModel: Successfully loaded all components") except Exception as e: logger.exception(f"Flux2KleinModel: Failed to load: {e}") raise return self def _load_diffusion_model(self, path: str): """Load the Flux2 diffusion model using native LightDiffusion-Next. Args: path: Path to diffusion model safetensors Returns: ModelPatcher wrapping the Flux2 model """ logger.info(f"Loading Flux2 diffusion model: {path}") # Load state dict using native utility sd = util.load_torch_file(path) # Sanitize NaN values in weights (some Flux2 checkpoints have NaN biases) nan_keys = [] for key, value in sd.items(): if isinstance(value, torch.Tensor) and torch.isnan(value).any(): nan_keys.append(key) sd[key] = torch.where(torch.isnan(value), torch.zeros_like(value), value) if nan_keys: logger.warning(f"Sanitized NaN values in {len(nan_keys)} keys: {nan_keys[:5]}...") # Detect model configuration from state dict config = self._detect_flux2_config(sd) # Determine dtype and device load_device = Device.get_torch_device() offload_device = Device.unet_offload_device() # Infer dtype from weights dtype = torch.bfloat16 for k, v in sd.items(): if isinstance(v, torch.Tensor) and v.dtype in (torch.float16, torch.bfloat16, torch.float32): dtype = v.dtype break logger.info(f"Flux2 model dtype: {dtype}") # Create model with detected config params = Flux2Params(**config) model = Flux2(params=params, dtype=dtype, device="cpu") # Attach config for compatibility model.model_config = self._create_model_config() # Load weights missing, unexpected = model.load_state_dict(sd, strict=False) if missing: logger.debug(f"Missing keys: {len(missing)}") if unexpected: logger.debug(f"Unexpected keys: {len(unexpected)}") self._raw_model = model # Wrap in ModelPatcher for compatibility with sampling infrastructure model_patcher = ModelPatcher.ModelPatcher( model, load_device=load_device, offload_device=offload_device, current_device=torch.device("cpu"), ) return model_patcher def _detect_flux2_config(self, sd: dict) -> dict: """Detect Flux2 model configuration from state dict. Args: sd: Model state dictionary Returns: Configuration dict for Flux2Params """ # Detect if this is Flux2 (has double_stream_modulation) or Flux1 is_flux2 = any("double_stream_modulation" in k for k in sd.keys()) if is_flux2: # Flux2 / Klein defaults (patch_size=1 unlike Flux1!) config = { "patch_size": 1, # CRITICAL: Flux2 uses patch_size=1 (no spatial patchification) "in_channels": 128, # Direct channel input (no patch_size division) "out_channels": 128, # Direct channel output "vec_in_dim": 768, "context_in_dim": 7680, # Klein uses concatenated multi-layer output "hidden_size": 3072, "mlp_ratio": 3.0, # Klein uses 3.0 with gated MLP "num_heads": 24, # Flux2: hidden_size/sum(axes_dim) = 3072/128 = 24 "depth": 19, "depth_single_blocks": 38, "axes_dim": [32, 32, 32, 32], # Flux2 specific - sum=128 "theta": 2000, # Flux2 uses lower theta "qkv_bias": False, "guidance_embed": False, "gated_mlp": True, # Klein uses gated MLP (SwiGLU) "global_modulation": True, # Flux2 feature "mlp_silu_act": True, # Flux2 feature "ops_bias": False, # Flux2 feature "use_vector_in": False, # Flux2/Klein doesn't use pooled conditioning } logger.info("Detected Flux2 model (has double_stream_modulation)") else: # Flux1 defaults config = { "in_channels": 16, "out_channels": 16, "vec_in_dim": 768, "context_in_dim": 7680, "hidden_size": 3072, "mlp_ratio": 4.0, "num_heads": 24, "depth": 19, "depth_single_blocks": 38, "axes_dim": [16, 56, 56], # Flux1 specific "theta": 10000, "qkv_bias": True, "guidance_embed": True, "gated_mlp": False, } logger.info("Detected Flux1 model") # Detect depth from double_blocks double_blocks = [k for k in sd.keys() if "double_blocks" in k] if double_blocks: max_block = max( int(k.split("double_blocks.")[1].split(".")[0]) for k in double_blocks if "double_blocks." in k ) config["depth"] = max_block + 1 # Detect single blocks depth single_blocks = [k for k in sd.keys() if "single_blocks" in k] if single_blocks: max_single = max( int(k.split("single_blocks.")[1].split(".")[0]) for k in single_blocks if "single_blocks." in k ) config["depth_single_blocks"] = max_single + 1 # Detect hidden size and in_channels from img_in if "img_in.weight" in sd: config["hidden_size"] = sd["img_in.weight"].shape[0] # img_in input dim = in_channels * patch_size^2 # For Flux2 with patch_size=1: in_channels = img_in_dim directly img_in_dim = sd["img_in.weight"].shape[1] patch_size = config.get("patch_size", 2) config["in_channels"] = img_in_dim // (patch_size ** 2) logger.info(f"Detected in_channels={config['in_channels']} from img_in (patch_size={patch_size})") # Detect out_channels from final_layer if "final_layer.linear.weight" in sd: # final_layer.linear maps hidden -> patch_size * patch_size * out_channels # For Flux2 with patch_size=1: out_channels = final.shape[0] directly final_out = sd["final_layer.linear.weight"].shape[0] patch_size = config.get("patch_size", 2) config["out_channels"] = final_out // (patch_size ** 2) logger.info(f"Detected out_channels={config['out_channels']} from final_layer") # Detect mlp_ratio and gated_mlp from double_blocks MLP weights # For gated MLP: img_mlp.0 maps hidden -> 2*intermediate (gate+up) # img_mlp.2 maps intermediate -> hidden # So: mlp_0_out = 2 * intermediate, intermediate = mlp_2_in # mlp_ratio = intermediate / hidden if "double_blocks.0.img_mlp.0.weight" in sd and "double_blocks.0.img_mlp.2.weight" in sd: mlp_0_out = sd["double_blocks.0.img_mlp.0.weight"].shape[0] mlp_2_in = sd["double_blocks.0.img_mlp.2.weight"].shape[1] hidden = config["hidden_size"] # Check if it's gated MLP: mlp_0_out should be 2 * mlp_2_in if abs(mlp_0_out - 2 * mlp_2_in) < 10: # Small tolerance # Gated MLP detected config["gated_mlp"] = True intermediate = mlp_2_in config["mlp_ratio"] = intermediate / hidden logger.info(f"Detected gated MLP: intermediate={intermediate}, mlp_ratio={config['mlp_ratio']}") else: # Standard MLP: mlp_0_out = mlp_2_in = hidden * mlp_ratio config["gated_mlp"] = False config["mlp_ratio"] = mlp_0_out / hidden # Calculate num_heads from hidden_size and axes_dim (ComfyUI approach) # num_heads = hidden_size // sum(axes_dim) axes_sum = sum(config["axes_dim"]) config["num_heads"] = config["hidden_size"] // axes_sum logger.info(f"Calculated num_heads={config['num_heads']} from hidden_size={config['hidden_size']} / axes_sum={axes_sum}") # Detect context_in_dim from txt_in if "txt_in.weight" in sd: config["context_in_dim"] = sd["txt_in.weight"].shape[1] # Detect vec_in_dim from vector_in if "vector_in.in_layer.weight" in sd: config["vec_in_dim"] = sd["vector_in.in_layer.weight"].shape[1] config["use_vector_in"] = True # Enable vector_in if weights exist logger.info(f"Detected vector_in with dim {config['vec_in_dim']}") # Detect guidance embedding if any("guidance_in" in k for k in sd.keys()): config["guidance_embed"] = True # Detect txt_norm (critical for some Flux2 variants) if any("txt_norm.scale" in k for k in sd.keys()): config["txt_norm"] = True logger.info("Detected txt_norm in model weights") logger.info(f"Detected Flux2 config: depth={config['depth']}, " f"single_blocks={config['depth_single_blocks']}, " f"hidden={config['hidden_size']}, mlp_ratio={config['mlp_ratio']}, " f"gated_mlp={config.get('gated_mlp', False)}") return config def _load_klein_text_encoder(self, path: str, quantize: str = None): """Load the Klein (Qwen3-4B) text encoder. Args: path: Path to text encoder safetensors quantize: Quantization format ("fp8", "nvfp4", or None) Returns: KleinCLIP wrapper """ logger.info(f"Loading Text Encoder: {path}") from src.clip.KleinEncoder import KleinCLIP, KleinTokenizer, Qwen3_4BModel, get_ops from src.Model.ModelPatcher import ModelPatcher # Determine paths sd_path = path tokenizer_path = os.path.join(os.path.dirname(path), "qwen25_tokenizer") if not os.path.exists(tokenizer_path): tokenizer_path = None # Let KleinTokenizer find its default # Load weights sd = util.load_torch_file(sd_path, device=torch.device("cpu")) # Create model structure # Base dtype is BF16 dtype = torch.bfloat16 model = Qwen3_4BModel(dtype=dtype, device="cpu") # Load state dict model_sd = {} for k, v in sd.items(): if k.startswith("model."): model_sd[k[6:]] = v else: model_sd[k] = v missing, unexpected = model.load_state_dict(model_sd, strict=False) # Apply quantization BEFORE moving to offload device if requested if quantize: logger.info(f"Flux2KleinModel: Quantizing Klein (Qwen3-4B) to {quantize}") # We must use ModelPatcher to correctly update comfy_cast_weights flags te_patcher = ModelPatcher(model, self.load_device, self.offload_device) if quantize == "nvfp4": te_patcher.weight_only_quantize("nvfp4") else: te_patcher.weight_only_quantize(torch.float8_e4m3fn) model = te_patcher.model # IMPORTANT: Keep model on CPU to save VRAM for diffusion model offload_device = Device.text_encoder_offload_device() model = model.to(offload_device) # Create wrapper tokenizer = KleinTokenizer(tokenizer_path) clip = KleinCLIP(tokenizer=tokenizer, model=model, dtype=dtype, device=self.load_device, offload_device=offload_device) return clip def _load_vae(self, path: str): """Load the VAE for decoding latents using native LightDiffusion-Next. Following ComfyUI's VAE loading approach: - Detects z_channels from decoder.conv_in.weight.shape[1] - Uses post_quant_conv/quant_conv (flux=False) for standard VAE structure Args: path: Path to VAE safetensors Returns: VAE model """ logger.info(f"Loading VAE: {path}") # Load state dict sd = util.load_torch_file(path) # Check for diffusers format and convert if needed (ComfyUI approach) if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd: logger.info("Converting diffusers VAE format to SD format") sd = self._convert_diffusers_vae(sd) # Log VAE structure is_flux_vae = False if 'decoder.conv_in.weight' in sd: z_ch = sd['decoder.conv_in.weight'].shape[1] logger.info(f"VAE z_channels: {z_ch}") if 'post_quant_conv.weight' in sd: embed_dim = sd['post_quant_conv.weight'].shape[1] logger.info(f"VAE embed_dim: {embed_dim} (Standard VAE)") is_flux_vae = False else: logger.info("VAE missing post_quant_conv (Flux VAE)") is_flux_vae = True # Create VAE using native implementation # Set flux=True if it's a Flux VAE (skips post_quant_conv) # Use bfloat16 for better precision/memory balance on modern GPUs vae = VariationalAE.VAE(sd=sd, flux=is_flux_vae, dtype=torch.bfloat16) return vae def _convert_diffusers_vae(self, sd: dict) -> dict: """Convert diffusers VAE format to SD format (ComfyUI approach).""" # VAE conversion map from ComfyUI's diffusers_convert.py vae_conversion_map = [ ("nin_shortcut", "conv_shortcut"), ("norm_out", "conv_norm_out"), ("mid.attn_1.", "mid_block.attentions.0."), ] for i in range(4): for j in range(2): hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." sd_down_prefix = f"encoder.down.{i}.block.{j}." vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) if i < 3: hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." sd_downsample_prefix = f"down.{i}.downsample." vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." sd_upsample_prefix = f"up.{3 - i}.upsample." vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) for j in range(3): hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." sd_up_prefix = f"decoder.up.{3 - i}.block.{j}." vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) for i in range(2): hf_mid_res_prefix = f"mid_block.resnets.{i}." sd_mid_res_prefix = f"mid.block_{i + 1}." vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) vae_conversion_map_attn = [ ("norm.", "group_norm."), ("q.", "query."), ("k.", "key."), ("v.", "value."), ("q.", "to_q."), ("k.", "to_k."), ("v.", "to_v."), ("proj_out.", "to_out.0."), ("proj_out.", "proj_attn."), ] mapping = {k: k for k in sd.keys()} for k, v in mapping.items(): for sd_part, hf_part in vae_conversion_map: v = v.replace(hf_part, sd_part) mapping[k] = v for k, v in mapping.items(): if "attentions" in k: for sd_part, hf_part in vae_conversion_map_attn: v = v.replace(hf_part, sd_part) mapping[k] = v new_state_dict = {v: sd[k] for k, v in mapping.items()} # Reshape attention weights weights_to_convert = ["q", "k", "v", "proj_out"] for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: new_state_dict[k] = v.reshape(*v.shape, 1, 1) return new_state_dict def _create_model_config(self): """Create a model config object for sampling.""" class Flux2KleinConfig: """Configuration for Flux2 Klein sampling.""" sampling_settings = { "shift": 2.02, # Flux2 default shift (different from Flux1's 1.15) } latent_format = Latent.Flux2() recommended_steps = 4 recommended_cfg = 1.0 return Flux2KleinConfig() def encode_prompt( self, prompt: str | list[str], negative_prompt: str | list[str] = "", clip_skip: int = None, ) -> tuple[Any, Any]: """Encode text prompts into conditioning tensors. For Flux2 Klein, this uses the Qwen3-based Klein text encoder which does not use traditional CLIP skip. CRITICAL: ComfyUI LEFT-PADS text embeddings to 512 tokens before passing to the diffusion model. This is essential for matching image quality because: 1. The positional encoding (RoPE) depends on sequence length 2. The model was trained with fixed 512-token text sequences Args: prompt: Positive prompt(s) to encode negative_prompt: Negative prompt(s) (may be ignored for Flux2) clip_skip: Not used for Klein encoder Returns: Tuple of (positive_conditioning, negative_conditioning) """ if not self._loaded: raise RuntimeError("Model must be loaded before encoding prompts") if self.clip is None: raise RuntimeError("No text encoder loaded") try: import torch # Use Klein encoder directly if isinstance(prompt, list): # Encode each prompt in the batch all_hidden = [] all_pooled = [] for p in prompt: tokens = self.clip.tokenizer.tokenize_with_weights(p) h, pol, _ = self.clip.encode_token_weights(tokens) all_hidden.append(h) # Handle cases where pooled output might be None (common in Klein/Qwen encoders) if pol is not None: all_pooled.append(pol) hidden_states = torch.cat(all_hidden, dim=0) pooled = torch.cat(all_pooled, dim=0) if all_pooled else None else: # Single prompt tokens = self.clip.tokenizer.tokenize_with_weights(prompt) hidden_states, pooled, extra = self.clip.encode_token_weights(tokens) # Encode negative (or empty) neg_prompt = negative_prompt if neg_prompt: if isinstance(neg_prompt, list): # We usually only need one negative for the whole batch or match batch size if len(neg_prompt) == 1: neg_prompt = neg_prompt[0] else: # Encode all negatives all_neg_hidden = [] all_neg_pooled = [] for np in neg_prompt: ntokens = self.clip.tokenizer.tokenize_with_weights(np) nh, npol, _ = self.clip.encode_token_weights(ntokens) all_neg_hidden.append(nh) if npol is not None: all_neg_pooled.append(npol) neg_hidden = torch.cat(all_neg_hidden, dim=0) neg_pooled = torch.cat(all_neg_pooled, dim=0) if all_neg_pooled else None neg_prompt = None # Mark as processed if neg_prompt is not None: neg_tokens = self.clip.tokenizer.tokenize_with_weights(neg_prompt or "") neg_hidden, neg_pooled, neg_extra = self.clip.encode_token_weights(neg_tokens) # Embeddings are already padded to 512 tokens by the tokenizer # Format as conditioning # Note: ComfyUI does NOT pass attention_mask to diffusion model for Flux2 # The zero-padded tokens don't contribute meaningfully to cross-attention cond_dict = {"pooled_output": pooled} positive = [[hidden_states, cond_dict]] neg_cond_dict = {"pooled_output": neg_pooled} negative = [[neg_hidden, neg_cond_dict]] return positive, negative except Exception as e: logger.exception(f"Prompt encoding failed: {e}") raise def generate( self, ctx: "Context", positive: Any, negative: Any, latent_image: Optional[Any] = None, start_step: Optional[int] = None, last_step: Optional[int] = None, disable_noise: bool = False, callback: Optional[Callable] = None, ) -> dict: """Generate latents using the Flux2 sampler. Args: ctx: Context with generation parameters positive: Positive conditioning negative: Negative conditioning (may be ignored) Returns: Dictionary with 'samples' key containing generated latents """ if not self._loaded: raise RuntimeError("Model must be loaded before generating") # Log recommendation if CFG is high for this distilled model if ctx.sampling.cfg > 2.0: logger.info(f"Tip: Flux2 Klein works best with CFG 1.0. " f"You are currently using CFG {ctx.sampling.cfg}.") try: # Use provided latent or create empty one for Flux2 if latent_image is not None: latent = latent_image else: latent = self._create_flux2_latent( ctx.width, ctx.height, ctx.generation.batch, ) # Add seeds for deterministic noise latent["seeds"] = ctx.seeds[:ctx.generation.batch] if ctx.seeds else [ctx.seed] # CRITICAL: Force-disable multi-scale for Flux2 models # Multi-scale is designed for UNet architectures (SD1.5/SDXL) and # causes significant performance overhead for Flux2's DiT architecture enable_multiscale = False # Always disable for Flux2 if ctx.sampling.enable_multiscale: logger.info("Multi-scale disabled: not compatible with Flux2 architecture") # Run sampling with flux=True AND flux2=True for resolution-aware scheduler ksampler = sampling.KSampler() result = ksampler.sample( seed=ctx.seed, steps=ctx.sampling.steps, cfg=ctx.sampling.cfg, sampler_name=ctx.sampling.sampler, scheduler=ctx.sampling.scheduler, denoise=ctx.sampling.denoise, pipeline=True, model=self.model, positive=positive, negative=negative, latent_image=latent, start_step=start_step, last_step=last_step, disable_noise=disable_noise, callback=callback or ctx.callback, flux=True, # Enable Flux sampling mode flux2=True, # Enable Flux2-specific resolution-aware scheduler (matches ComfyUI Flux2Scheduler) enable_multiscale=enable_multiscale, # Force disabled for Flux2 multiscale_factor=ctx.sampling.multiscale_factor, multiscale_fullres_start=ctx.sampling.multiscale_fullres_start, multiscale_fullres_end=ctx.sampling.multiscale_fullres_end, multiscale_intermittent_fullres=ctx.sampling.multiscale_intermittent_fullres, cfg_free_enabled=ctx.sampling.cfg_free_enabled, cfg_free_start_percent=ctx.sampling.cfg_free_start_percent, batched_cfg=ctx.sampling.batched_cfg, dynamic_cfg_rescaling=ctx.sampling.dynamic_cfg_rescaling, dynamic_cfg_method=ctx.sampling.dynamic_cfg_method, dynamic_cfg_percentile=ctx.sampling.dynamic_cfg_percentile, dynamic_cfg_target_scale=ctx.sampling.dynamic_cfg_target_scale, adaptive_noise_enabled=ctx.sampling.adaptive_noise_enabled, adaptive_noise_method=ctx.sampling.adaptive_noise_method, ) return result[0] except Exception as e: logger.exception(f"Generation failed: {e}") raise def _create_flux2_latent(self, width: int, height: int, batch_size: int) -> dict: """Create an empty latent tensor for Flux2. Flux2 uses 32-channel VAE-shaped latents in the pipeline. Args: width: Image width height: Image height batch_size: Batch size Returns: Dict with 'samples' key containing latent tensor """ # Flux VAE uses 8x downscaling latent_height = height // 8 latent_width = width // 8 latent = torch.zeros( batch_size, 32, latent_height, latent_width, dtype=torch.float32, ) return {"samples": latent} def decode(self, latents: torch.Tensor) -> torch.Tensor: """Decode latents to pixel space using the VAE. Args: latents: Latent tensor or dict with 'samples' key Returns: Decoded image tensor in [0, 1] range """ if not self._loaded: raise RuntimeError("Model must be loaded before decoding") try: # Handle both raw tensor and dict input if isinstance(latents, dict): samples_tensor = latents["samples"] else: samples_tensor = latents # Use the Flux2 latent format # Apply process_latent_out (undo scale/shift from sampling) is now handled by KSAMPLER # Decode with VAE decoder = VariationalAE.VAEDecode() result = decoder.decode( vae=self.vae, samples={"samples": samples_tensor}, ) return result[0] except Exception as e: logger.exception(f"Decoding failed: {e}") raise def get_model_object(self, name): """Get an attribute from the model or its patcher.""" if name == "latent_format": return self._model_config.latent_format if self.model: return self.model.get_model_object(name) return None def apply_lora( self, lora_name: str, strength_model: float = 1.0, strength_clip: float = 1.0, ) -> "Flux2KleinModel": """Apply a LoRA to the Flux2 Klein model. Note: LoRA support for Flux2 may be limited. Args: lora_name: Name/path of the LoRA file strength_model: Strength to apply to the model strength_clip: Strength to apply to CLIP Returns: Self for method chaining """ if not self._loaded: raise RuntimeError("Model must be loaded before applying LoRA") try: loader = LoRas.LoraLoader() result = loader.load_lora( lora_name=lora_name, strength_model=strength_model, strength_clip=strength_clip, model=self.model, clip=self.clip, ) self.model = result[0] self.clip = result[1] logger.info(f"Applied LoRA: {lora_name}") except Exception as e: logger.warning(f"Failed to apply LoRA {lora_name}: {e}") return self