File size: 4,088 Bytes
e298226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 World Labs.
"""Qwen3-based text encoder for the FLUX-RGBD model family.

The encoder turns text prompts into a stack of hidden states drawn from
three intermediate layers of Qwen3-8B-FP8 (layers 9, 18, 27, mirroring
the FLUX.2 [klein] recipe). Output shape is

    (batch, MAX_LENGTH=512, 3 * hidden_size = 12288)

Pads to the full ``MAX_LENGTH`` (no stripping) because the FLUX.2 DiT was
trained against the padded text stream produced by Black Forest Labs.

Loads everything from HuggingFace Hub — no internal mirrors, no auth.
"""

from __future__ import annotations

import einops
import torch
from torch import Tensor, nn
from transformers import AutoModelForCausalLM, AutoTokenizer

from flux_rgbd._flux2.constants import MAX_LENGTH, OUTPUT_LAYERS_QWEN3

DEFAULT_MODEL_SPEC = "Qwen/Qwen3-8B-FP8"


class Qwen3Embedder(nn.Module):
    """Wrap a Qwen3 causal LM, surface multi-layer hidden states.

    Inference-only. Chat-template policy matches the upstream FLUX.2
    [klein] recipe: user role, generation-prompt appended, thinking
    disabled.
    """

    def __init__(
        self,
        model_spec: str = DEFAULT_MODEL_SPEC,
        *,
        device: str | torch.device = "cuda",
        max_length: int = MAX_LENGTH,
        output_layers: tuple[int, ...] = OUTPUT_LAYERS_QWEN3,
    ) -> None:
        super().__init__()
        self.model_spec = model_spec
        self.device_ = torch.device(device)
        self.max_length = max_length
        self.output_layers = tuple(output_layers)

        self._tokenizer = AutoTokenizer.from_pretrained(model_spec)
        # Load straight onto the target device. `dtype=None` lets
        # transformers honor whatever the checkpoint advertises (FP8 for
        # the Qwen3-8B-FP8 we use by default).
        self._model = AutoModelForCausalLM.from_pretrained(
            model_spec, dtype=None, device_map=str(self.device_)
        )
        self._model.eval()
        self._pad_id = self._tokenizer.pad_token_id or 0

    @staticmethod
    def _apply_chat_template(tok, text: str) -> str:
        return tok.apply_chat_template(
            [{"role": "user", "content": text}],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,
        )

    def _tokenize(self, prompts: list[str]) -> tuple[Tensor, Tensor]:
        formatted = [self._apply_chat_template(self._tokenizer, p) for p in prompts]
        # First tokenize without padding to control truncation; we'll pad
        # ourselves to `max_length` so the output sequence length is fixed
        # regardless of the prompt batch.
        encoded = self._tokenizer(
            formatted,
            padding=False,
            truncation=True,
            max_length=self.max_length,
        )
        rows = encoded["input_ids"]
        b = len(rows)
        input_ids = torch.full(
            (b, self.max_length), fill_value=self._pad_id, dtype=torch.long
        )
        attn = torch.zeros((b, self.max_length), dtype=torch.long)
        for i, row in enumerate(rows):
            n = min(len(row), self.max_length)
            input_ids[i, :n] = torch.tensor(row[:n], dtype=torch.long)
            attn[i, :n] = 1
        return input_ids.to(self.device_), attn.to(self.device_)

    @torch.inference_mode()
    def forward(self, prompts: list[str]) -> Tensor:
        """Encode a batch of prompts.

        Returns:
            Tensor of shape ``(batch, max_length, len(output_layers) * hidden_size)``.
        """
        input_ids, attention_mask = self._tokenize(list(prompts))
        out = self._model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            use_cache=False,
        )
        # `hidden_states` is a tuple of length num_hidden_layers + 1.
        layers = torch.stack(
            [out.hidden_states[k] for k in self.output_layers], dim=1
        )  # (b, len(output_layers), seq, hidden)
        return einops.rearrange(layers, "b c l d -> b l (c d)")