pixeldit-diffusers / pixeldit /text_encoder_gemma.py
madtune's picture
Upload folder using huggingface_hub
e9200bf verified
raw
history blame
3.52 kB
"""
Gemma-2-2B text encoder for PixelDiT.
Handles chi_prompt prefix + select_index to match training exactly.
Usage:
from pixeldit.text_encoder_gemma import GemmaEncoder
enc = GemmaEncoder()
cond = enc.encode(["a dragon at sunset"]) # [1, 300, 2304]
null = enc.encode_null(1) # [1, 300, 2304]
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
_GEMMA_ID = "Efficient-Large-Model/gemma-2-2b-it"
_GEMMA_DIM = 2304
_TXT_MAX = 300
_CHI_PROMPT = "\n".join([
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:',
'- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.',
'- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.',
'Here are examples of how to transform or refine prompts:',
'- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.',
'- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.',
'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:',
'User Prompt: ',
])
_SELECT_IDX = [0] + list(range(-(_TXT_MAX - 1), 0))
class GemmaEncoder:
def __init__(
self,
model_id=_GEMMA_ID,
output_device="cuda",
output_dtype=torch.bfloat16,
):
self.output_device = torch.device(output_device)
self.output_dtype = output_dtype
print(f"[GemmaEncoder] loading {model_id} (CPU)")
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.tokenizer.padding_side = "right"
self._model = (
AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
.get_decoder().eval()
)
self._num_chi_tokens = len(self.tokenizer.encode(_CHI_PROMPT))
print("[GemmaEncoder] ready")
@torch.no_grad()
def encode(self, texts: list[str]) -> torch.Tensor:
"""Returns [B, 300, 2304]."""
texts_full = [_CHI_PROMPT + t for t in texts]
max_len = self._num_chi_tokens + _TXT_MAX - 2
tok = self.tokenizer(
texts_full, max_length=max_len,
padding="max_length", truncation=True, return_tensors="pt",
)
emb = self._model(
input_ids=tok.input_ids,
attention_mask=tok.attention_mask,
).last_hidden_state
emb = emb[:, _SELECT_IDX, :]
return emb.to(self.output_device).to(self.output_dtype)
@torch.no_grad()
def encode_null(self, batch_size: int) -> torch.Tensor:
"""Returns [B, 300, 2304] for empty string (CFG unconditional)."""
tok = self.tokenizer(
[""] * batch_size, max_length=_TXT_MAX,
padding="max_length", truncation=True, return_tensors="pt",
)
emb = self._model(
input_ids=tok.input_ids,
attention_mask=tok.attention_mask,
).last_hidden_state
return emb.to(self.output_device).to(self.output_dtype)