Spaces:
Running on Zero
Running on Zero
| # SPDX-License-Identifier: Apache-2.0 | |
| # Copyright (c) 2026 World Labs. | |
| """Qwen3-based text encoder for the FLUX-RGBD model family. | |
| The encoder turns text prompts into a stack of hidden states drawn from | |
| three intermediate layers of Qwen3-8B-FP8 (layers 9, 18, 27, mirroring | |
| the FLUX.2 [klein] recipe). Output shape is | |
| (batch, MAX_LENGTH=512, 3 * hidden_size = 12288) | |
| Pads to the full ``MAX_LENGTH`` (no stripping) because the FLUX.2 DiT was | |
| trained against the padded text stream produced by Black Forest Labs. | |
| Loads everything from HuggingFace Hub — no internal mirrors, no auth. | |
| """ | |
| from __future__ import annotations | |
| import einops | |
| import torch | |
| from torch import Tensor, nn | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from flux_rgbd._flux2.constants import MAX_LENGTH, OUTPUT_LAYERS_QWEN3 | |
| DEFAULT_MODEL_SPEC = "Qwen/Qwen3-8B-FP8" | |
| class Qwen3Embedder(nn.Module): | |
| """Wrap a Qwen3 causal LM, surface multi-layer hidden states. | |
| Inference-only. Chat-template policy matches the upstream FLUX.2 | |
| [klein] recipe: user role, generation-prompt appended, thinking | |
| disabled. | |
| """ | |
| def __init__( | |
| self, | |
| model_spec: str = DEFAULT_MODEL_SPEC, | |
| *, | |
| device: str | torch.device = "cuda", | |
| max_length: int = MAX_LENGTH, | |
| output_layers: tuple[int, ...] = OUTPUT_LAYERS_QWEN3, | |
| ) -> None: | |
| super().__init__() | |
| self.model_spec = model_spec | |
| self.device_ = torch.device(device) | |
| self.max_length = max_length | |
| self.output_layers = tuple(output_layers) | |
| self._tokenizer = AutoTokenizer.from_pretrained(model_spec) | |
| # Load straight onto the target device. `dtype=None` lets | |
| # transformers honor whatever the checkpoint advertises (FP8 for | |
| # the Qwen3-8B-FP8 we use by default). | |
| self._model = AutoModelForCausalLM.from_pretrained( | |
| model_spec, dtype=None, device_map=str(self.device_) | |
| ) | |
| self._model.eval() | |
| self._pad_id = self._tokenizer.pad_token_id or 0 | |
| def _apply_chat_template(tok, text: str) -> str: | |
| return tok.apply_chat_template( | |
| [{"role": "user", "content": text}], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| def _tokenize(self, prompts: list[str]) -> tuple[Tensor, Tensor]: | |
| formatted = [self._apply_chat_template(self._tokenizer, p) for p in prompts] | |
| # First tokenize without padding to control truncation; we'll pad | |
| # ourselves to `max_length` so the output sequence length is fixed | |
| # regardless of the prompt batch. | |
| encoded = self._tokenizer( | |
| formatted, | |
| padding=False, | |
| truncation=True, | |
| max_length=self.max_length, | |
| ) | |
| rows = encoded["input_ids"] | |
| b = len(rows) | |
| input_ids = torch.full( | |
| (b, self.max_length), fill_value=self._pad_id, dtype=torch.long | |
| ) | |
| attn = torch.zeros((b, self.max_length), dtype=torch.long) | |
| for i, row in enumerate(rows): | |
| n = min(len(row), self.max_length) | |
| input_ids[i, :n] = torch.tensor(row[:n], dtype=torch.long) | |
| attn[i, :n] = 1 | |
| return input_ids.to(self.device_), attn.to(self.device_) | |
| def forward(self, prompts: list[str]) -> Tensor: | |
| """Encode a batch of prompts. | |
| Returns: | |
| Tensor of shape ``(batch, max_length, len(output_layers) * hidden_size)``. | |
| """ | |
| input_ids, attention_mask = self._tokenize(list(prompts)) | |
| out = self._model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True, | |
| use_cache=False, | |
| ) | |
| # `hidden_states` is a tuple of length num_hidden_layers + 1. | |
| layers = torch.stack( | |
| [out.hidden_states[k] for k in self.output_layers], dim=1 | |
| ) # (b, len(output_layers), seq, hidden) | |
| return einops.rearrange(layers, "b c l d -> b l (c d)") | |