pixagram-dev / cappella.py
primerz's picture
Upload cappella.py
c951070 verified
raw
history blame
3.22 kB
import torch
from dataclasses import dataclass
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
@dataclass
class CappellaResult:
"""
Holds the 4 tensors required by the SDXL pipeline,
all guaranteed to have the correct, matching sequence length.
"""
embeds: torch.Tensor
pooled_embeds: torch.Tensor
negative_embeds: torch.Tensor
negative_pooled_embeds: torch.Tensor
class Cappella:
"""
A minimal, custom-built prompt encoder for our SDXL pipeline.
It replaces the 'compel' dependency and is tailored for our exact use case.
It correctly:
1. Uses both SDXL tokenizers and text encoders.
2. Truncates prompts that are too long (fixes "78 vs 77" error).
3. Pads prompts that are too short (fixes "93 vs 77" error).
4. Returns all 4 required embedding tensors.
"""
def __init__(self, pipe, device):
self.tokenizer: CLIPTokenizer = pipe.tokenizer
self.tokenizer_2: CLIPTokenizer = pipe.tokenizer_2
self.text_encoder: CLIPTextModel = pipe.text_encoder
self.text_encoder_2: CLIPTextModelWithProjection = pipe.text_encoder_2
self.device = device
@torch.no_grad()
def __call__(self, prompt: str, negative_prompt: str) -> CappellaResult:
"""
Encodes the positive and negative prompts.
"""
# Encode the positive prompt
pos_embeds, pos_pooled = self._encode_one(prompt)
# Encode the negative prompt
neg_embeds, neg_pooled = self._encode_one(negative_prompt)
return CappellaResult(
embeds=pos_embeds,
pooled_embeds=pos_pooled,
negative_embeds=neg_embeds,
negative_pooled_embeds=neg_pooled
)
def _encode_one(self, prompt: str) -> (torch.Tensor, torch.Tensor):
"""
Runs a single prompt string through both text encoders.
"""
# --- Tokenizer 1 (CLIP-L) ---
tok_1_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
)
# --- Tokenizer 2 (OpenCLIP-G) ---
tok_2_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=self.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt"
)
# --- Text Encoder 1 (CLIP-L) ---
# Gets last_hidden_state. Pooled output is not used.
embeds_1 = self.text_encoder(
tok_1_inputs.input_ids.to(self.device)
).last_hidden_state
# --- Text Encoder 2 (OpenCLIP-G) ---
# Gets hidden_states[-2] and the pooled output.
output_2 = self.text_encoder_2(
tok_2_inputs.input_ids.to(self.device),
output_hidden_states=True
)
embeds_2 = output_2.hidden_states[-2]
pooled_embeds = output_2.pooler_output
# --- Concatenate ---
# The final embeddings are a concatenation of both.
prompt_embeds = torch.cat([embeds_1, embeds_2], dim=-1)
return prompt_embeds, pooled_embeds