Spaces:
Runtime error
Runtime error
| import torch | |
| from dataclasses import dataclass | |
| from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection | |
| from typing import Tuple | |
| class CappellaResult: | |
| """ | |
| Holds the 4 tensors required by the SDXL pipeline, | |
| all guaranteed to have the correct, matching sequence length. | |
| """ | |
| embeds: torch.Tensor | |
| pooled_embeds: torch.Tensor | |
| negative_embeds: torch.Tensor | |
| negative_pooled_embeds: torch.Tensor | |
| class Cappella: | |
| """ | |
| A minimal, custom-built prompt encoder for our SDXL pipeline. | |
| It replaces the 'compel' dependency and is tailored for our exact use case. | |
| It correctly: | |
| 1. Uses both SDXL tokenizers and text encoders. | |
| 2. Truncates prompts that are too long (fixes "78 vs 77" error). | |
| 3. Pads prompts (by using max_length) to ensure they are all 77 tokens. | |
| 4. Returns all 4 required embedding tensors. | |
| """ | |
| def __init__(self, pipe, device): | |
| self.tokenizer: CLIPTokenizer = pipe.tokenizer | |
| self.tokenizer_2: CLIPTokenizer = pipe.tokenizer_2 | |
| self.text_encoder: CLIPTextModel = pipe.text_encoder | |
| self.text_encoder_2: CLIPTextModelWithProjection = pipe.text_encoder_2 | |
| self.device = device | |
| def __call__(self, prompt: str, negative_prompt: str) -> CappellaResult: | |
| """ | |
| Encodes the positive and negative prompts. | |
| """ | |
| # Encode the positive prompt | |
| pos_embeds, pos_pooled = self._encode_one(prompt) | |
| # Encode the negative prompt | |
| neg_embeds, neg_pooled = self._encode_one(negative_prompt) | |
| return CappellaResult( | |
| embeds=pos_embeds, | |
| pooled_embeds=pos_pooled, | |
| negative_embeds=neg_embeds, | |
| negative_pooled_embeds=neg_pooled | |
| ) | |
| def _encode_one(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Runs a single prompt string through both text encoders, | |
| ensuring truncation and padding to 77 tokens. | |
| """ | |
| # --- Tokenizer 1 (CLIP-L) --- | |
| tok_1_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| # --- Tokenizer 2 (OpenCLIP-G) --- | |
| tok_2_inputs = self.tokenizer_2( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer_2.model_max_length, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| # --- Text Encoder 1 (CLIP-L) --- | |
| # Gets last_hidden_state. Pooled output is not used. | |
| embeds_1 = self.text_encoder( | |
| tok_1_inputs.input_ids.to(self.device) | |
| ).last_hidden_state | |
| # --- Text Encoder 2 (OpenCLIP-G) --- | |
| # Gets hidden_states[-2] and the pooled output. | |
| output_2 = self.text_encoder_2( | |
| tok_2_inputs.input_ids.to(self.device), | |
| output_hidden_states=True | |
| ) | |
| embeds_2 = output_2.hidden_states[-2] | |
| pooled_embeds = output_2.pooler_output | |
| # --- Concatenate --- | |
| # The final embeddings are a concatenation of both. | |
| prompt_embeds = torch.cat([embeds_1, embeds_2], dim=-1) | |
| return prompt_embeds, pooled_embeds | |