"""Prompt-aware helper to encode text using a Qwen3 causal LM.""" from typing import List, Optional, Tuple import torch from transformers import PreTrainedTokenizerBase MAX_SEQUENCE_LENGTH = 1024 DROP_IDX = 38 SYSTEM_PROMPT = "Describe the image, focusing on its content, artistic style, composition, lighting, color, texture, and the spatial relationships between objects and the background:" PROMPT_TEMPLATE = ( "<|im_start|>system\n{system_prompt}<|im_end|>\n" "<|im_start|>user\n{user_prompt}<|im_end|>\n" "<|im_start|>assistant\n" ) def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor) -> List[torch.Tensor]: bool_mask = mask.bool() valid_lengths = bool_mask.sum(dim=1) selected = hidden_states[bool_mask] return list(torch.split(selected, valid_lengths.tolist(), dim=0)) def _trim_sequence(sequence: torch.Tensor) -> torch.Tensor: if sequence.size(0) <= DROP_IDX: return sequence.new_zeros((0, sequence.size(1))) end = DROP_IDX + MAX_SEQUENCE_LENGTH return sequence[DROP_IDX:end] def _build_prompt(text: str) -> str: return PROMPT_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, user_prompt=text) def encode_text( texts: List[str], model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, pooling: bool, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Encode captions with the Qwen3 chat template for DiT conditioning. Returns: embeddings: [batch, seq, hidden] attention_mask: [batch, seq] pooled: [batch, hidden] when pooling is True """ if not texts: raise ValueError("texts must contain at least one caption.") prompts = [_build_prompt(text) for text in texts] inputs = tokenizer( prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQUENCE_LENGTH + DROP_IDX, ).to(model.device) with torch.no_grad(): outputs = model.model( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, output_hidden_states=False, ) hidden = outputs.last_hidden_state sequences = _extract_masked_hidden(hidden, inputs.attention_mask) trimmed = [_trim_sequence(seq) for seq in sequences] max_seq_len = max((seq.size(0) for seq in trimmed), default=0) if max_seq_len == 0: max_seq_len = 1 batch_embeddings = [] batch_masks = [] for seq in trimmed: seq_len = seq.size(0) pad_len = max_seq_len - seq_len if pad_len > 0: pad = seq.new_zeros((pad_len, seq.size(1))) seq_padded = torch.cat([seq, pad], dim=0) else: seq_padded = seq batch_embeddings.append(seq_padded) mask = seq.new_zeros(max_seq_len, dtype=torch.long) mask[:seq_len] = 1 batch_masks.append(mask) embeddings = torch.stack(batch_embeddings).to(model.dtype) attention_mask = torch.stack(batch_masks).to(embeddings.device) pooled = None if pooling: weight = attention_mask.unsqueeze(-1).to(embeddings.dtype) denom = weight.sum(dim=1).clamp_min(1.0) pooled = (embeddings * weight).sum(dim=1) / denom return embeddings, attention_mask, pooled if __name__ == "__main__": from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_ID = "Qwen/Qwen3-0.6B" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, device_map="cuda:0" ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) texts = [ "Impressionism landscape by Claude Monet", "romanticism marina by Van Gogh", ] * 2 embedding, mask, pooled = encode_text(texts, model, tokenizer, True) sample_prompt = _build_prompt(texts[0]) token_info = tokenizer( sample_prompt, return_tensors="pt", padding=False, truncation=False, add_special_tokens=False, ) ids = token_info.input_ids[0] tokens = tokenizer.convert_ids_to_tokens(ids) sentinel = "__DROP_BOUNDARY__" sentinel_prompt = _build_prompt(sentinel) sentinel_ids = tokenizer( sentinel_prompt, return_tensors="pt", padding=False, truncation=False, add_special_tokens=False, ).input_ids[0] sentinel_token_ids = tokenizer( sentinel, return_tensors="pt", padding=False, truncation=False, add_special_tokens=False, ).input_ids[0] detected_drop_idx = None for i in range(0, sentinel_ids.shape[0] - sentinel_token_ids.shape[0] + 1): if torch.equal(sentinel_ids[i : i + sentinel_token_ids.shape[0]], sentinel_token_ids): detected_drop_idx = i break print(f"Configured DROP_IDX={DROP_IDX}, detected drop boundary={detected_drop_idx}") if detected_drop_idx != DROP_IDX: print("WARNING: DROP_IDX does not match detected boundary index!") print(f"Embedding shape: {embedding.shape}") print(f"Mask shape: {mask.shape}") print(f"Pooled shape: {pooled.shape}") print("\nToken inspection (first prompt):") sample_embeddings = embedding[0] for idx, (tok_id, token) in enumerate(zip(ids.tolist(), tokens)): status = "keep" if idx >= DROP_IDX else "drop" if status == "keep": trimmed_idx = idx - DROP_IDX if trimmed_idx < sample_embeddings.size(0): emb_vec = sample_embeddings[trimmed_idx] emb_preview = ", ".join(f"{v:.4f}" for v in emb_vec[:4]) else: emb_preview = "" else: emb_preview = "-" word = tokenizer.decode([tok_id]).strip() or token print( f"[{idx:03d}] id={tok_id:>6} token={token:<12} word={word:<12} status={status:>4} emb={emb_preview}" )