pixagram-dev / cappella.py
primerz's picture
Update cappella.py
44632cd verified
raw
history blame
3.32 kB
import torch
from dataclasses import dataclass
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from typing import Tuple
@dataclass
class CappellaResult:
"""
Holds the 4 tensors required by the SDXL pipeline,
all guaranteed to have the correct, matching sequence length.
"""
embeds: torch.Tensor
pooled_embeds: torch.Tensor
negative_embeds: torch.Tensor
negative_pooled_embeds: torch.Tensor
class Cappella:
"""
A minimal, custom-built prompt encoder for our SDXL pipeline.
It replaces the 'compel' dependency and is tailored for our exact use case.
It correctly:
1. Uses both SDXL tokenizers and text encoders.
2. Truncates prompts that are too long (fixes "78 vs 77" error).
3. Pads prompts (by using max_length) to ensure they are all 77 tokens.
4. Returns all 4 required embedding tensors.
"""
def __init__(self, pipe, device):
self.tokenizer: CLIPTokenizer = pipe.tokenizer
self.tokenizer_2: CLIPTokenizer = pipe.tokenizer_2
self.text_encoder: CLIPTextModel = pipe.text_encoder
self.text_encoder_2: CLIPTextModelWithProjection = pipe.text_encoder_2
self.device = device
@torch.no_grad()
def __call__(self, prompt: str, negative_prompt: str) -> CappellaResult:
"""
Encodes the positive and negative prompts.
"""
# Encode the positive prompt
pos_embeds, pos_pooled = self._encode_one(prompt)
# Encode the negative prompt
neg_embeds, neg_pooled = self._encode_one(negative_prompt)
return CappellaResult(
embeds=pos_embeds,
pooled_embeds=pos_pooled,
negative_embeds=neg_embeds,
negative_pooled_embeds=neg_pooled
)
def _encode_one(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Runs a single prompt string through both text encoders,
ensuring truncation and padding to 77 tokens.
"""
# --- Tokenizer 1 (CLIP-L) ---
tok_1_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
)
# --- Tokenizer 2 (OpenCLIP-G) ---
tok_2_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=self.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt"
)
# --- Text Encoder 1 (CLIP-L) ---
# Gets last_hidden_state. Pooled output is not used.
embeds_1 = self.text_encoder(
tok_1_inputs.input_ids.to(self.device)
).last_hidden_state
# --- Text Encoder 2 (OpenCLIP-G) ---
# Gets hidden_states[-2] and the pooled output.
output_2 = self.text_encoder_2(
tok_2_inputs.input_ids.to(self.device),
output_hidden_states=True
)
embeds_2 = output_2.hidden_states[-2]
pooled_embeds = output_2.pooler_output
# --- Concatenate ---
# The final embeddings are a concatenation of both.
prompt_embeds = torch.cat([embeds_1, embeds_2], dim=-1)
return prompt_embeds, pooled_embeds