File size: 2,128 Bytes
e4b7689
ea08658
 
e4b7689
 
ea08658
eb8610d
 
e4b7689
 
ea08658
e4b7689
 
 
ea08658
 
 
e4b7689
ea08658
 
e4b7689
ea08658
e4b7689
 
ea08658
 
e4b7689
 
 
 
ea08658
 
e4b7689
ea08658
e4b7689
 
ea08658
e4b7689
 
 
 
ea08658
e4b7689
 
 
 
 
 
 
 
 
 
 
eb8610d
 
e4b7689
 
 
 
 
 
ea08658
e4b7689
 
 
 
 
 
ea08658
 
 
e4b7689
ea08658
e4b7689
 
 
ea08658
e4b7689
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Union, Optional


def get_qwen_prompt_embeds(
    text_encoder: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt: Union[str, List[str]],
    max_sequence_length: int = 512,
    hidden_layer: int = -1,  # dernière couche
):
    prompt = [prompt] if isinstance(prompt, str) else prompt

    # Tokenisation simple (pas de chat template)
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_sequence_length,
    ).to(text_encoder.device)

    with torch.inference_mode():
        outputs = text_encoder(
            **inputs,
            output_hidden_states=True,
            use_cache=False,
        )

    # hidden_states[-1] = dernière couche
    hidden = outputs.hidden_states[hidden_layer]  # [B, L, D]

    return hidden  # pas de concat, pas de reshape


def prepare_text_ids(x: torch.Tensor):
    B, L, _ = x.shape
    out_ids = []

    for i in range(B):
        t = torch.arange(1)
        h = torch.arange(1)
        w = torch.arange(1)
        l = torch.arange(L)

        coords = torch.cartesian_prod(t, h, w, l)
        out_ids.append(coords)

    return torch.stack(out_ids)


def encode_prompt(
    text_encoder: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt: Union[str, List[str]],
    num_images_per_prompt: int = 1,
    prompt_embeds: Optional[torch.Tensor] = None,
    max_sequence_length: int = 512,
):
    if prompt_embeds is None:
        prompt_embeds = get_qwen_prompt_embeds(
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            prompt=prompt,
            max_sequence_length=max_sequence_length,
        )

    B, L, D = prompt_embeds.shape

    # répéter pour plusieurs images
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(B * num_images_per_prompt, L, D)

    text_ids = prepare_text_ids(prompt_embeds)
    text_ids = text_ids.to(text_encoder.device)

    return prompt_embeds, text_ids