File size: 3,319 Bytes
c951070
 
 
44632cd
c951070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44632cd
c951070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44632cd
c951070
44632cd
 
c951070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from dataclasses import dataclass
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from typing import Tuple

@dataclass
class CappellaResult:
    """
    Holds the 4 tensors required by the SDXL pipeline,
    all guaranteed to have the correct, matching sequence length.
    """
    embeds: torch.Tensor
    pooled_embeds: torch.Tensor
    negative_embeds: torch.Tensor
    negative_pooled_embeds: torch.Tensor

class Cappella:
    """
    A minimal, custom-built prompt encoder for our SDXL pipeline.
    It replaces the 'compel' dependency and is tailored for our exact use case.
    
    It correctly:
    1. Uses both SDXL tokenizers and text encoders.
    2. Truncates prompts that are too long (fixes "78 vs 77" error).
    3. Pads prompts (by using max_length) to ensure they are all 77 tokens.
    4. Returns all 4 required embedding tensors.
    """
    def __init__(self, pipe, device):
        self.tokenizer: CLIPTokenizer = pipe.tokenizer
        self.tokenizer_2: CLIPTokenizer = pipe.tokenizer_2
        self.text_encoder: CLIPTextModel = pipe.text_encoder
        self.text_encoder_2: CLIPTextModelWithProjection = pipe.text_encoder_2
        self.device = device

    @torch.no_grad()
    def __call__(self, prompt: str, negative_prompt: str) -> CappellaResult:
        """
        Encodes the positive and negative prompts.
        """
        # Encode the positive prompt
        pos_embeds, pos_pooled = self._encode_one(prompt)
        
        # Encode the negative prompt
        neg_embeds, neg_pooled = self._encode_one(negative_prompt)

        return CappellaResult(
            embeds=pos_embeds,
            pooled_embeds=pos_pooled,
            negative_embeds=neg_embeds,
            negative_pooled_embeds=neg_pooled
        )

    def _encode_one(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Runs a single prompt string through both text encoders,
        ensuring truncation and padding to 77 tokens.
        """
        # --- Tokenizer 1 (CLIP-L) ---
        tok_1_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt"
        )
        
        # --- Tokenizer 2 (OpenCLIP-G) ---
        tok_2_inputs = self.tokenizer_2(
            prompt,
            padding="max_length",
            max_length=self.tokenizer_2.model_max_length,
            truncation=True,
            return_tensors="pt"
        )

        # --- Text Encoder 1 (CLIP-L) ---
        # Gets last_hidden_state. Pooled output is not used.
        embeds_1 = self.text_encoder(
            tok_1_inputs.input_ids.to(self.device)
        ).last_hidden_state

        # --- Text Encoder 2 (OpenCLIP-G) ---
        # Gets hidden_states[-2] and the pooled output.
        output_2 = self.text_encoder_2(
            tok_2_inputs.input_ids.to(self.device),
            output_hidden_states=True
        )
        embeds_2 = output_2.hidden_states[-2]
        pooled_embeds = output_2.pooler_output

        # --- Concatenate ---
        # The final embeddings are a concatenation of both.
        prompt_embeds = torch.cat([embeds_1, embeds_2], dim=-1)

        return prompt_embeds, pooled_embeds