import torch from dataclasses import dataclass from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection from typing import Tuple @dataclass class CappellaResult: """ Holds the 4 tensors required by the SDXL pipeline, all guaranteed to have the correct, matching sequence length. """ embeds: torch.Tensor pooled_embeds: torch.Tensor negative_embeds: torch.Tensor negative_pooled_embeds: torch.Tensor class Cappella: """ A minimal, custom-built prompt encoder for our SDXL pipeline. It replaces the 'compel' dependency and is tailored for our exact use case. It correctly: 1. Uses both SDXL tokenizers and text encoders. 2. Truncates prompts that are too long (fixes "78 vs 77" error). 3. Pads prompts (by using max_length) to ensure they are all 77 tokens. 4. Returns all 4 required embedding tensors. """ def __init__(self, pipe, device): self.tokenizer: CLIPTokenizer = pipe.tokenizer self.tokenizer_2: CLIPTokenizer = pipe.tokenizer_2 self.text_encoder: CLIPTextModel = pipe.text_encoder self.text_encoder_2: CLIPTextModelWithProjection = pipe.text_encoder_2 self.device = device @torch.no_grad() def __call__(self, prompt: str, negative_prompt: str) -> CappellaResult: """ Encodes the positive and negative prompts. """ # Encode the positive prompt pos_embeds, pos_pooled = self._encode_one(prompt) # Encode the negative prompt neg_embeds, neg_pooled = self._encode_one(negative_prompt) return CappellaResult( embeds=pos_embeds, pooled_embeds=pos_pooled, negative_embeds=neg_embeds, negative_pooled_embeds=neg_pooled ) def _encode_one(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: """ Runs a single prompt string through both text encoders, ensuring truncation and padding to 77 tokens. """ # --- Tokenizer 1 (CLIP-L) --- tok_1_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt" ) # --- Tokenizer 2 (OpenCLIP-G) --- tok_2_inputs = self.tokenizer_2( prompt, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt" ) # --- Text Encoder 1 (CLIP-L) --- # Gets last_hidden_state. Pooled output is not used. embeds_1 = self.text_encoder( tok_1_inputs.input_ids.to(self.device) ).last_hidden_state # --- Text Encoder 2 (OpenCLIP-G) --- # Gets hidden_states[-2] and the pooled output. output_2 = self.text_encoder_2( tok_2_inputs.input_ids.to(self.device), output_hidden_states=True ) embeds_2 = output_2.hidden_states[-2] pooled_embeds = output_2.pooler_output # --- Concatenate --- # The final embeddings are a concatenation of both. prompt_embeds = torch.cat([embeds_1, embeds_2], dim=-1) return prompt_embeds, pooled_embeds