instruct-particulate / instruct_particulate /utils /text_embedding_utils.py
rayli's picture
Cleanup demo code paths
2f3ab6d verified
Raw
History Blame Contribute Delete
2.77 kB
from __future__ import annotations
from typing import Any, Sequence
import torch
def load_clip_text_encoder(
clip_model_name: str,
*,
device: torch.device,
expected_embedding_dim: int | None,
cache_dir: str | None,
local_files_only: bool = False,
):
"""Load the CLIP tokenizer plus text encoder with the nested text config."""
from transformers import AutoConfig, AutoTokenizer, CLIPTextModel
tokenizer = AutoTokenizer.from_pretrained(
clip_model_name,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
config = AutoConfig.from_pretrained(
clip_model_name,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
text_config = getattr(config, "text_config", config)
hidden_size = int(text_config.hidden_size)
if expected_embedding_dim is not None and hidden_size != expected_embedding_dim:
raise ValueError(
f"CLIP hidden size mismatch for {clip_model_name}: "
f"expected {expected_embedding_dim}, got {hidden_size}"
)
text_model = CLIPTextModel.from_pretrained(
clip_model_name,
cache_dir=cache_dir,
config=text_config,
use_safetensors=True,
local_files_only=local_files_only,
)
text_model.eval()
text_model.requires_grad_(False)
text_model.to(device=device)
return tokenizer, text_model
@torch.no_grad()
def encode_clip_text_prompts(
prompts: Sequence[str],
*,
tokenizer: Any,
text_model: torch.nn.Module,
batch_size: int,
output_device: torch.device | None = None,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Encode text prompts the same way as dataset-side CLIP embedding caching."""
if batch_size <= 0:
raise ValueError(f"batch_size must be positive, got {batch_size}")
model_device = next(text_model.parameters()).device
if output_device is None:
output_device = model_device
hidden_size = int(text_model.config.hidden_size)
if len(prompts) == 0:
return torch.zeros((0, hidden_size), device=output_device, dtype=output_dtype)
encoded_batches: list[torch.Tensor] = []
for start in range(0, len(prompts), batch_size):
batch_prompts = list(prompts[start : start + batch_size])
encoded = tokenizer(
batch_prompts,
max_length=tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).to(model_device)
text_outputs = text_model(**encoded)
encoded_batches.append(
text_outputs.pooler_output.detach().to(device=output_device, dtype=output_dtype)
)
return torch.cat(encoded_batches, dim=0)