primerz commited on
Commit
c951070
·
verified ·
1 Parent(s): 9a7f039

Upload cappella.py

Browse files
Files changed (1) hide show
  1. cappella.py +94 -0
cappella.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass
3
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
4
+
5
+ @dataclass
6
+ class CappellaResult:
7
+ """
8
+ Holds the 4 tensors required by the SDXL pipeline,
9
+ all guaranteed to have the correct, matching sequence length.
10
+ """
11
+ embeds: torch.Tensor
12
+ pooled_embeds: torch.Tensor
13
+ negative_embeds: torch.Tensor
14
+ negative_pooled_embeds: torch.Tensor
15
+
16
+ class Cappella:
17
+ """
18
+ A minimal, custom-built prompt encoder for our SDXL pipeline.
19
+ It replaces the 'compel' dependency and is tailored for our exact use case.
20
+
21
+ It correctly:
22
+ 1. Uses both SDXL tokenizers and text encoders.
23
+ 2. Truncates prompts that are too long (fixes "78 vs 77" error).
24
+ 3. Pads prompts that are too short (fixes "93 vs 77" error).
25
+ 4. Returns all 4 required embedding tensors.
26
+ """
27
+ def __init__(self, pipe, device):
28
+ self.tokenizer: CLIPTokenizer = pipe.tokenizer
29
+ self.tokenizer_2: CLIPTokenizer = pipe.tokenizer_2
30
+ self.text_encoder: CLIPTextModel = pipe.text_encoder
31
+ self.text_encoder_2: CLIPTextModelWithProjection = pipe.text_encoder_2
32
+ self.device = device
33
+
34
+ @torch.no_grad()
35
+ def __call__(self, prompt: str, negative_prompt: str) -> CappellaResult:
36
+ """
37
+ Encodes the positive and negative prompts.
38
+ """
39
+ # Encode the positive prompt
40
+ pos_embeds, pos_pooled = self._encode_one(prompt)
41
+
42
+ # Encode the negative prompt
43
+ neg_embeds, neg_pooled = self._encode_one(negative_prompt)
44
+
45
+ return CappellaResult(
46
+ embeds=pos_embeds,
47
+ pooled_embeds=pos_pooled,
48
+ negative_embeds=neg_embeds,
49
+ negative_pooled_embeds=neg_pooled
50
+ )
51
+
52
+ def _encode_one(self, prompt: str) -> (torch.Tensor, torch.Tensor):
53
+ """
54
+ Runs a single prompt string through both text encoders.
55
+ """
56
+ # --- Tokenizer 1 (CLIP-L) ---
57
+ tok_1_inputs = self.tokenizer(
58
+ prompt,
59
+ padding="max_length",
60
+ max_length=self.tokenizer.model_max_length,
61
+ truncation=True,
62
+ return_tensors="pt"
63
+ )
64
+
65
+ # --- Tokenizer 2 (OpenCLIP-G) ---
66
+ tok_2_inputs = self.tokenizer_2(
67
+ prompt,
68
+ padding="max_length",
69
+ max_length=self.tokenizer_2.model_max_length,
70
+ truncation=True,
71
+ return_tensors="pt"
72
+ )
73
+
74
+ # --- Text Encoder 1 (CLIP-L) ---
75
+ # Gets last_hidden_state. Pooled output is not used.
76
+ embeds_1 = self.text_encoder(
77
+ tok_1_inputs.input_ids.to(self.device)
78
+ ).last_hidden_state
79
+
80
+ # --- Text Encoder 2 (OpenCLIP-G) ---
81
+ # Gets hidden_states[-2] and the pooled output.
82
+ output_2 = self.text_encoder_2(
83
+ tok_2_inputs.input_ids.to(self.device),
84
+ output_hidden_states=True
85
+ )
86
+ embeds_2 = output_2.hidden_states[-2]
87
+ pooled_embeds = output_2.pooler_output
88
+
89
+ # --- Concatenate ---
90
+ # The final embeddings are a concatenation of both.
91
+ prompt_embeds = torch.cat([embeds_1, embeds_2], dim=-1)
92
+
93
+ return prompt_embeds, pooled_embeds
94
+