Spaces:
Runtime error
Runtime error
| import torch | |
| from dataclasses import dataclass | |
| from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection | |
| from typing import Tuple | |
| class CappellaResult: | |
| """ | |
| Holds the 4 tensors required by the SDXL pipeline, | |
| all guaranteed to have the correct, matching sequence length. | |
| """ | |
| embeds: torch.Tensor | |
| pooled_embeds: torch.Tensor | |
| negative_embeds: torch.Tensor | |
| negative_pooled_embeds: torch.Tensor | |
| class Cappella: | |
| """ | |
| A minimal, custom-built prompt encoder for our SDXL pipeline. | |
| It replaces the 'compel' dependency and is tailored for our exact use case. | |
| It correctly: | |
| 1. Uses both SDXL tokenizers and text encoders. | |
| 2. Truncates prompts that are too long (fixes "78 vs 77" error). | |
| 3. Pads prompts (by using max_length) to ensure they are all 77 tokens. | |
| 4. Returns all 4 required embedding tensors. | |
| """ | |
| def __init__(self, pipe, device): | |
| self.tokenizer: CLIPTokenizer = pipe.tokenizer | |
| self.tokenizer_2: CLIPTokenizer = pipe.tokenizer_2 | |
| self.text_encoder: CLIPTextModel = pipe.text_encoder | |
| self.text_encoder_2: CLIPTextModelWithProjection = pipe.text_encoder_2 | |
| self.device = device | |
| # In cappella.py | |
| def __call__(self, prompt: str, negative_prompt: str) -> CappellaResult: | |
| """ | |
| Encodes the positive and negative prompts. | |
| Ensures both embedding tensors have the same sequence length. | |
| """ | |
| # Encode the positive prompt | |
| pos_embeds, pos_pooled = self._encode_one(prompt) | |
| # Encode the negative prompt | |
| neg_embeds, neg_pooled = self._encode_one(negative_prompt) | |
| # --- START FIX: Pad shorter embeds --- | |
| # Ensure embeds and negative_embeds have the same sequence length | |
| seq_len_pos = pos_embeds.shape[1] | |
| seq_len_neg = neg_embeds.shape[1] | |
| if seq_len_pos > seq_len_neg: | |
| # Pad negative embeds | |
| pad_len = seq_len_pos - seq_len_neg | |
| padding = torch.zeros( | |
| (neg_embeds.shape[0], pad_len, neg_embeds.shape[2]), | |
| device=self.device, dtype=neg_embeds.dtype | |
| ) | |
| neg_embeds = torch.cat([neg_embeds, padding], dim=1) | |
| elif seq_len_neg > seq_len_pos: | |
| # Pad positive embeds | |
| pad_len = seq_len_neg - seq_len_pos | |
| padding = torch.zeros( | |
| (pos_embeds.shape[0], pad_len, pos_embeds.shape[2]), | |
| device=self.device, dtype=pos_embeds.dtype | |
| ) | |
| pos_embeds = torch.cat([pos_embeds, padding], dim=1) | |
| # Now seq_len_pos and seq_len_neg are guaranteed to be equal | |
| # --- END FIX --- | |
| return CappellaResult( | |
| embeds=pos_embeds, | |
| pooled_embeds=pos_pooled, | |
| negative_embeds=neg_embeds, | |
| negative_pooled_embeds=neg_pooled | |
| ) | |
| def _encode_one(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Runs a single prompt string through both text encoders. | |
| Handles prompts longer than 77 tokens by chunking. | |
| """ | |
| # --- Get Tokenizers and Encoders --- | |
| tokenizers = [self.tokenizer, self.tokenizer_2] | |
| text_encoders = [self.text_encoder, self.text_encoder_2] | |
| prompt_embeds_list = [] | |
| pooled_prompt_embeds = None | |
| for tokenizer, text_encoder in zip(tokenizers, text_encoders): | |
| # --- Tokenize --- | |
| # Tokenize without padding or truncation first | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding=False, | |
| truncation=False, | |
| return_tensors="pt" | |
| ) | |
| input_ids = text_inputs.input_ids.to(self.device) | |
| # --- Chunking --- | |
| # Manually chunk the input_ids | |
| max_length = tokenizer.model_max_length | |
| bos = tokenizer.bos_token_id | |
| eos = tokenizer.eos_token_id | |
| # We subtract 2 for BOS and EOS | |
| chunk_length = max_length - 2 | |
| # Get all token IDs *except* BOS and EOS | |
| clean_input_ids = input_ids[0, 1:-1] | |
| # Split into chunks | |
| chunks = [clean_input_ids[i:i + chunk_length] for i in range(0, len(clean_input_ids), chunk_length)] | |
| # --- Prepare Batches --- | |
| batch_input_ids = [] | |
| for chunk in chunks: | |
| # Add BOS and EOS | |
| chunk_with_bos_eos = torch.cat([ | |
| torch.tensor([bos], dtype=torch.long, device=self.device), | |
| chunk.to(torch.long), | |
| torch.tensor([eos], dtype=torch.long, device=self.device) | |
| ]) | |
| # Pad to max_length | |
| pad_len = max_length - len(chunk_with_bos_eos) | |
| if pad_len > 0: | |
| padding = torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long, device=self.device) | |
| chunk_with_bos_eos = torch.cat([chunk_with_bos_eos, padding]) | |
| batch_input_ids.append(chunk_with_bos_eos) | |
| if not batch_input_ids: | |
| # Handle empty prompt | |
| batch_input_ids.append( | |
| torch.full((max_length,), tokenizer.pad_token_id, dtype=torch.long, device=self.device) | |
| ) | |
| batch_input_ids = torch.stack(batch_input_ids) | |
| # --- Encode --- | |
| if text_encoder == self.text_encoder: | |
| # Text Encoder 1 (CLIP-L) | |
| # We only need the last_hidden_state | |
| encoder_output = text_encoder( | |
| batch_input_ids, | |
| output_hidden_states=False | |
| ) | |
| # [num_chunks, 77, 768] | |
| prompt_embeds = encoder_output.last_hidden_state | |
| prompt_embeds_list.append(prompt_embeds) | |
| elif text_encoder == self.text_encoder_2: | |
| # Text Encoder 2 (OpenCLIP-G) | |
| # We need hidden_states[-2] and the pooled output from the FIRST chunk | |
| encoder_output = text_encoder( | |
| batch_input_ids, | |
| output_hidden_states=True | |
| ) | |
| # [num_chunks, 77, 1280] | |
| prompt_embeds = encoder_output.hidden_states[-2] | |
| prompt_embeds_list.append(prompt_embeds) | |
| # Pooled output comes from the FIRST chunk | |
| # We use .text_embeds which is the pooled output | |
| # [num_chunks, 1280] | |
| all_pooled = encoder_output.text_embeds | |
| pooled_prompt_embeds = all_pooled[0:1] # Keep as [1, 1280] | |
| # --- Concatenate Chunks --- | |
| # Reshape from [num_chunks, 77, dim] to [1, num_chunks*77, dim] | |
| # and then concatenate along the dim=-1 | |
| embeds_1 = prompt_embeds_list[0].reshape(1, -1, prompt_embeds_list[0].shape[-1]) | |
| embeds_2 = prompt_embeds_list[1].reshape(1, -1, prompt_embeds_list[1].shape[-1]) | |
| prompt_embeds = torch.cat([embeds_1, embeds_2], dim=-1) | |
| # pooled_prompt_embeds is already [1, 1280] from Encoder 2's first chunk | |
| return prompt_embeds, pooled_prompt_embeds |