| """Prompt-aware helper to encode text using a Qwen3 causal LM.""" | |
| from typing import List, Optional, Tuple | |
| import torch | |
| from transformers import PreTrainedTokenizerBase | |
| MAX_SEQUENCE_LENGTH = 1024 | |
| DROP_IDX = 38 | |
| SYSTEM_PROMPT = "Describe the image, focusing on its content, artistic style, composition, lighting, color, texture, and the spatial relationships between objects and the background:" | |
| PROMPT_TEMPLATE = ( | |
| "<|im_start|>system\n{system_prompt}<|im_end|>\n" | |
| "<|im_start|>user\n{user_prompt}<|im_end|>\n" | |
| "<|im_start|>assistant\n" | |
| ) | |
| def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor) -> List[torch.Tensor]: | |
| bool_mask = mask.bool() | |
| valid_lengths = bool_mask.sum(dim=1) | |
| selected = hidden_states[bool_mask] | |
| return list(torch.split(selected, valid_lengths.tolist(), dim=0)) | |
| def _trim_sequence(sequence: torch.Tensor) -> torch.Tensor: | |
| if sequence.size(0) <= DROP_IDX: | |
| return sequence.new_zeros((0, sequence.size(1))) | |
| end = DROP_IDX + MAX_SEQUENCE_LENGTH | |
| return sequence[DROP_IDX:end] | |
| def _build_prompt(text: str) -> str: | |
| return PROMPT_TEMPLATE.format(system_prompt=SYSTEM_PROMPT, user_prompt=text) | |
| def encode_text( | |
| texts: List[str], | |
| model: torch.nn.Module, | |
| tokenizer: PreTrainedTokenizerBase, | |
| pooling: bool, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| Encode captions with the Qwen3 chat template for DiT conditioning. | |
| Returns: | |
| embeddings: [batch, seq, hidden] | |
| attention_mask: [batch, seq] | |
| pooled: [batch, hidden] when pooling is True | |
| """ | |
| if not texts: | |
| raise ValueError("texts must contain at least one caption.") | |
| prompts = [_build_prompt(text) for text in texts] | |
| inputs = tokenizer( | |
| prompts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=MAX_SEQUENCE_LENGTH + DROP_IDX, | |
| ).to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.model( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| output_hidden_states=False, | |
| ) | |
| hidden = outputs.last_hidden_state | |
| sequences = _extract_masked_hidden(hidden, inputs.attention_mask) | |
| trimmed = [_trim_sequence(seq) for seq in sequences] | |
| max_seq_len = max((seq.size(0) for seq in trimmed), default=0) | |
| if max_seq_len == 0: | |
| max_seq_len = 1 | |
| batch_embeddings = [] | |
| batch_masks = [] | |
| for seq in trimmed: | |
| seq_len = seq.size(0) | |
| pad_len = max_seq_len - seq_len | |
| if pad_len > 0: | |
| pad = seq.new_zeros((pad_len, seq.size(1))) | |
| seq_padded = torch.cat([seq, pad], dim=0) | |
| else: | |
| seq_padded = seq | |
| batch_embeddings.append(seq_padded) | |
| mask = seq.new_zeros(max_seq_len, dtype=torch.long) | |
| mask[:seq_len] = 1 | |
| batch_masks.append(mask) | |
| embeddings = torch.stack(batch_embeddings).to(model.dtype) | |
| attention_mask = torch.stack(batch_masks).to(embeddings.device) | |
| pooled = None | |
| if pooling: | |
| weight = attention_mask.unsqueeze(-1).to(embeddings.dtype) | |
| denom = weight.sum(dim=1).clamp_min(1.0) | |
| pooled = (embeddings * weight).sum(dim=1) / denom | |
| return embeddings, attention_mask, pooled | |
| if __name__ == "__main__": | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| MODEL_ID = "Qwen/Qwen3-0.6B" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, dtype=torch.bfloat16, device_map="cuda:0" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| texts = [ | |
| "Impressionism landscape by Claude Monet", | |
| "romanticism marina by Van Gogh", | |
| ] * 2 | |
| embedding, mask, pooled = encode_text(texts, model, tokenizer, True) | |
| sample_prompt = _build_prompt(texts[0]) | |
| token_info = tokenizer( | |
| sample_prompt, | |
| return_tensors="pt", | |
| padding=False, | |
| truncation=False, | |
| add_special_tokens=False, | |
| ) | |
| ids = token_info.input_ids[0] | |
| tokens = tokenizer.convert_ids_to_tokens(ids) | |
| sentinel = "__DROP_BOUNDARY__" | |
| sentinel_prompt = _build_prompt(sentinel) | |
| sentinel_ids = tokenizer( | |
| sentinel_prompt, | |
| return_tensors="pt", | |
| padding=False, | |
| truncation=False, | |
| add_special_tokens=False, | |
| ).input_ids[0] | |
| sentinel_token_ids = tokenizer( | |
| sentinel, | |
| return_tensors="pt", | |
| padding=False, | |
| truncation=False, | |
| add_special_tokens=False, | |
| ).input_ids[0] | |
| detected_drop_idx = None | |
| for i in range(0, sentinel_ids.shape[0] - sentinel_token_ids.shape[0] + 1): | |
| if torch.equal(sentinel_ids[i : i + sentinel_token_ids.shape[0]], sentinel_token_ids): | |
| detected_drop_idx = i | |
| break | |
| print(f"Configured DROP_IDX={DROP_IDX}, detected drop boundary={detected_drop_idx}") | |
| if detected_drop_idx != DROP_IDX: | |
| print("WARNING: DROP_IDX does not match detected boundary index!") | |
| print(f"Embedding shape: {embedding.shape}") | |
| print(f"Mask shape: {mask.shape}") | |
| print(f"Pooled shape: {pooled.shape}") | |
| print("\nToken inspection (first prompt):") | |
| sample_embeddings = embedding[0] | |
| for idx, (tok_id, token) in enumerate(zip(ids.tolist(), tokens)): | |
| status = "keep" if idx >= DROP_IDX else "drop" | |
| if status == "keep": | |
| trimmed_idx = idx - DROP_IDX | |
| if trimmed_idx < sample_embeddings.size(0): | |
| emb_vec = sample_embeddings[trimmed_idx] | |
| emb_preview = ", ".join(f"{v:.4f}" for v in emb_vec[:4]) | |
| else: | |
| emb_preview = "<truncated>" | |
| else: | |
| emb_preview = "-" | |
| word = tokenizer.decode([tok_id]).strip() or token | |
| print( | |
| f"[{idx:03d}] id={tok_id:>6} token={token:<12} word={word:<12} status={status:>4} emb={emb_preview}" | |
| ) | |