| This is an optimized version of the text encoder used in flux2klein 9B. Same weights/architecture (Qwen3), just stripped down code that, under torch.compile, is 1.3x faster and uses less peak VRAM (should save a couple gigs). |
|
|
|
|
| ``` |
| qwen_model = FluxQwen3TorchEmbedder.from_pretrained("fancyfeast/flux2klein-optimized-text-embedder-9B", torch_dtype=torch.bfloat16) |
| ``` |
|
|
|
|
| ``` |
| from __future__ import annotations |
| |
| import json |
| import math |
| from pathlib import Path |
| |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from transformers.models.qwen3.configuration_qwen3 import Qwen3Config |
| from transformers import PreTrainedModel |
| |
| |
| class FluxQwen3TorchEmbedder(PreTrainedModel): |
| """Stripped down and optimized Qwen3 specifically for Flux 2 Klein models. |
| In my testing this is about 1.3x faster than using the original HF implementation, and saves ~3GB of peak memory on the 8GB model. |
| |
| The output_hidden_state_indices is 9, 18, 27 for both Klein 4B and Klein 9B. |
| """ |
| config_class = Qwen3Config |
| base_model_prefix = "flux_qwen3" |
| |
| def __init__( |
| self, |
| config: Qwen3Config, |
| *, |
| output_hidden_state_indices: tuple[int, ...] = (9, 18, 27), |
| max_sequence_length: int = 512, |
| ): |
| super().__init__(config) |
| |
| self.hidden_size = config.hidden_size |
| self.num_attention_heads = config.num_attention_heads |
| self.head_dim = int(getattr(config, "head_dim", self.hidden_size // self.num_attention_heads)) |
| self.rope_theta = float(getattr(config, "rope_theta", 1000000.0)) |
| |
| self.output_hidden_state_indices = tuple(int(i) for i in output_hidden_state_indices) |
| if not self.output_hidden_state_indices: |
| raise ValueError("output_hidden_state_indices must not be empty") |
| if min(self.output_hidden_state_indices) < 1: |
| raise ValueError("output hidden state indices must be >= 1 for decoder layer outputs") |
| |
| max_layer_needed = max(self.output_hidden_state_indices) |
| if max_layer_needed > int(config.num_hidden_layers): |
| raise ValueError(f"requested hidden state after layer {max_layer_needed}, but config.num_hidden_layers={config.num_hidden_layers}") |
| |
| self.capture_slot_by_layer = { |
| layer_idx: slot for slot, layer_idx in enumerate(self.output_hidden_state_indices) |
| } |
| self.max_sequence_length = int(max_sequence_length) |
| |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=getattr(config, "pad_token_id", None)) |
| self.layers = nn.ModuleList( |
| FluxQwen3TorchLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) |
| ) |
| |
| # Built lazily/refreshed in forward so dtype/device tracks the model. |
| self.register_buffer("cos_cached", torch.empty(0), persistent=False) |
| self.register_buffer("sin_cached", torch.empty(0), persistent=False) |
| self.register_buffer("causal_mask", torch.empty(0, dtype=torch.bool), persistent=False) |
| |
| self.post_init() |
| |
| def _maybe_refresh_caches(self, *, device: torch.device, dtype: torch.dtype): |
| need_refresh = self.cos_cached.numel() == 0 or self.cos_cached.shape[0] < self.max_sequence_length or self.cos_cached.device != device or self.cos_cached.dtype != dtype |
| if not need_refresh: |
| return |
| |
| cos, sin = _rotary_cache( |
| self.max_sequence_length, |
| self.head_dim, |
| self.rope_theta, |
| device=device, |
| dtype=dtype, |
| ) |
| |
| pos = torch.arange(self.max_sequence_length, device=device) |
| causal = pos[None, :] <= pos[:, None] |
| |
| self.cos_cached = cos |
| self.sin_cached = sin |
| self.causal_mask = causal[None, None, :, :] |
| |
| @classmethod |
| def _from_original_hf_checkpoint(cls, checkpoint_path: str, subfolder: str | None) -> "FluxQwen3TorchEmbedder": |
| from huggingface_hub import hf_hub_download |
| import safetensors.torch |
| from transformers import AutoConfig |
| |
| cfg = AutoConfig.from_pretrained(checkpoint_path, subfolder=subfolder) |
| assert isinstance(cfg, Qwen3Config), f"expected Qwen3Config, got {type(cfg)}" |
| cfg.num_hidden_layers = 27 |
| if cfg.layer_types is not None: |
| cfg.layer_types = cfg.layer_types[:27] |
| cfg.max_window_layers = 27 |
| model = cls(cfg) |
| |
| # Load the original checkpoint |
| index_path = hf_hub_download(checkpoint_path, filename="model.safetensors.index.json", subfolder=subfolder) |
| index = json.loads(Path(index_path).read_text()) |
| shard_names = set(index['weight_map'].values()) |
| original_checkpoint = {} |
| |
| for shard_name in shard_names: |
| path = hf_hub_download(checkpoint_path, filename=shard_name, subfolder=subfolder) |
| shard = safetensors.torch.load_file(path) |
| original_checkpoint.update(shard) |
| |
| # Copy weights from the original checkpoint into our model |
| with torch.no_grad(): |
| model.embed_tokens.weight.copy_(original_checkpoint["model.embed_tokens.weight"]) |
| |
| for layer_idx in range(len(model.layers)): |
| layer = model.layers[layer_idx] |
| layer_base = f"model.layers.{layer_idx}." |
| |
| layer.input_layernorm_weight.copy_(original_checkpoint[layer_base + "input_layernorm.weight"]) |
| layer.post_attention_layernorm_weight.copy_(original_checkpoint[layer_base + "post_attention_layernorm.weight"]) |
| q = original_checkpoint[layer_base + "self_attn.q_proj.weight"] |
| k = original_checkpoint[layer_base + "self_attn.k_proj.weight"] |
| v = original_checkpoint[layer_base + "self_attn.v_proj.weight"] |
| layer.qkv_proj_weight.copy_(torch.cat((q, k, v), dim=0)) |
| layer.o_proj_weight.copy_(original_checkpoint[layer_base + "self_attn.o_proj.weight"]) |
| layer.q_norm_weight.copy_(original_checkpoint[layer_base + "self_attn.q_norm.weight"]) |
| layer.k_norm_weight.copy_(original_checkpoint[layer_base + "self_attn.k_norm.weight"]) |
| gate = original_checkpoint[layer_base + "mlp.gate_proj.weight"] |
| up = original_checkpoint[layer_base + "mlp.up_proj.weight"] |
| layer.gate_up_proj_weight.copy_(torch.cat((gate, up), dim=0)) |
| layer.down_proj_weight.copy_(original_checkpoint[layer_base + "mlp.down_proj.weight"]) |
| |
| return model |
| |
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| if input_ids.ndim != 2: |
| raise ValueError(f"expected input_ids [batch, seq], got {tuple(input_ids.shape)}") |
| |
| batch, seq_len = input_ids.shape |
| |
| if seq_len != self.max_sequence_length: |
| raise ValueError(f"sequence length {seq_len} does not match cached max {self.max_sequence_length}") |
| |
| dtype = self.embed_tokens.weight.dtype |
| device = input_ids.device |
| self._maybe_refresh_caches(device=device, dtype=dtype) |
| |
| key_mask = attention_mask.reshape(batch, 1, 1, seq_len).to(dtype=torch.bool) |
| sdpa_mask = self.causal_mask[:, :, :seq_len, :seq_len] & key_mask |
| |
| cos = self.cos_cached[:seq_len] |
| sin = self.sin_cached[:seq_len] |
| |
| hidden_states = self.embed_tokens(input_ids) |
| |
| prompt_embeds = torch.empty( |
| batch, |
| seq_len, |
| len(self.output_hidden_state_indices) * self.hidden_size, |
| device=input_ids.device, |
| dtype=dtype, |
| ) |
| |
| for layer_number, layer in enumerate(self.layers, start=1): |
| hidden_states = layer(hidden_states, cos, sin, sdpa_mask) |
| slot = self.capture_slot_by_layer.get(layer_number) |
| if slot is None: |
| continue |
| |
| start = slot * self.hidden_size |
| prompt_embeds[:, :, start : start + self.hidden_size].copy_(hidden_states) |
| |
| return prompt_embeds |
| |
| |
| class FluxQwen3TorchLayer(nn.Module): |
| def __init__(self, config: Qwen3Config, layer_idx: int): |
| super().__init__() |
| |
| self.layer_idx = layer_idx |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.num_attention_heads = config.num_attention_heads |
| assert config.num_key_value_heads is not None, "num_key_value_heads must be specified in config for FluxQwen3TorchLayer" |
| self.num_key_value_heads = config.num_key_value_heads |
| self.head_dim = int(getattr(config, "head_dim", self.hidden_size // self.num_attention_heads)) |
| self.rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) |
| self.scale = 1.0 / math.sqrt(self.head_dim) |
| |
| self.q_width = self.num_attention_heads * self.head_dim |
| self.kv_width = self.num_key_value_heads * self.head_dim |
| self.k_offset = self.q_width |
| self.v_offset = self.q_width + self.kv_width |
| |
| self.input_layernorm_weight = nn.Parameter(torch.empty(self.hidden_size)) |
| self.post_attention_layernorm_weight = nn.Parameter(torch.empty(self.hidden_size)) |
| |
| self.qkv_proj_weight = nn.Parameter(torch.empty(self.q_width + 2 * self.kv_width, self.hidden_size)) |
| self.o_proj_weight = nn.Parameter(torch.empty(self.hidden_size, self.q_width)) |
| |
| self.q_norm_weight = nn.Parameter(torch.empty(self.head_dim)) |
| self.k_norm_weight = nn.Parameter(torch.empty(self.head_dim)) |
| |
| self.gate_up_proj_weight = nn.Parameter(torch.empty(self.intermediate_size * 2, self.hidden_size)) |
| self.down_proj_weight = nn.Parameter(torch.empty(self.hidden_size, self.intermediate_size)) |
| |
| assert self.q_width == self.o_proj_weight.shape[1] |
| assert self.o_proj_weight.shape == (self.hidden_size, self.q_width) |
| assert self.qkv_proj_weight.shape == (self.q_width + 2 * self.kv_width, self.hidden_size) |
| |
| def _rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: |
| dtype = x.dtype |
| x_float = x.float() |
| variance = x_float.pow(2).mean(dim=-1, keepdim=True) |
| return (x_float * torch.rsqrt(variance + self.rms_norm_eps)).to(dtype) * weight |
| |
| def _head_rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: |
| dtype = x.dtype |
| x_float = x.float() |
| variance = x_float.pow(2).mean(dim=-1, keepdim=True) |
| return (x_float * torch.rsqrt(variance + self.rms_norm_eps)).to(dtype) * weight |
| |
| @staticmethod |
| def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| half = x.shape[-1] // 2 |
| x1 = x[..., :half] |
| x2 = x[..., half:] |
| rotated = torch.cat((-x2, x1), dim=-1) |
| return x * cos[:, None, :] + rotated * sin[:, None, :] |
| |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| attention_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| residual = hidden_states |
| x = self._rms_norm(hidden_states, self.input_layernorm_weight) |
| |
| batch, seq_len, _ = x.shape |
| qkv = F.linear(x, self.qkv_proj_weight) |
| q_raw = qkv[:, :, : self.q_width].view(batch, seq_len, self.num_attention_heads, self.head_dim) |
| k_raw = qkv[:, :, self.k_offset : self.v_offset].view( |
| batch, seq_len, self.num_key_value_heads, self.head_dim |
| ) |
| v = qkv[:, :, self.v_offset :].view(batch, seq_len, self.num_key_value_heads, self.head_dim) |
| |
| q = self._apply_rope(self._head_rms_norm(q_raw, self.q_norm_weight), cos, sin).transpose(1, 2) |
| k = self._apply_rope(self._head_rms_norm(k_raw, self.k_norm_weight), cos, sin).transpose(1, 2) |
| v = v.transpose(1, 2) |
| |
| attn = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=attention_mask, |
| dropout_p=0.0, |
| scale=self.scale, |
| is_causal=False, |
| enable_gqa=True, |
| ) |
| attn = attn.transpose(1, 2).contiguous().view(batch, seq_len, self.q_width) |
| hidden_states = residual + F.linear(attn, self.o_proj_weight) |
| |
| residual = hidden_states |
| x = self._rms_norm(hidden_states, self.post_attention_layernorm_weight) |
| |
| gate_up = F.linear(x, self.gate_up_proj_weight) |
| gate, up = gate_up.split(self.intermediate_size, dim=-1) |
| x = F.silu(gate) * up |
| |
| hidden_states = residual + F.linear(x, self.down_proj_weight) |
| return hidden_states |
| |
| |
| def _rotary_cache( |
| seq_len: int, |
| head_dim: int, |
| rope_theta: float, |
| *, |
| device: torch.device | str, |
| dtype: torch.dtype, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| inv_freq = 1.0 / ( |
| rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim) |
| ) |
| pos = torch.arange(seq_len, dtype=torch.float32, device=device) |
| freqs = torch.outer(pos, inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| return emb.cos().to(dtype=dtype).contiguous(), emb.sin().to(dtype=dtype).contiguous() |
| |
| ``` |
|
|
|
|
|
|