LightDiffusion-Next / src /Core /Models /Flux2KleinModel.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""Flux2 Klein model adapter for LightDiffusion-Next.
Provides a clean interface to the Flux2 Klein 4B model that inherits from
AbstractModel and integrates with the LightDiffusion-Next model factory.
This implementation uses ONLY native LightDiffusion-Next components,
without any ComfyUI imports.
File structure expected:
- include/diffusion_model/flux-2-klein-4b.safetensors (or similar)
- include/text_encoder/qwen_3_4b.safetensors
- include/vae/ae.safetensors (Flux VAE)
"""
import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Optional
import torch
from src.Core.AbstractModel import AbstractModel, ModelCapabilities
from src.Utilities import util
from src.Device import Device
# Import modules that were previously lazy-loaded inside methods
# This avoids KeyError: 'src' when running via uv run streamlit
from src.NeuralNetwork.flux2.model import Flux2, Flux2Params
from src.Model.ModelPatcher import ModelPatcher
from src.clip.KleinEncoder import KleinCLIP, Qwen3_4BModel
from src.AutoEncoders import VariationalAE
from src.sample import sampling
from src.Utilities import Latent
from src.Model import LoRas
if TYPE_CHECKING:
from src.Core.Context import Context
logger = logging.getLogger(__name__)
# Default paths for Flux2 Klein components
DEFAULT_DIFFUSION_MODEL_DIR = "./include/diffusion_model"
DEFAULT_TEXT_ENCODER_DIR = "./include/text_encoder"
DEFAULT_VAE_DIR = "./include/vae"
class Flux2KleinModel(AbstractModel):
"""Flux2 Klein 4B model implementation.
Wraps the Flux2 Klein model with the clean AbstractModel interface
for use with the LightDiffusion-Next pipeline system.
The Flux2 Klein model is a distilled version of the Flux2 architecture
using the Klein (Qwen3 4B) text encoder.
Unlike SD1.5/SDXL which use combined checkpoints, Flux2 Klein loads
components separately:
- Diffusion model from include/diffusion_model/
- Text encoder (Qwen3 4B) from include/text_encoder/
- VAE from include/vae/
"""
def __init__(
self,
model_path: str = None,
text_encoder_path: str = None,
vae_path: str = None,
quantization: str = None, # "fp8", "nvfp4", or None
):
"""Initialize the Flux2 Klein model adapter.
Args:
model_path: Path to diffusion model (safetensors)
text_encoder_path: Path to Qwen3 text encoder (optional, auto-detected)
vae_path: Path to VAE (optional, auto-detected)
quantization: Quantization format to use ("fp8", "nvfp4", or None)
"""
super().__init__(model_path)
self._text_encoder = None
self._tokenizer = None
self._model_config = None
self._text_encoder_path = text_encoder_path
self._vae_path = vae_path
self._raw_model = None # The raw Flux2 nn.Module
self.quantization = quantization
# Device management
self.load_device = Device.get_torch_device()
self.offload_device = torch.device("cpu")
def _create_capabilities(self) -> ModelCapabilities:
"""Create capabilities for Flux2 Klein model."""
return ModelCapabilities(
min_resolution=256,
max_resolution=4096,
preferred_resolution=1024,
requires_resolution_multiple=16, # Flux2 uses 16-pixel patches
supports_hires_fix=True,
supports_img2img=True,
supports_inpainting=False, # Not yet implemented for Flux2
supports_controlnet=False, # ControlNet support pending
supports_stable_fast=False, # May need special handling
supports_deepcache=False, # Architecture differs from UNet
supports_tome=False, # Token merging needs special implementation
supports_lora=False, # Flux2 LoRA format differs from SD
uses_dual_clip=False, # Uses single Klein (Qwen3) encoder
requires_size_conditioning=False,
is_flux=True,
is_flux2=True,
)
def _find_diffusion_model(self) -> Optional[str]:
"""Auto-detect Flux2 diffusion model in default directory."""
if os.path.exists(DEFAULT_DIFFUSION_MODEL_DIR):
for f in os.listdir(DEFAULT_DIFFUSION_MODEL_DIR):
f_lower = f.lower()
if ("flux" in f_lower or "klein" in f_lower) and f.endswith((".safetensors", ".pt", ".pth")):
return os.path.join(DEFAULT_DIFFUSION_MODEL_DIR, f)
return None
def _find_text_encoder(self) -> Optional[str]:
"""Auto-detect Qwen3 text encoder in default directory."""
if os.path.exists(DEFAULT_TEXT_ENCODER_DIR):
for f in os.listdir(DEFAULT_TEXT_ENCODER_DIR):
f_lower = f.lower()
if ("qwen" in f_lower or "klein" in f_lower) and f.endswith((".safetensors", ".pt", ".pth")):
return os.path.join(DEFAULT_TEXT_ENCODER_DIR, f)
return None
def _find_vae(self) -> Optional[str]:
"""Auto-detect VAE in default directory."""
if os.path.exists(DEFAULT_VAE_DIR):
# Look for Flux-compatible VAE (ae.safetensors)
for f in os.listdir(DEFAULT_VAE_DIR):
if f.endswith((".safetensors", ".pt", ".pth")):
return os.path.join(DEFAULT_VAE_DIR, f)
return None
def load(self, model_path: str = None) -> "Flux2KleinModel":
"""Load the Flux2 Klein model components from disk.
Components are loaded separately:
- Diffusion model (Flux2 transformer)
- Text encoder (Qwen3 4B via Klein tokenizer)
- VAE
Args:
model_path: Optional override for the diffusion model path
Returns:
Self for method chaining
"""
# Resolve paths
diffusion_path = model_path or self.model_path or self._find_diffusion_model()
# Guard: Don't reload if already loaded with same diffusion model
if self._loaded and self.model_path == diffusion_path:
logger.info("Flux2KleinModel: Already loaded, skipping redundant load")
return self
if diffusion_path is None:
raise ValueError(
"No Flux2 diffusion model found. Please place the model in "
f"{DEFAULT_DIFFUSION_MODEL_DIR}/ with 'flux' or 'klein' in the filename."
)
self.model_path = diffusion_path
# Resolve other paths only when loading is actually needed
text_encoder_path = self._text_encoder_path or self._find_text_encoder()
vae_path = self._vae_path or self._find_vae()
logger.info(f"Flux2KleinModel: Loading components...")
logger.info(f" Diffusion model: {diffusion_path}")
logger.info(f" Text encoder: {text_encoder_path}")
logger.info(f" VAE: {vae_path}")
try:
# Load diffusion model
# self.model = self._load_diffusion_model(diffusion_path) # Original line
# New FP8 loading logic
from src.NeuralNetwork.flux2.model import create_flux2_klein
from src.Device import Device
from src.FileManaging import Loader
# Check for FP8 support and user preference/environment
use_fp8 = Device.is_fp8_supported(self.load_device)
# For 8GB cards, we force FP8 for Flux2 Klein 4B to avoid swapping
total_vram = Device.get_total_memory(self.load_device) / (1024**3)
if total_vram < 12.0: # If less than 12GB, FP8 is highly recommended for Flux
use_fp8 = use_fp8 and True
dtype = torch.bfloat16 # Base weight dtype
# Create model with detected config
config = self._detect_flux2_config(util.load_torch_file(diffusion_path, device=torch.device("cpu"))) # Load temporarily to detect config
params = Flux2Params(**config)
self.model = Flux2(params=params, dtype=dtype, device=torch.device("cpu")) # Create on CPU first
self.model.eval()
# Attach config for compatibility
self._model_config = self._create_model_config() # Ensure _model_config is set
# Load weights
sd = util.load_torch_file(diffusion_path, device=self.offload_device)
# Sanitize NaN values in weights (some Flux2 checkpoints have NaN biases)
nan_keys = []
for key, value in sd.items():
if isinstance(value, torch.Tensor) and torch.isnan(value).any():
nan_keys.append(key)
sd[key] = torch.where(torch.isnan(value), torch.zeros_like(value), value)
if nan_keys:
logger.warning(f"Sanitized NaN values in {len(nan_keys)} keys: {nan_keys[:5]}...")
self.model.load_state_dict(sd, strict=False)
del sd
self._raw_model = self.model # Store raw model
# Create ModelPatcher
self.model = ModelPatcher(self.model, self.load_device, self.offload_device)
# Apply quantization if requested or needed
quant_format = self.quantization
if quant_format is None and use_fp8:
quant_format = "fp8"
if quant_format == "nvfp4":
logging.info("Flux2: Applying NVFP4 (4-bit) weight-only quantization")
self.model.weight_only_quantize("nvfp4")
self.model.model_dtype = lambda: torch.float16 # Compute in FP16 for dequantization
elif quant_format == "fp8":
logging.info("Flux2: Applying FP8 weight-only quantization")
self.model.weight_only_quantize(torch.float8_e4m3fn)
self.model.model_dtype = lambda: torch.float8_e4m3fn # Override
# Load text encoder
if text_encoder_path:
self.clip = self._load_klein_text_encoder(text_encoder_path, quantize=quant_format)
self._text_encoder = self.clip # For internal reference
self._tokenizer = self.clip.tokenizer
else:
logger.warning("No Qwen3 text encoder found - prompt encoding may fail")
self.clip = None
# Load VAE
if vae_path:
self.vae = self._load_vae(vae_path)
else:
logger.warning("No VAE found - image decoding may fail")
self.vae = None
# Store config for sampling
self._model_config = self._create_model_config()
# Attach model_sampling for sampler infrastructure
from src.sample import sampling
self.model.model_sampling = sampling.model_sampling(self._model_config, "flux2", flux=True, flux2=True)
self._loaded = True
logger.info(f"Flux2KleinModel: Successfully loaded all components")
except Exception as e:
logger.exception(f"Flux2KleinModel: Failed to load: {e}")
raise
return self
def _load_diffusion_model(self, path: str):
"""Load the Flux2 diffusion model using native LightDiffusion-Next.
Args:
path: Path to diffusion model safetensors
Returns:
ModelPatcher wrapping the Flux2 model
"""
logger.info(f"Loading Flux2 diffusion model: {path}")
# Load state dict using native utility
sd = util.load_torch_file(path)
# Sanitize NaN values in weights (some Flux2 checkpoints have NaN biases)
nan_keys = []
for key, value in sd.items():
if isinstance(value, torch.Tensor) and torch.isnan(value).any():
nan_keys.append(key)
sd[key] = torch.where(torch.isnan(value), torch.zeros_like(value), value)
if nan_keys:
logger.warning(f"Sanitized NaN values in {len(nan_keys)} keys: {nan_keys[:5]}...")
# Detect model configuration from state dict
config = self._detect_flux2_config(sd)
# Determine dtype and device
load_device = Device.get_torch_device()
offload_device = Device.unet_offload_device()
# Infer dtype from weights
dtype = torch.bfloat16
for k, v in sd.items():
if isinstance(v, torch.Tensor) and v.dtype in (torch.float16, torch.bfloat16, torch.float32):
dtype = v.dtype
break
logger.info(f"Flux2 model dtype: {dtype}")
# Create model with detected config
params = Flux2Params(**config)
model = Flux2(params=params, dtype=dtype, device="cpu")
# Attach config for compatibility
model.model_config = self._create_model_config()
# Load weights
missing, unexpected = model.load_state_dict(sd, strict=False)
if missing:
logger.debug(f"Missing keys: {len(missing)}")
if unexpected:
logger.debug(f"Unexpected keys: {len(unexpected)}")
self._raw_model = model
# Wrap in ModelPatcher for compatibility with sampling infrastructure
model_patcher = ModelPatcher.ModelPatcher(
model,
load_device=load_device,
offload_device=offload_device,
current_device=torch.device("cpu"),
)
return model_patcher
def _detect_flux2_config(self, sd: dict) -> dict:
"""Detect Flux2 model configuration from state dict.
Args:
sd: Model state dictionary
Returns:
Configuration dict for Flux2Params
"""
# Detect if this is Flux2 (has double_stream_modulation) or Flux1
is_flux2 = any("double_stream_modulation" in k for k in sd.keys())
if is_flux2:
# Flux2 / Klein defaults (patch_size=1 unlike Flux1!)
config = {
"patch_size": 1, # CRITICAL: Flux2 uses patch_size=1 (no spatial patchification)
"in_channels": 128, # Direct channel input (no patch_size division)
"out_channels": 128, # Direct channel output
"vec_in_dim": 768,
"context_in_dim": 7680, # Klein uses concatenated multi-layer output
"hidden_size": 3072,
"mlp_ratio": 3.0, # Klein uses 3.0 with gated MLP
"num_heads": 24, # Flux2: hidden_size/sum(axes_dim) = 3072/128 = 24
"depth": 19,
"depth_single_blocks": 38,
"axes_dim": [32, 32, 32, 32], # Flux2 specific - sum=128
"theta": 2000, # Flux2 uses lower theta
"qkv_bias": False,
"guidance_embed": False,
"gated_mlp": True, # Klein uses gated MLP (SwiGLU)
"global_modulation": True, # Flux2 feature
"mlp_silu_act": True, # Flux2 feature
"ops_bias": False, # Flux2 feature
"use_vector_in": False, # Flux2/Klein doesn't use pooled conditioning
}
logger.info("Detected Flux2 model (has double_stream_modulation)")
else:
# Flux1 defaults
config = {
"in_channels": 16,
"out_channels": 16,
"vec_in_dim": 768,
"context_in_dim": 7680,
"hidden_size": 3072,
"mlp_ratio": 4.0,
"num_heads": 24,
"depth": 19,
"depth_single_blocks": 38,
"axes_dim": [16, 56, 56], # Flux1 specific
"theta": 10000,
"qkv_bias": True,
"guidance_embed": True,
"gated_mlp": False,
}
logger.info("Detected Flux1 model")
# Detect depth from double_blocks
double_blocks = [k for k in sd.keys() if "double_blocks" in k]
if double_blocks:
max_block = max(
int(k.split("double_blocks.")[1].split(".")[0])
for k in double_blocks
if "double_blocks." in k
)
config["depth"] = max_block + 1
# Detect single blocks depth
single_blocks = [k for k in sd.keys() if "single_blocks" in k]
if single_blocks:
max_single = max(
int(k.split("single_blocks.")[1].split(".")[0])
for k in single_blocks
if "single_blocks." in k
)
config["depth_single_blocks"] = max_single + 1
# Detect hidden size and in_channels from img_in
if "img_in.weight" in sd:
config["hidden_size"] = sd["img_in.weight"].shape[0]
# img_in input dim = in_channels * patch_size^2
# For Flux2 with patch_size=1: in_channels = img_in_dim directly
img_in_dim = sd["img_in.weight"].shape[1]
patch_size = config.get("patch_size", 2)
config["in_channels"] = img_in_dim // (patch_size ** 2)
logger.info(f"Detected in_channels={config['in_channels']} from img_in (patch_size={patch_size})")
# Detect out_channels from final_layer
if "final_layer.linear.weight" in sd:
# final_layer.linear maps hidden -> patch_size * patch_size * out_channels
# For Flux2 with patch_size=1: out_channels = final.shape[0] directly
final_out = sd["final_layer.linear.weight"].shape[0]
patch_size = config.get("patch_size", 2)
config["out_channels"] = final_out // (patch_size ** 2)
logger.info(f"Detected out_channels={config['out_channels']} from final_layer")
# Detect mlp_ratio and gated_mlp from double_blocks MLP weights
# For gated MLP: img_mlp.0 maps hidden -> 2*intermediate (gate+up)
# img_mlp.2 maps intermediate -> hidden
# So: mlp_0_out = 2 * intermediate, intermediate = mlp_2_in
# mlp_ratio = intermediate / hidden
if "double_blocks.0.img_mlp.0.weight" in sd and "double_blocks.0.img_mlp.2.weight" in sd:
mlp_0_out = sd["double_blocks.0.img_mlp.0.weight"].shape[0]
mlp_2_in = sd["double_blocks.0.img_mlp.2.weight"].shape[1]
hidden = config["hidden_size"]
# Check if it's gated MLP: mlp_0_out should be 2 * mlp_2_in
if abs(mlp_0_out - 2 * mlp_2_in) < 10: # Small tolerance
# Gated MLP detected
config["gated_mlp"] = True
intermediate = mlp_2_in
config["mlp_ratio"] = intermediate / hidden
logger.info(f"Detected gated MLP: intermediate={intermediate}, mlp_ratio={config['mlp_ratio']}")
else:
# Standard MLP: mlp_0_out = mlp_2_in = hidden * mlp_ratio
config["gated_mlp"] = False
config["mlp_ratio"] = mlp_0_out / hidden
# Calculate num_heads from hidden_size and axes_dim (ComfyUI approach)
# num_heads = hidden_size // sum(axes_dim)
axes_sum = sum(config["axes_dim"])
config["num_heads"] = config["hidden_size"] // axes_sum
logger.info(f"Calculated num_heads={config['num_heads']} from hidden_size={config['hidden_size']} / axes_sum={axes_sum}")
# Detect context_in_dim from txt_in
if "txt_in.weight" in sd:
config["context_in_dim"] = sd["txt_in.weight"].shape[1]
# Detect vec_in_dim from vector_in
if "vector_in.in_layer.weight" in sd:
config["vec_in_dim"] = sd["vector_in.in_layer.weight"].shape[1]
config["use_vector_in"] = True # Enable vector_in if weights exist
logger.info(f"Detected vector_in with dim {config['vec_in_dim']}")
# Detect guidance embedding
if any("guidance_in" in k for k in sd.keys()):
config["guidance_embed"] = True
# Detect txt_norm (critical for some Flux2 variants)
if any("txt_norm.scale" in k for k in sd.keys()):
config["txt_norm"] = True
logger.info("Detected txt_norm in model weights")
logger.info(f"Detected Flux2 config: depth={config['depth']}, "
f"single_blocks={config['depth_single_blocks']}, "
f"hidden={config['hidden_size']}, mlp_ratio={config['mlp_ratio']}, "
f"gated_mlp={config.get('gated_mlp', False)}")
return config
def _load_klein_text_encoder(self, path: str, quantize: str = None):
"""Load the Klein (Qwen3-4B) text encoder.
Args:
path: Path to text encoder safetensors
quantize: Quantization format ("fp8", "nvfp4", or None)
Returns:
KleinCLIP wrapper
"""
logger.info(f"Loading Text Encoder: {path}")
from src.clip.KleinEncoder import KleinCLIP, KleinTokenizer, Qwen3_4BModel, get_ops
from src.Model.ModelPatcher import ModelPatcher
# Determine paths
sd_path = path
tokenizer_path = os.path.join(os.path.dirname(path), "qwen25_tokenizer")
if not os.path.exists(tokenizer_path):
tokenizer_path = None # Let KleinTokenizer find its default
# Load weights
sd = util.load_torch_file(sd_path, device=torch.device("cpu"))
# Create model structure
# Base dtype is BF16
dtype = torch.bfloat16
model = Qwen3_4BModel(dtype=dtype, device="cpu")
# Load state dict
model_sd = {}
for k, v in sd.items():
if k.startswith("model."):
model_sd[k[6:]] = v
else:
model_sd[k] = v
missing, unexpected = model.load_state_dict(model_sd, strict=False)
# Apply quantization BEFORE moving to offload device if requested
if quantize:
logger.info(f"Flux2KleinModel: Quantizing Klein (Qwen3-4B) to {quantize}")
# We must use ModelPatcher to correctly update comfy_cast_weights flags
te_patcher = ModelPatcher(model, self.load_device, self.offload_device)
if quantize == "nvfp4":
te_patcher.weight_only_quantize("nvfp4")
else:
te_patcher.weight_only_quantize(torch.float8_e4m3fn)
model = te_patcher.model
# IMPORTANT: Keep model on CPU to save VRAM for diffusion model
offload_device = Device.text_encoder_offload_device()
model = model.to(offload_device)
# Create wrapper
tokenizer = KleinTokenizer(tokenizer_path)
clip = KleinCLIP(tokenizer=tokenizer, model=model, dtype=dtype, device=self.load_device, offload_device=offload_device)
return clip
def _load_vae(self, path: str):
"""Load the VAE for decoding latents using native LightDiffusion-Next.
Following ComfyUI's VAE loading approach:
- Detects z_channels from decoder.conv_in.weight.shape[1]
- Uses post_quant_conv/quant_conv (flux=False) for standard VAE structure
Args:
path: Path to VAE safetensors
Returns:
VAE model
"""
logger.info(f"Loading VAE: {path}")
# Load state dict
sd = util.load_torch_file(path)
# Check for diffusers format and convert if needed (ComfyUI approach)
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd:
logger.info("Converting diffusers VAE format to SD format")
sd = self._convert_diffusers_vae(sd)
# Log VAE structure
is_flux_vae = False
if 'decoder.conv_in.weight' in sd:
z_ch = sd['decoder.conv_in.weight'].shape[1]
logger.info(f"VAE z_channels: {z_ch}")
if 'post_quant_conv.weight' in sd:
embed_dim = sd['post_quant_conv.weight'].shape[1]
logger.info(f"VAE embed_dim: {embed_dim} (Standard VAE)")
is_flux_vae = False
else:
logger.info("VAE missing post_quant_conv (Flux VAE)")
is_flux_vae = True
# Create VAE using native implementation
# Set flux=True if it's a Flux VAE (skips post_quant_conv)
# Use bfloat16 for better precision/memory balance on modern GPUs
vae = VariationalAE.VAE(sd=sd, flux=is_flux_vae, dtype=torch.bfloat16)
return vae
def _convert_diffusers_vae(self, sd: dict) -> dict:
"""Convert diffusers VAE format to SD format (ComfyUI approach)."""
# VAE conversion map from ComfyUI's diffusers_convert.py
vae_conversion_map = [
("nin_shortcut", "conv_shortcut"),
("norm_out", "conv_norm_out"),
("mid.attn_1.", "mid_block.attentions.0."),
]
for i in range(4):
for j in range(2):
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
sd_down_prefix = f"encoder.down.{i}.block.{j}."
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
if i < 3:
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
sd_downsample_prefix = f"down.{i}.downsample."
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
vae_conversion_map_attn = [
("norm.", "group_norm."),
("q.", "query."), ("k.", "key."), ("v.", "value."),
("q.", "to_q."), ("k.", "to_k."), ("v.", "to_v."),
("proj_out.", "to_out.0."), ("proj_out.", "proj_attn."),
]
mapping = {k: k for k in sd.keys()}
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
if "attentions" in k:
for sd_part, hf_part in vae_conversion_map_attn:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: sd[k] for k, v in mapping.items()}
# Reshape attention weights
weights_to_convert = ["q", "k", "v", "proj_out"]
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
new_state_dict[k] = v.reshape(*v.shape, 1, 1)
return new_state_dict
def _create_model_config(self):
"""Create a model config object for sampling."""
class Flux2KleinConfig:
"""Configuration for Flux2 Klein sampling."""
sampling_settings = {
"shift": 2.02, # Flux2 default shift (different from Flux1's 1.15)
}
latent_format = Latent.Flux2()
recommended_steps = 4
recommended_cfg = 1.0
return Flux2KleinConfig()
def encode_prompt(
self,
prompt: str | list[str],
negative_prompt: str | list[str] = "",
clip_skip: int = None,
) -> tuple[Any, Any]:
"""Encode text prompts into conditioning tensors.
For Flux2 Klein, this uses the Qwen3-based Klein text encoder
which does not use traditional CLIP skip.
CRITICAL: ComfyUI LEFT-PADS text embeddings to 512 tokens before passing
to the diffusion model. This is essential for matching image quality because:
1. The positional encoding (RoPE) depends on sequence length
2. The model was trained with fixed 512-token text sequences
Args:
prompt: Positive prompt(s) to encode
negative_prompt: Negative prompt(s) (may be ignored for Flux2)
clip_skip: Not used for Klein encoder
Returns:
Tuple of (positive_conditioning, negative_conditioning)
"""
if not self._loaded:
raise RuntimeError("Model must be loaded before encoding prompts")
if self.clip is None:
raise RuntimeError("No text encoder loaded")
try:
import torch
# Use Klein encoder directly
if isinstance(prompt, list):
# Encode each prompt in the batch
all_hidden = []
all_pooled = []
for p in prompt:
tokens = self.clip.tokenizer.tokenize_with_weights(p)
h, pol, _ = self.clip.encode_token_weights(tokens)
all_hidden.append(h)
# Handle cases where pooled output might be None (common in Klein/Qwen encoders)
if pol is not None:
all_pooled.append(pol)
hidden_states = torch.cat(all_hidden, dim=0)
pooled = torch.cat(all_pooled, dim=0) if all_pooled else None
else:
# Single prompt
tokens = self.clip.tokenizer.tokenize_with_weights(prompt)
hidden_states, pooled, extra = self.clip.encode_token_weights(tokens)
# Encode negative (or empty)
neg_prompt = negative_prompt
if neg_prompt:
if isinstance(neg_prompt, list):
# We usually only need one negative for the whole batch or match batch size
if len(neg_prompt) == 1:
neg_prompt = neg_prompt[0]
else:
# Encode all negatives
all_neg_hidden = []
all_neg_pooled = []
for np in neg_prompt:
ntokens = self.clip.tokenizer.tokenize_with_weights(np)
nh, npol, _ = self.clip.encode_token_weights(ntokens)
all_neg_hidden.append(nh)
if npol is not None:
all_neg_pooled.append(npol)
neg_hidden = torch.cat(all_neg_hidden, dim=0)
neg_pooled = torch.cat(all_neg_pooled, dim=0) if all_neg_pooled else None
neg_prompt = None # Mark as processed
if neg_prompt is not None:
neg_tokens = self.clip.tokenizer.tokenize_with_weights(neg_prompt or "")
neg_hidden, neg_pooled, neg_extra = self.clip.encode_token_weights(neg_tokens)
# Embeddings are already padded to 512 tokens by the tokenizer
# Format as conditioning
# Note: ComfyUI does NOT pass attention_mask to diffusion model for Flux2
# The zero-padded tokens don't contribute meaningfully to cross-attention
cond_dict = {"pooled_output": pooled}
positive = [[hidden_states, cond_dict]]
neg_cond_dict = {"pooled_output": neg_pooled}
negative = [[neg_hidden, neg_cond_dict]]
return positive, negative
except Exception as e:
logger.exception(f"Prompt encoding failed: {e}")
raise
def generate(
self,
ctx: "Context",
positive: Any,
negative: Any,
latent_image: Optional[Any] = None,
start_step: Optional[int] = None,
last_step: Optional[int] = None,
disable_noise: bool = False,
callback: Optional[Callable] = None,
) -> dict:
"""Generate latents using the Flux2 sampler.
Args:
ctx: Context with generation parameters
positive: Positive conditioning
negative: Negative conditioning (may be ignored)
Returns:
Dictionary with 'samples' key containing generated latents
"""
if not self._loaded:
raise RuntimeError("Model must be loaded before generating")
# Log recommendation if CFG is high for this distilled model
if ctx.sampling.cfg > 2.0:
logger.info(f"Tip: Flux2 Klein works best with CFG 1.0. "
f"You are currently using CFG {ctx.sampling.cfg}.")
try:
# Use provided latent or create empty one for Flux2
if latent_image is not None:
latent = latent_image
else:
latent = self._create_flux2_latent(
ctx.width,
ctx.height,
ctx.generation.batch,
)
# Add seeds for deterministic noise
latent["seeds"] = ctx.seeds[:ctx.generation.batch] if ctx.seeds else [ctx.seed]
# CRITICAL: Force-disable multi-scale for Flux2 models
# Multi-scale is designed for UNet architectures (SD1.5/SDXL) and
# causes significant performance overhead for Flux2's DiT architecture
enable_multiscale = False # Always disable for Flux2
if ctx.sampling.enable_multiscale:
logger.info("Multi-scale disabled: not compatible with Flux2 architecture")
# Run sampling with flux=True AND flux2=True for resolution-aware scheduler
ksampler = sampling.KSampler()
result = ksampler.sample(
seed=ctx.seed,
steps=ctx.sampling.steps,
cfg=ctx.sampling.cfg,
sampler_name=ctx.sampling.sampler,
scheduler=ctx.sampling.scheduler,
denoise=ctx.sampling.denoise,
pipeline=True,
model=self.model,
positive=positive,
negative=negative,
latent_image=latent,
start_step=start_step,
last_step=last_step,
disable_noise=disable_noise,
callback=callback or ctx.callback,
flux=True, # Enable Flux sampling mode
flux2=True, # Enable Flux2-specific resolution-aware scheduler (matches ComfyUI Flux2Scheduler)
enable_multiscale=enable_multiscale, # Force disabled for Flux2
multiscale_factor=ctx.sampling.multiscale_factor,
multiscale_fullres_start=ctx.sampling.multiscale_fullres_start,
multiscale_fullres_end=ctx.sampling.multiscale_fullres_end,
multiscale_intermittent_fullres=ctx.sampling.multiscale_intermittent_fullres,
cfg_free_enabled=ctx.sampling.cfg_free_enabled,
cfg_free_start_percent=ctx.sampling.cfg_free_start_percent,
batched_cfg=ctx.sampling.batched_cfg,
dynamic_cfg_rescaling=ctx.sampling.dynamic_cfg_rescaling,
dynamic_cfg_method=ctx.sampling.dynamic_cfg_method,
dynamic_cfg_percentile=ctx.sampling.dynamic_cfg_percentile,
dynamic_cfg_target_scale=ctx.sampling.dynamic_cfg_target_scale,
adaptive_noise_enabled=ctx.sampling.adaptive_noise_enabled,
adaptive_noise_method=ctx.sampling.adaptive_noise_method,
)
return result[0]
except Exception as e:
logger.exception(f"Generation failed: {e}")
raise
def _create_flux2_latent(self, width: int, height: int, batch_size: int) -> dict:
"""Create an empty latent tensor for Flux2.
Flux2 uses 32-channel VAE-shaped latents in the pipeline.
Args:
width: Image width
height: Image height
batch_size: Batch size
Returns:
Dict with 'samples' key containing latent tensor
"""
# Flux VAE uses 8x downscaling
latent_height = height // 8
latent_width = width // 8
latent = torch.zeros(
batch_size,
32,
latent_height,
latent_width,
dtype=torch.float32,
)
return {"samples": latent}
def decode(self, latents: torch.Tensor) -> torch.Tensor:
"""Decode latents to pixel space using the VAE.
Args:
latents: Latent tensor or dict with 'samples' key
Returns:
Decoded image tensor in [0, 1] range
"""
if not self._loaded:
raise RuntimeError("Model must be loaded before decoding")
try:
# Handle both raw tensor and dict input
if isinstance(latents, dict):
samples_tensor = latents["samples"]
else:
samples_tensor = latents
# Use the Flux2 latent format
# Apply process_latent_out (undo scale/shift from sampling) is now handled by KSAMPLER
# Decode with VAE
decoder = VariationalAE.VAEDecode()
result = decoder.decode(
vae=self.vae,
samples={"samples": samples_tensor},
)
return result[0]
except Exception as e:
logger.exception(f"Decoding failed: {e}")
raise
def get_model_object(self, name):
"""Get an attribute from the model or its patcher."""
if name == "latent_format":
return self._model_config.latent_format
if self.model:
return self.model.get_model_object(name)
return None
def apply_lora(
self,
lora_name: str,
strength_model: float = 1.0,
strength_clip: float = 1.0,
) -> "Flux2KleinModel":
"""Apply a LoRA to the Flux2 Klein model.
Note: LoRA support for Flux2 may be limited.
Args:
lora_name: Name/path of the LoRA file
strength_model: Strength to apply to the model
strength_clip: Strength to apply to CLIP
Returns:
Self for method chaining
"""
if not self._loaded:
raise RuntimeError("Model must be loaded before applying LoRA")
try:
loader = LoRas.LoraLoader()
result = loader.load_lora(
lora_name=lora_name,
strength_model=strength_model,
strength_clip=strength_clip,
model=self.model,
clip=self.clip,
)
self.model = result[0]
self.clip = result[1]
logger.info(f"Applied LoRA: {lora_name}")
except Exception as e:
logger.warning(f"Failed to apply LoRA {lora_name}: {e}")
return self