File size: 5,355 Bytes
3fef103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""TIPSv2 model for HuggingFace — wraps vision and text encoders."""

import importlib
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union

import numpy as np
import torch
from huggingface_hub import hf_hub_download
from transformers import PreTrainedModel

from .configuration_tips import TIPSv2Config

_this_dir = Path(__file__).parent
_sibling_cache = {}


def _load_sibling(name, repo_id=None):
    """Import a sibling .py from the same dir, downloading from HF if needed."""
    if name in _sibling_cache:
        return _sibling_cache[name]
    path = _this_dir / f"{name}.py"
    if not path.exists() and repo_id:
        path = Path(hf_hub_download(repo_id, f"{name}.py"))
    spec = importlib.util.spec_from_file_location(name, str(path))
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    _sibling_cache[name] = mod
    return mod


@dataclass
class TIPSv2ImageOutput:
    """Output from the vision encoder."""
    cls_token: torch.Tensor        # (B, 1, D)
    register_tokens: torch.Tensor  # (B, R, D)
    patch_tokens: torch.Tensor     # (B, N, D)


@dataclass
class TIPSv2Output:
    """Output from the full model."""
    image_features: Optional[TIPSv2ImageOutput] = None
    text_embeds: Optional[torch.Tensor] = None
    temperature: Optional[float] = None


class TIPSv2Model(PreTrainedModel):
    """TIPSv2 vision-language model.

    Usage::

        model = AutoModel.from_pretrained("google/tipsv2-b14", trust_remote_code=True)

        # Image features
        out = model.encode_image(pixel_values)  # pixel_values in [0, 1]
        cls = out.cls_token        # (B, 1, D)
        spatial = out.patch_tokens  # (B, N, D)

        # Text features
        text_emb = model.encode_text(["a photo of a cat"])  # (B, D)
    """

    config_class = TIPSv2Config
    _no_split_modules = []
    _supports_cache_class = False
    _tied_weights_keys = []

    @property
    def all_tied_weights_keys(self):
        return {}

    def __init__(self, config: TIPSv2Config):
        super().__init__(config)

        repo_id = getattr(config, "_name_or_path", None)
        ie = _load_sibling("image_encoder", repo_id)
        te = _load_sibling("text_encoder", repo_id)

        build_fn = getattr(ie, config.vision_fn)
        self.vision_encoder = build_fn(
            img_size=config.img_size,
            patch_size=config.patch_size,
            ffn_layer=config.ffn_layer,
            block_chunks=0,
            init_values=config.init_values,
            interpolate_antialias=True,
            interpolate_offset=0.0,
        )

        self.text_encoder = te.TextEncoder(
            config={
                "hidden_size": config.text_hidden_size,
                "mlp_dim": config.text_mlp_dim,
                "num_heads": config.text_num_heads,
                "num_layers": config.text_num_layers,
            },
            vocab_size=config.vocab_size,
        )

        self._tokenizer = None
        self._te_mod = te

    def _load_tokenizer(self):
        """Lazy-load the SentencePiece tokenizer."""
        tok_path = _this_dir / "tokenizer.model"
        if not tok_path.exists():
            tok_path = hf_hub_download(self.name_or_path, "tokenizer.model")
        return self._te_mod.Tokenizer(str(tok_path))

    @torch.no_grad()
    def encode_image(self, pixel_values: torch.Tensor) -> TIPSv2ImageOutput:
        """Encode images. pixel_values: (B, 3, H, W) in [0, 1]."""
        pixel_values = pixel_values.to(self.device)
        cls_token, register_tokens, patch_tokens = self.vision_encoder(pixel_values)
        return TIPSv2ImageOutput(
            cls_token=cls_token,
            register_tokens=register_tokens,
            patch_tokens=patch_tokens,
        )

    @torch.no_grad()
    def encode_text(
        self,
        texts: Union[str, List[str], torch.Tensor],
        padding_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Encode text. Pass strings (auto-tokenized) or pre-tokenized tensors."""
        if isinstance(texts, (str, list)):
            if isinstance(texts, str):
                texts = [texts]
            if self._tokenizer is None:
                self._tokenizer = self._load_tokenizer()
            ids, paddings = self._tokenizer.tokenize(texts, max_len=self.config.max_len)
            ids = torch.from_numpy(ids).to(self.device)
            padding_mask = torch.from_numpy(paddings).to(self.device)
        else:
            ids = texts.to(self.device)
            padding_mask = padding_mask.to(self.device)
        return self.text_encoder(ids, padding_mask)

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        input_ids: Optional[torch.Tensor] = None,
        padding_mask: Optional[torch.Tensor] = None,
    ) -> TIPSv2Output:
        """Forward pass for both or either modality."""
        image_features = None
        text_embeds = None
        if pixel_values is not None:
            image_features = self.encode_image(pixel_values)
        if input_ids is not None:
            text_embeds = self.encode_text(input_ids, padding_mask)
        return TIPSv2Output(
            image_features=image_features,
            text_embeds=text_embeds,
            temperature=self.config.temperature,
        )