Encoder / dd.py
lea97338's picture
Update dd.py
ea08658 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Union, Optional
def get_qwen_prompt_embeds(
text_encoder: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
prompt: Union[str, List[str]],
max_sequence_length: int = 512,
hidden_layer: int = -1, # dernière couche
):
prompt = [prompt] if isinstance(prompt, str) else prompt
# Tokenisation simple (pas de chat template)
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_sequence_length,
).to(text_encoder.device)
with torch.inference_mode():
outputs = text_encoder(
**inputs,
output_hidden_states=True,
use_cache=False,
)
# hidden_states[-1] = dernière couche
hidden = outputs.hidden_states[hidden_layer] # [B, L, D]
return hidden # pas de concat, pas de reshape
def prepare_text_ids(x: torch.Tensor):
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1)
h = torch.arange(1)
w = torch.arange(1)
l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, l)
out_ids.append(coords)
return torch.stack(out_ids)
def encode_prompt(
text_encoder: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
):
if prompt_embeds is None:
prompt_embeds = get_qwen_prompt_embeds(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
max_sequence_length=max_sequence_length,
)
B, L, D = prompt_embeds.shape
# répéter pour plusieurs images
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(B * num_images_per_prompt, L, D)
text_ids = prepare_text_ids(prompt_embeds)
text_ids = text_ids.to(text_encoder.device)
return prompt_embeds, text_ids