Spaces:
Running on Zero
Running on Zero
File size: 2,773 Bytes
2dd4628 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | from __future__ import annotations
from typing import Any, Sequence
import torch
def load_clip_text_encoder(
clip_model_name: str,
*,
device: torch.device,
expected_embedding_dim: int | None,
cache_dir: str | None,
local_files_only: bool = False,
):
"""Load the CLIP tokenizer plus text encoder with the nested text config."""
from transformers import AutoConfig, AutoTokenizer, CLIPTextModel
tokenizer = AutoTokenizer.from_pretrained(
clip_model_name,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
config = AutoConfig.from_pretrained(
clip_model_name,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
text_config = getattr(config, "text_config", config)
hidden_size = int(text_config.hidden_size)
if expected_embedding_dim is not None and hidden_size != expected_embedding_dim:
raise ValueError(
f"CLIP hidden size mismatch for {clip_model_name}: "
f"expected {expected_embedding_dim}, got {hidden_size}"
)
text_model = CLIPTextModel.from_pretrained(
clip_model_name,
cache_dir=cache_dir,
config=text_config,
use_safetensors=True,
local_files_only=local_files_only,
)
text_model.eval()
text_model.requires_grad_(False)
text_model.to(device=device)
return tokenizer, text_model
@torch.no_grad()
def encode_clip_text_prompts(
prompts: Sequence[str],
*,
tokenizer: Any,
text_model: torch.nn.Module,
batch_size: int,
output_device: torch.device | None = None,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Encode text prompts the same way as dataset-side CLIP embedding caching."""
if batch_size <= 0:
raise ValueError(f"batch_size must be positive, got {batch_size}")
model_device = next(text_model.parameters()).device
if output_device is None:
output_device = model_device
hidden_size = int(text_model.config.hidden_size)
if len(prompts) == 0:
return torch.zeros((0, hidden_size), device=output_device, dtype=output_dtype)
encoded_batches: list[torch.Tensor] = []
for start in range(0, len(prompts), batch_size):
batch_prompts = list(prompts[start : start + batch_size])
encoded = tokenizer(
batch_prompts,
max_length=tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).to(model_device)
text_outputs = text_model(**encoded)
encoded_batches.append(
text_outputs.pooler_output.detach().to(device=output_device, dtype=output_dtype)
)
return torch.cat(encoded_batches, dim=0)
|