File size: 3,224 Bytes
c951070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from dataclasses import dataclass
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection

@dataclass
class CappellaResult:
    """
    Holds the 4 tensors required by the SDXL pipeline,
    all guaranteed to have the correct, matching sequence length.
    """
    embeds: torch.Tensor
    pooled_embeds: torch.Tensor
    negative_embeds: torch.Tensor
    negative_pooled_embeds: torch.Tensor

class Cappella:
    """
    A minimal, custom-built prompt encoder for our SDXL pipeline.
    It replaces the 'compel' dependency and is tailored for our exact use case.
    
    It correctly:
    1. Uses both SDXL tokenizers and text encoders.
    2. Truncates prompts that are too long (fixes "78 vs 77" error).
    3. Pads prompts that are too short (fixes "93 vs 77" error).
    4. Returns all 4 required embedding tensors.
    """
    def __init__(self, pipe, device):
        self.tokenizer: CLIPTokenizer = pipe.tokenizer
        self.tokenizer_2: CLIPTokenizer = pipe.tokenizer_2
        self.text_encoder: CLIPTextModel = pipe.text_encoder
        self.text_encoder_2: CLIPTextModelWithProjection = pipe.text_encoder_2
        self.device = device

    @torch.no_grad()
    def __call__(self, prompt: str, negative_prompt: str) -> CappellaResult:
        """
        Encodes the positive and negative prompts.
        """
        # Encode the positive prompt
        pos_embeds, pos_pooled = self._encode_one(prompt)
        
        # Encode the negative prompt
        neg_embeds, neg_pooled = self._encode_one(negative_prompt)

        return CappellaResult(
            embeds=pos_embeds,
            pooled_embeds=pos_pooled,
            negative_embeds=neg_embeds,
            negative_pooled_embeds=neg_pooled
        )

    def _encode_one(self, prompt: str) -> (torch.Tensor, torch.Tensor):
        """
        Runs a single prompt string through both text encoders.
        """
        # --- Tokenizer 1 (CLIP-L) ---
        tok_1_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt"
        )
        
        # --- Tokenizer 2 (OpenCLIP-G) ---
        tok_2_inputs = self.tokenizer_2(
            prompt,
            padding="max_length",
            max_length=self.tokenizer_2.model_max_length,
            truncation=True,
            return_tensors="pt"
        )

        # --- Text Encoder 1 (CLIP-L) ---
        # Gets last_hidden_state. Pooled output is not used.
        embeds_1 = self.text_encoder(
            tok_1_inputs.input_ids.to(self.device)
        ).last_hidden_state

        # --- Text Encoder 2 (OpenCLIP-G) ---
        # Gets hidden_states[-2] and the pooled output.
        output_2 = self.text_encoder_2(
            tok_2_inputs.input_ids.to(self.device),
            output_hidden_states=True
        )
        embeds_2 = output_2.hidden_states[-2]
        pooled_embeds = output_2.pooler_output

        # --- Concatenate ---
        # The final embeddings are a concatenation of both.
        prompt_embeds = torch.cat([embeds_1, embeds_2], dim=-1)

        return prompt_embeds, pooled_embeds