This is an optimized version of the text encoder used in flux2klein 9B. Same weights/architecture (Qwen3), just stripped down code that, under torch.compile, is 1.3x faster and uses less peak VRAM (should save a couple gigs). ``` qwen_model = FluxQwen3TorchEmbedder.from_pretrained("fancyfeast/flux2klein-optimized-text-embedder-9B", torch_dtype=torch.bfloat16) ``` ``` from __future__ import annotations import json import math from pathlib import Path import torch from torch import nn from torch.nn import functional as F from transformers.models.qwen3.configuration_qwen3 import Qwen3Config from transformers import PreTrainedModel class FluxQwen3TorchEmbedder(PreTrainedModel): """Stripped down and optimized Qwen3 specifically for Flux 2 Klein models. In my testing this is about 1.3x faster than using the original HF implementation, and saves ~3GB of peak memory on the 8GB model. The output_hidden_state_indices is 9, 18, 27 for both Klein 4B and Klein 9B. """ config_class = Qwen3Config base_model_prefix = "flux_qwen3" def __init__( self, config: Qwen3Config, *, output_hidden_state_indices: tuple[int, ...] = (9, 18, 27), max_sequence_length: int = 512, ): super().__init__(config) self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.head_dim = int(getattr(config, "head_dim", self.hidden_size // self.num_attention_heads)) self.rope_theta = float(getattr(config, "rope_theta", 1000000.0)) self.output_hidden_state_indices = tuple(int(i) for i in output_hidden_state_indices) if not self.output_hidden_state_indices: raise ValueError("output_hidden_state_indices must not be empty") if min(self.output_hidden_state_indices) < 1: raise ValueError("output hidden state indices must be >= 1 for decoder layer outputs") max_layer_needed = max(self.output_hidden_state_indices) if max_layer_needed > int(config.num_hidden_layers): raise ValueError(f"requested hidden state after layer {max_layer_needed}, but config.num_hidden_layers={config.num_hidden_layers}") self.capture_slot_by_layer = { layer_idx: slot for slot, layer_idx in enumerate(self.output_hidden_state_indices) } self.max_sequence_length = int(max_sequence_length) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=getattr(config, "pad_token_id", None)) self.layers = nn.ModuleList( FluxQwen3TorchLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ) # Built lazily/refreshed in forward so dtype/device tracks the model. self.register_buffer("cos_cached", torch.empty(0), persistent=False) self.register_buffer("sin_cached", torch.empty(0), persistent=False) self.register_buffer("causal_mask", torch.empty(0, dtype=torch.bool), persistent=False) self.post_init() def _maybe_refresh_caches(self, *, device: torch.device, dtype: torch.dtype): need_refresh = self.cos_cached.numel() == 0 or self.cos_cached.shape[0] < self.max_sequence_length or self.cos_cached.device != device or self.cos_cached.dtype != dtype if not need_refresh: return cos, sin = _rotary_cache( self.max_sequence_length, self.head_dim, self.rope_theta, device=device, dtype=dtype, ) pos = torch.arange(self.max_sequence_length, device=device) causal = pos[None, :] <= pos[:, None] self.cos_cached = cos self.sin_cached = sin self.causal_mask = causal[None, None, :, :] @classmethod def _from_original_hf_checkpoint(cls, checkpoint_path: str, subfolder: str | None) -> "FluxQwen3TorchEmbedder": from huggingface_hub import hf_hub_download import safetensors.torch from transformers import AutoConfig cfg = AutoConfig.from_pretrained(checkpoint_path, subfolder=subfolder) assert isinstance(cfg, Qwen3Config), f"expected Qwen3Config, got {type(cfg)}" cfg.num_hidden_layers = 27 if cfg.layer_types is not None: cfg.layer_types = cfg.layer_types[:27] cfg.max_window_layers = 27 model = cls(cfg) # Load the original checkpoint index_path = hf_hub_download(checkpoint_path, filename="model.safetensors.index.json", subfolder=subfolder) index = json.loads(Path(index_path).read_text()) shard_names = set(index['weight_map'].values()) original_checkpoint = {} for shard_name in shard_names: path = hf_hub_download(checkpoint_path, filename=shard_name, subfolder=subfolder) shard = safetensors.torch.load_file(path) original_checkpoint.update(shard) # Copy weights from the original checkpoint into our model with torch.no_grad(): model.embed_tokens.weight.copy_(original_checkpoint["model.embed_tokens.weight"]) for layer_idx in range(len(model.layers)): layer = model.layers[layer_idx] layer_base = f"model.layers.{layer_idx}." layer.input_layernorm_weight.copy_(original_checkpoint[layer_base + "input_layernorm.weight"]) layer.post_attention_layernorm_weight.copy_(original_checkpoint[layer_base + "post_attention_layernorm.weight"]) q = original_checkpoint[layer_base + "self_attn.q_proj.weight"] k = original_checkpoint[layer_base + "self_attn.k_proj.weight"] v = original_checkpoint[layer_base + "self_attn.v_proj.weight"] layer.qkv_proj_weight.copy_(torch.cat((q, k, v), dim=0)) layer.o_proj_weight.copy_(original_checkpoint[layer_base + "self_attn.o_proj.weight"]) layer.q_norm_weight.copy_(original_checkpoint[layer_base + "self_attn.q_norm.weight"]) layer.k_norm_weight.copy_(original_checkpoint[layer_base + "self_attn.k_norm.weight"]) gate = original_checkpoint[layer_base + "mlp.gate_proj.weight"] up = original_checkpoint[layer_base + "mlp.up_proj.weight"] layer.gate_up_proj_weight.copy_(torch.cat((gate, up), dim=0)) layer.down_proj_weight.copy_(original_checkpoint[layer_base + "mlp.down_proj.weight"]) return model def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: if input_ids.ndim != 2: raise ValueError(f"expected input_ids [batch, seq], got {tuple(input_ids.shape)}") batch, seq_len = input_ids.shape if seq_len != self.max_sequence_length: raise ValueError(f"sequence length {seq_len} does not match cached max {self.max_sequence_length}") dtype = self.embed_tokens.weight.dtype device = input_ids.device self._maybe_refresh_caches(device=device, dtype=dtype) key_mask = attention_mask.reshape(batch, 1, 1, seq_len).to(dtype=torch.bool) sdpa_mask = self.causal_mask[:, :, :seq_len, :seq_len] & key_mask cos = self.cos_cached[:seq_len] sin = self.sin_cached[:seq_len] hidden_states = self.embed_tokens(input_ids) prompt_embeds = torch.empty( batch, seq_len, len(self.output_hidden_state_indices) * self.hidden_size, device=input_ids.device, dtype=dtype, ) for layer_number, layer in enumerate(self.layers, start=1): hidden_states = layer(hidden_states, cos, sin, sdpa_mask) slot = self.capture_slot_by_layer.get(layer_number) if slot is None: continue start = slot * self.hidden_size prompt_embeds[:, :, start : start + self.hidden_size].copy_(hidden_states) return prompt_embeds class FluxQwen3TorchLayer(nn.Module): def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.num_attention_heads = config.num_attention_heads assert config.num_key_value_heads is not None, "num_key_value_heads must be specified in config for FluxQwen3TorchLayer" self.num_key_value_heads = config.num_key_value_heads self.head_dim = int(getattr(config, "head_dim", self.hidden_size // self.num_attention_heads)) self.rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) self.scale = 1.0 / math.sqrt(self.head_dim) self.q_width = self.num_attention_heads * self.head_dim self.kv_width = self.num_key_value_heads * self.head_dim self.k_offset = self.q_width self.v_offset = self.q_width + self.kv_width self.input_layernorm_weight = nn.Parameter(torch.empty(self.hidden_size)) self.post_attention_layernorm_weight = nn.Parameter(torch.empty(self.hidden_size)) self.qkv_proj_weight = nn.Parameter(torch.empty(self.q_width + 2 * self.kv_width, self.hidden_size)) self.o_proj_weight = nn.Parameter(torch.empty(self.hidden_size, self.q_width)) self.q_norm_weight = nn.Parameter(torch.empty(self.head_dim)) self.k_norm_weight = nn.Parameter(torch.empty(self.head_dim)) self.gate_up_proj_weight = nn.Parameter(torch.empty(self.intermediate_size * 2, self.hidden_size)) self.down_proj_weight = nn.Parameter(torch.empty(self.hidden_size, self.intermediate_size)) assert self.q_width == self.o_proj_weight.shape[1] assert self.o_proj_weight.shape == (self.hidden_size, self.q_width) assert self.qkv_proj_weight.shape == (self.q_width + 2 * self.kv_width, self.hidden_size) def _rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: dtype = x.dtype x_float = x.float() variance = x_float.pow(2).mean(dim=-1, keepdim=True) return (x_float * torch.rsqrt(variance + self.rms_norm_eps)).to(dtype) * weight def _head_rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: dtype = x.dtype x_float = x.float() variance = x_float.pow(2).mean(dim=-1, keepdim=True) return (x_float * torch.rsqrt(variance + self.rms_norm_eps)).to(dtype) * weight @staticmethod def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: half = x.shape[-1] // 2 x1 = x[..., :half] x2 = x[..., half:] rotated = torch.cat((-x2, x1), dim=-1) return x * cos[:, None, :] + rotated * sin[:, None, :] def forward( self, hidden_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: residual = hidden_states x = self._rms_norm(hidden_states, self.input_layernorm_weight) batch, seq_len, _ = x.shape qkv = F.linear(x, self.qkv_proj_weight) q_raw = qkv[:, :, : self.q_width].view(batch, seq_len, self.num_attention_heads, self.head_dim) k_raw = qkv[:, :, self.k_offset : self.v_offset].view( batch, seq_len, self.num_key_value_heads, self.head_dim ) v = qkv[:, :, self.v_offset :].view(batch, seq_len, self.num_key_value_heads, self.head_dim) q = self._apply_rope(self._head_rms_norm(q_raw, self.q_norm_weight), cos, sin).transpose(1, 2) k = self._apply_rope(self._head_rms_norm(k_raw, self.k_norm_weight), cos, sin).transpose(1, 2) v = v.transpose(1, 2) attn = F.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=0.0, scale=self.scale, is_causal=False, enable_gqa=True, ) attn = attn.transpose(1, 2).contiguous().view(batch, seq_len, self.q_width) hidden_states = residual + F.linear(attn, self.o_proj_weight) residual = hidden_states x = self._rms_norm(hidden_states, self.post_attention_layernorm_weight) gate_up = F.linear(x, self.gate_up_proj_weight) gate, up = gate_up.split(self.intermediate_size, dim=-1) x = F.silu(gate) * up hidden_states = residual + F.linear(x, self.down_proj_weight) return hidden_states def _rotary_cache( seq_len: int, head_dim: int, rope_theta: float, *, device: torch.device | str, dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: inv_freq = 1.0 / ( rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim) ) pos = torch.arange(seq_len, dtype=torch.float32, device=device) freqs = torch.outer(pos, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos().to(dtype=dtype).contiguous(), emb.sin().to(dtype=dtype).contiguous() ```