File size: 5,908 Bytes
eb52c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""Prompt-aware helper to encode text using a Qwen3 causal LM."""

from typing import List, Optional, Tuple

import torch
from transformers import PreTrainedTokenizerBase


MAX_SEQUENCE_LENGTH = 1024
DROP_IDX = 38
SYSTEM_PROMPT = "Describe the image, focusing on its content, artistic style, composition, lighting, color, texture, and the spatial relationships between objects and the background:"
PROMPT_TEMPLATE = (
    "<|im_start|>system\n{system_prompt}<|im_end|>\n"
    "<|im_start|>user\n{user_prompt}<|im_end|>\n"
    "<|im_start|>assistant\n"
)


def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor) -> List[torch.Tensor]:
    bool_mask = mask.bool()
    valid_lengths = bool_mask.sum(dim=1)
    selected = hidden_states[bool_mask]
    return list(torch.split(selected, valid_lengths.tolist(), dim=0))


def _trim_sequence(sequence: torch.Tensor) -> torch.Tensor:
    if sequence.size(0) <= DROP_IDX:
        return sequence.new_zeros((0, sequence.size(1)))
    end = DROP_IDX + MAX_SEQUENCE_LENGTH
    return sequence[DROP_IDX:end]


def _build_prompt(text: str) -> str:
    return PROMPT_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, user_prompt=text)


def encode_text(
    texts: List[str],
    model: torch.nn.Module,
    tokenizer: PreTrainedTokenizerBase,
    pooling: bool,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    """
    Encode captions with the Qwen3 chat template for DiT conditioning.

    Returns:
        embeddings: [batch, seq, hidden]
        attention_mask: [batch, seq]
        pooled: [batch, hidden] when pooling is True
    """

    if not texts:
        raise ValueError("texts must contain at least one caption.")

    prompts = [_build_prompt(text) for text in texts]

    inputs = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=MAX_SEQUENCE_LENGTH + DROP_IDX,
    ).to(model.device)

    with torch.no_grad():
        outputs = model.model(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            output_hidden_states=False,
        )
        hidden = outputs.last_hidden_state

    sequences = _extract_masked_hidden(hidden, inputs.attention_mask)
    trimmed = [_trim_sequence(seq) for seq in sequences]
    max_seq_len = max((seq.size(0) for seq in trimmed), default=0)
    if max_seq_len == 0:
        max_seq_len = 1

    batch_embeddings = []
    batch_masks = []
    for seq in trimmed:
        seq_len = seq.size(0)
        pad_len = max_seq_len - seq_len
        if pad_len > 0:
            pad = seq.new_zeros((pad_len, seq.size(1)))
            seq_padded = torch.cat([seq, pad], dim=0)
        else:
            seq_padded = seq
        batch_embeddings.append(seq_padded)

        mask = seq.new_zeros(max_seq_len, dtype=torch.long)
        mask[:seq_len] = 1
        batch_masks.append(mask)

    embeddings = torch.stack(batch_embeddings).to(model.dtype)
    attention_mask = torch.stack(batch_masks).to(embeddings.device)

    pooled = None
    if pooling:
        weight = attention_mask.unsqueeze(-1).to(embeddings.dtype)
        denom = weight.sum(dim=1).clamp_min(1.0)
        pooled = (embeddings * weight).sum(dim=1) / denom

    return embeddings, attention_mask, pooled


if __name__ == "__main__":
    from transformers import AutoModelForCausalLM, AutoTokenizer

    MODEL_ID = "Qwen/Qwen3-0.6B"

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, dtype=torch.bfloat16, device_map="cuda:0"
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

    texts = [
        "Impressionism landscape by Claude Monet",
        "romanticism marina by Van Gogh",
    ] * 2

    embedding, mask, pooled = encode_text(texts, model, tokenizer, True)

    sample_prompt = _build_prompt(texts[0])
    token_info = tokenizer(
        sample_prompt,
        return_tensors="pt",
        padding=False,
        truncation=False,
        add_special_tokens=False,
    )
    ids = token_info.input_ids[0]
    tokens = tokenizer.convert_ids_to_tokens(ids)

    sentinel = "__DROP_BOUNDARY__"
    sentinel_prompt = _build_prompt(sentinel)
    sentinel_ids = tokenizer(
        sentinel_prompt,
        return_tensors="pt",
        padding=False,
        truncation=False,
        add_special_tokens=False,
    ).input_ids[0]
    sentinel_token_ids = tokenizer(
        sentinel,
        return_tensors="pt",
        padding=False,
        truncation=False,
        add_special_tokens=False,
    ).input_ids[0]

    detected_drop_idx = None
    for i in range(0, sentinel_ids.shape[0] - sentinel_token_ids.shape[0] + 1):
        if torch.equal(sentinel_ids[i : i + sentinel_token_ids.shape[0]], sentinel_token_ids):
            detected_drop_idx = i
            break

    print(f"Configured DROP_IDX={DROP_IDX}, detected drop boundary={detected_drop_idx}")
    if detected_drop_idx != DROP_IDX:
        print("WARNING: DROP_IDX does not match detected boundary index!")

    print(f"Embedding shape: {embedding.shape}")
    print(f"Mask shape: {mask.shape}")
    print(f"Pooled shape: {pooled.shape}")
    print("\nToken inspection (first prompt):")

    sample_embeddings = embedding[0]
    for idx, (tok_id, token) in enumerate(zip(ids.tolist(), tokens)):
        status = "keep" if idx >= DROP_IDX else "drop"
        if status == "keep":
            trimmed_idx = idx - DROP_IDX
            if trimmed_idx < sample_embeddings.size(0):
                emb_vec = sample_embeddings[trimmed_idx]
                emb_preview = ", ".join(f"{v:.4f}" for v in emb_vec[:4])
            else:
                emb_preview = "<truncated>"
        else:
            emb_preview = "-"

        word = tokenizer.decode([tok_id]).strip() or token
        print(
            f"[{idx:03d}] id={tok_id:>6} token={token:<12} word={word:<12} status={status:>4} emb={emb_preview}"
        )