File size: 2,773 Bytes
2dd4628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from __future__ import annotations

from typing import Any, Sequence

import torch


def load_clip_text_encoder(
    clip_model_name: str,
    *,
    device: torch.device,
    expected_embedding_dim: int | None,
    cache_dir: str | None,
    local_files_only: bool = False,
):
    """Load the CLIP tokenizer plus text encoder with the nested text config."""
    from transformers import AutoConfig, AutoTokenizer, CLIPTextModel

    tokenizer = AutoTokenizer.from_pretrained(
        clip_model_name,
        cache_dir=cache_dir,
        local_files_only=local_files_only,
    )
    config = AutoConfig.from_pretrained(
        clip_model_name,
        cache_dir=cache_dir,
        local_files_only=local_files_only,
    )
    text_config = getattr(config, "text_config", config)
    hidden_size = int(text_config.hidden_size)
    if expected_embedding_dim is not None and hidden_size != expected_embedding_dim:
        raise ValueError(
            f"CLIP hidden size mismatch for {clip_model_name}: "
            f"expected {expected_embedding_dim}, got {hidden_size}"
        )

    text_model = CLIPTextModel.from_pretrained(
        clip_model_name,
        cache_dir=cache_dir,
        config=text_config,
        use_safetensors=True,
        local_files_only=local_files_only,
    )
    text_model.eval()
    text_model.requires_grad_(False)
    text_model.to(device=device)
    return tokenizer, text_model


@torch.no_grad()
def encode_clip_text_prompts(
    prompts: Sequence[str],
    *,
    tokenizer: Any,
    text_model: torch.nn.Module,
    batch_size: int,
    output_device: torch.device | None = None,
    output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    """Encode text prompts the same way as dataset-side CLIP embedding caching."""
    if batch_size <= 0:
        raise ValueError(f"batch_size must be positive, got {batch_size}")

    model_device = next(text_model.parameters()).device
    if output_device is None:
        output_device = model_device

    hidden_size = int(text_model.config.hidden_size)
    if len(prompts) == 0:
        return torch.zeros((0, hidden_size), device=output_device, dtype=output_dtype)

    encoded_batches: list[torch.Tensor] = []
    for start in range(0, len(prompts), batch_size):
        batch_prompts = list(prompts[start : start + batch_size])
        encoded = tokenizer(
            batch_prompts,
            max_length=tokenizer.model_max_length,
            padding=True,
            truncation=True,
            return_tensors="pt",
        ).to(model_device)
        text_outputs = text_model(**encoded)
        encoded_batches.append(
            text_outputs.pooler_output.detach().to(device=output_device, dtype=output_dtype)
        )
    return torch.cat(encoded_batches, dim=0)