primerz commited on
Commit
fc6a138
·
verified ·
1 Parent(s): b548850

Update cappella.py

Browse files
Files changed (1) hide show
  1. cappella.py +131 -41
cappella.py CHANGED
@@ -32,65 +32,155 @@ class Cappella:
32
  self.text_encoder_2: CLIPTextModelWithProjection = pipe.text_encoder_2
33
  self.device = device
34
 
 
 
35
  @torch.no_grad()
36
  def __call__(self, prompt: str, negative_prompt: str) -> CappellaResult:
37
  """
38
  Encodes the positive and negative prompts.
 
39
  """
40
  # Encode the positive prompt
41
  pos_embeds, pos_pooled = self._encode_one(prompt)
42
 
43
  # Encode the negative prompt
44
  neg_embeds, neg_pooled = self._encode_one(negative_prompt)
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  return CappellaResult(
47
  embeds=pos_embeds,
48
  pooled_embeds=pos_pooled,
49
  negative_embeds=neg_embeds,
50
  negative_pooled_embeds=neg_pooled
51
  )
52
-
53
  def _encode_one(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
54
  """
55
- Runs a single prompt string through both text encoders,
56
- ensuring truncation and padding to 77 tokens.
57
  """
58
- # --- Tokenizer 1 (CLIP-L) ---
59
- tok_1_inputs = self.tokenizer(
60
- prompt,
61
- padding="max_length",
62
- max_length=self.tokenizer.model_max_length,
63
- truncation=True,
64
- return_tensors="pt"
65
- )
66
 
67
- # --- Tokenizer 2 (OpenCLIP-G) ---
68
- tok_2_inputs = self.tokenizer_2(
69
- prompt,
70
- padding="max_length",
71
- max_length=self.tokenizer_2.model_max_length,
72
- truncation=True,
73
- return_tensors="pt"
74
- )
75
-
76
- # --- Text Encoder 1 (CLIP-L) ---
77
- # Gets last_hidden_state. Pooled output is not used.
78
- embeds_1 = self.text_encoder(
79
- tok_1_inputs.input_ids.to(self.device)
80
- ).last_hidden_state
81
-
82
- # --- Text Encoder 2 (OpenCLIP-G) ---
83
- # Gets hidden_states[-2] and the pooled output.
84
- output_2 = self.text_encoder_2(
85
- tok_2_inputs.input_ids.to(self.device),
86
- output_hidden_states=True
87
- )
88
- embeds_2 = output_2.hidden_states[-2]
89
- pooled_embeds = output_2.text_embeds
90
-
91
- # --- Concatenate ---
92
- # The final embeddings are a concatenation of both.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  prompt_embeds = torch.cat([embeds_1, embeds_2], dim=-1)
94
-
95
- return prompt_embeds, pooled_embeds
96
-
 
32
  self.text_encoder_2: CLIPTextModelWithProjection = pipe.text_encoder_2
33
  self.device = device
34
 
35
+
36
+ # In cappella.py
37
  @torch.no_grad()
38
  def __call__(self, prompt: str, negative_prompt: str) -> CappellaResult:
39
  """
40
  Encodes the positive and negative prompts.
41
+ Ensures both embedding tensors have the same sequence length.
42
  """
43
  # Encode the positive prompt
44
  pos_embeds, pos_pooled = self._encode_one(prompt)
45
 
46
  # Encode the negative prompt
47
  neg_embeds, neg_pooled = self._encode_one(negative_prompt)
48
+
49
+ # --- START FIX: Pad shorter embeds ---
50
+ # Ensure embeds and negative_embeds have the same sequence length
51
+ seq_len_pos = pos_embeds.shape[1]
52
+ seq_len_neg = neg_embeds.shape[1]
53
+
54
+ if seq_len_pos > seq_len_neg:
55
+ # Pad negative embeds
56
+ pad_len = seq_len_pos - seq_len_neg
57
+ padding = torch.zeros(
58
+ (neg_embeds.shape[0], pad_len, neg_embeds.shape[2]),
59
+ device=self.device, dtype=neg_embeds.dtype
60
+ )
61
+ neg_embeds = torch.cat([neg_embeds, padding], dim=1)
62
+
63
+ elif seq_len_neg > seq_len_pos:
64
+ # Pad positive embeds
65
+ pad_len = seq_len_neg - seq_len_pos
66
+ padding = torch.zeros(
67
+ (pos_embeds.shape[0], pad_len, pos_embeds.shape[2]),
68
+ device=self.device, dtype=pos_embeds.dtype
69
+ )
70
+ pos_embeds = torch.cat([pos_embeds, padding], dim=1)
71
+
72
+ # Now seq_len_pos and seq_len_neg are guaranteed to be equal
73
+ # --- END FIX ---
74
+
75
  return CappellaResult(
76
  embeds=pos_embeds,
77
  pooled_embeds=pos_pooled,
78
  negative_embeds=neg_embeds,
79
  negative_pooled_embeds=neg_pooled
80
  )
81
+
82
  def _encode_one(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
83
  """
84
+ Runs a single prompt string through both text encoders.
85
+ Handles prompts longer than 77 tokens by chunking.
86
  """
 
 
 
 
 
 
 
 
87
 
88
+ # --- Get Tokenizers and Encoders ---
89
+ tokenizers = [self.tokenizer, self.tokenizer_2]
90
+ text_encoders = [self.text_encoder, self.text_encoder_2]
91
+
92
+ prompt_embeds_list = []
93
+ pooled_prompt_embeds = None
94
+
95
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
96
+ # --- Tokenize ---
97
+ # Tokenize without padding or truncation first
98
+ text_inputs = tokenizer(
99
+ prompt,
100
+ padding=False,
101
+ truncation=False,
102
+ return_tensors="pt"
103
+ )
104
+ input_ids = text_inputs.input_ids.to(self.device)
105
+
106
+ # --- Chunking ---
107
+ # Manually chunk the input_ids
108
+ max_length = tokenizer.model_max_length
109
+ bos = tokenizer.bos_token_id
110
+ eos = tokenizer.eos_token_id
111
+
112
+ # We subtract 2 for BOS and EOS
113
+ chunk_length = max_length - 2
114
+
115
+ # Get all token IDs *except* BOS and EOS
116
+ clean_input_ids = input_ids[0, 1:-1]
117
+
118
+ # Split into chunks
119
+ chunks = [clean_input_ids[i:i + chunk_length] for i in range(0, len(clean_input_ids), chunk_length)]
120
+
121
+ # --- Prepare Batches ---
122
+ batch_input_ids = []
123
+ for chunk in chunks:
124
+ # Add BOS and EOS
125
+ chunk_with_bos_eos = torch.cat([
126
+ torch.tensor([bos], dtype=torch.long, device=self.device),
127
+ chunk.to(torch.long),
128
+ torch.tensor([eos], dtype=torch.long, device=self.device)
129
+ ])
130
+
131
+ # Pad to max_length
132
+ pad_len = max_length - len(chunk_with_bos_eos)
133
+ if pad_len > 0:
134
+ padding = torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long, device=self.device)
135
+ chunk_with_bos_eos = torch.cat([chunk_with_bos_eos, padding])
136
+
137
+ batch_input_ids.append(chunk_with_bos_eos)
138
+
139
+ if not batch_input_ids:
140
+ # Handle empty prompt
141
+ batch_input_ids.append(
142
+ torch.full((max_length,), tokenizer.pad_token_id, dtype=torch.long, device=self.device)
143
+ )
144
+
145
+ batch_input_ids = torch.stack(batch_input_ids)
146
+
147
+ # --- Encode ---
148
+ if text_encoder == self.text_encoder:
149
+ # Text Encoder 1 (CLIP-L)
150
+ # We only need the last_hidden_state
151
+ encoder_output = text_encoder(
152
+ batch_input_ids,
153
+ output_hidden_states=False
154
+ )
155
+ # [num_chunks, 77, 768]
156
+ prompt_embeds = encoder_output.last_hidden_state
157
+ prompt_embeds_list.append(prompt_embeds)
158
+
159
+ elif text_encoder == self.text_encoder_2:
160
+ # Text Encoder 2 (OpenCLIP-G)
161
+ # We need hidden_states[-2] and the pooled output from the FIRST chunk
162
+ encoder_output = text_encoder(
163
+ batch_input_ids,
164
+ output_hidden_states=True
165
+ )
166
+ # [num_chunks, 77, 1280]
167
+ prompt_embeds = encoder_output.hidden_states[-2]
168
+ prompt_embeds_list.append(prompt_embeds)
169
+
170
+ # Pooled output comes from the FIRST chunk
171
+ # We use .text_embeds which is the pooled output
172
+ # [num_chunks, 1280]
173
+ all_pooled = encoder_output.text_embeds
174
+ pooled_prompt_embeds = all_pooled[0:1] # Keep as [1, 1280]
175
+
176
+ # --- Concatenate Chunks ---
177
+ # Reshape from [num_chunks, 77, dim] to [1, num_chunks*77, dim]
178
+ # and then concatenate along the dim=-1
179
+
180
+ embeds_1 = prompt_embeds_list[0].reshape(1, -1, prompt_embeds_list[0].shape[-1])
181
+ embeds_2 = prompt_embeds_list[1].reshape(1, -1, prompt_embeds_list[1].shape[-1])
182
+
183
  prompt_embeds = torch.cat([embeds_1, embeds_2], dim=-1)
184
+
185
+ # pooled_prompt_embeds is already [1, 1280] from Encoder 2's first chunk
186
+ return prompt_embeds, pooled_prompt_embeds