import torch from transformers import AutoTokenizer, AutoModelForCausalLM from typing import List, Union, Optional def get_qwen_prompt_embeds( text_encoder: AutoModelForCausalLM, tokenizer: AutoTokenizer, prompt: Union[str, List[str]], max_sequence_length: int = 512, hidden_layer: int = -1, # dernière couche ): prompt = [prompt] if isinstance(prompt, str) else prompt # Tokenisation simple (pas de chat template) inputs = tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=max_sequence_length, ).to(text_encoder.device) with torch.inference_mode(): outputs = text_encoder( **inputs, output_hidden_states=True, use_cache=False, ) # hidden_states[-1] = dernière couche hidden = outputs.hidden_states[hidden_layer] # [B, L, D] return hidden # pas de concat, pas de reshape def prepare_text_ids(x: torch.Tensor): B, L, _ = x.shape out_ids = [] for i in range(B): t = torch.arange(1) h = torch.arange(1) w = torch.arange(1) l = torch.arange(L) coords = torch.cartesian_prod(t, h, w, l) out_ids.append(coords) return torch.stack(out_ids) def encode_prompt( text_encoder: AutoModelForCausalLM, tokenizer: AutoTokenizer, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 512, ): if prompt_embeds is None: prompt_embeds = get_qwen_prompt_embeds( text_encoder=text_encoder, tokenizer=tokenizer, prompt=prompt, max_sequence_length=max_sequence_length, ) B, L, D = prompt_embeds.shape # répéter pour plusieurs images prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(B * num_images_per_prompt, L, D) text_ids = prepare_text_ids(prompt_embeds) text_ids = text_ids.to(text_encoder.device) return prompt_embeds, text_ids