| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from typing import List, Union, Optional |
|
|
|
|
| def get_qwen_prompt_embeds( |
| text_encoder: AutoModelForCausalLM, |
| tokenizer: AutoTokenizer, |
| prompt: Union[str, List[str]], |
| max_sequence_length: int = 512, |
| hidden_layer: int = -1, |
| ): |
| prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
| |
| inputs = tokenizer( |
| prompt, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=max_sequence_length, |
| ).to(text_encoder.device) |
|
|
| with torch.inference_mode(): |
| outputs = text_encoder( |
| **inputs, |
| output_hidden_states=True, |
| use_cache=False, |
| ) |
|
|
| |
| hidden = outputs.hidden_states[hidden_layer] |
|
|
| return hidden |
|
|
|
|
| def prepare_text_ids(x: torch.Tensor): |
| B, L, _ = x.shape |
| out_ids = [] |
|
|
| for i in range(B): |
| t = torch.arange(1) |
| h = torch.arange(1) |
| w = torch.arange(1) |
| l = torch.arange(L) |
|
|
| coords = torch.cartesian_prod(t, h, w, l) |
| out_ids.append(coords) |
|
|
| return torch.stack(out_ids) |
|
|
|
|
| def encode_prompt( |
| text_encoder: AutoModelForCausalLM, |
| tokenizer: AutoTokenizer, |
| prompt: Union[str, List[str]], |
| num_images_per_prompt: int = 1, |
| prompt_embeds: Optional[torch.Tensor] = None, |
| max_sequence_length: int = 512, |
| ): |
| if prompt_embeds is None: |
| prompt_embeds = get_qwen_prompt_embeds( |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| prompt=prompt, |
| max_sequence_length=max_sequence_length, |
| ) |
|
|
| B, L, D = prompt_embeds.shape |
|
|
| |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| prompt_embeds = prompt_embeds.view(B * num_images_per_prompt, L, D) |
|
|
| text_ids = prepare_text_ids(prompt_embeds) |
| text_ids = text_ids.to(text_encoder.device) |
|
|
| return prompt_embeds, text_ids |
|
|