File size: 2,128 Bytes
e4b7689 ea08658 e4b7689 ea08658 eb8610d e4b7689 ea08658 e4b7689 ea08658 e4b7689 ea08658 e4b7689 ea08658 e4b7689 ea08658 e4b7689 ea08658 e4b7689 ea08658 e4b7689 ea08658 e4b7689 ea08658 e4b7689 eb8610d e4b7689 ea08658 e4b7689 ea08658 e4b7689 ea08658 e4b7689 ea08658 e4b7689 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Union, Optional
def get_qwen_prompt_embeds(
text_encoder: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
prompt: Union[str, List[str]],
max_sequence_length: int = 512,
hidden_layer: int = -1, # dernière couche
):
prompt = [prompt] if isinstance(prompt, str) else prompt
# Tokenisation simple (pas de chat template)
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_sequence_length,
).to(text_encoder.device)
with torch.inference_mode():
outputs = text_encoder(
**inputs,
output_hidden_states=True,
use_cache=False,
)
# hidden_states[-1] = dernière couche
hidden = outputs.hidden_states[hidden_layer] # [B, L, D]
return hidden # pas de concat, pas de reshape
def prepare_text_ids(x: torch.Tensor):
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1)
h = torch.arange(1)
w = torch.arange(1)
l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, l)
out_ids.append(coords)
return torch.stack(out_ids)
def encode_prompt(
text_encoder: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
):
if prompt_embeds is None:
prompt_embeds = get_qwen_prompt_embeds(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
max_sequence_length=max_sequence_length,
)
B, L, D = prompt_embeds.shape
# répéter pour plusieurs images
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(B * num_images_per_prompt, L, D)
text_ids = prepare_text_ids(prompt_embeds)
text_ids = text_ids.to(text_encoder.device)
return prompt_embeds, text_ids
|