File size: 5,355 Bytes
3fef103 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """TIPSv2 model for HuggingFace — wraps vision and text encoders."""
import importlib
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from transformers import PreTrainedModel
from .configuration_tips import TIPSv2Config
_this_dir = Path(__file__).parent
_sibling_cache = {}
def _load_sibling(name, repo_id=None):
"""Import a sibling .py from the same dir, downloading from HF if needed."""
if name in _sibling_cache:
return _sibling_cache[name]
path = _this_dir / f"{name}.py"
if not path.exists() and repo_id:
path = Path(hf_hub_download(repo_id, f"{name}.py"))
spec = importlib.util.spec_from_file_location(name, str(path))
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
_sibling_cache[name] = mod
return mod
@dataclass
class TIPSv2ImageOutput:
"""Output from the vision encoder."""
cls_token: torch.Tensor # (B, 1, D)
register_tokens: torch.Tensor # (B, R, D)
patch_tokens: torch.Tensor # (B, N, D)
@dataclass
class TIPSv2Output:
"""Output from the full model."""
image_features: Optional[TIPSv2ImageOutput] = None
text_embeds: Optional[torch.Tensor] = None
temperature: Optional[float] = None
class TIPSv2Model(PreTrainedModel):
"""TIPSv2 vision-language model.
Usage::
model = AutoModel.from_pretrained("google/tipsv2-b14", trust_remote_code=True)
# Image features
out = model.encode_image(pixel_values) # pixel_values in [0, 1]
cls = out.cls_token # (B, 1, D)
spatial = out.patch_tokens # (B, N, D)
# Text features
text_emb = model.encode_text(["a photo of a cat"]) # (B, D)
"""
config_class = TIPSv2Config
_no_split_modules = []
_supports_cache_class = False
_tied_weights_keys = []
@property
def all_tied_weights_keys(self):
return {}
def __init__(self, config: TIPSv2Config):
super().__init__(config)
repo_id = getattr(config, "_name_or_path", None)
ie = _load_sibling("image_encoder", repo_id)
te = _load_sibling("text_encoder", repo_id)
build_fn = getattr(ie, config.vision_fn)
self.vision_encoder = build_fn(
img_size=config.img_size,
patch_size=config.patch_size,
ffn_layer=config.ffn_layer,
block_chunks=0,
init_values=config.init_values,
interpolate_antialias=True,
interpolate_offset=0.0,
)
self.text_encoder = te.TextEncoder(
config={
"hidden_size": config.text_hidden_size,
"mlp_dim": config.text_mlp_dim,
"num_heads": config.text_num_heads,
"num_layers": config.text_num_layers,
},
vocab_size=config.vocab_size,
)
self._tokenizer = None
self._te_mod = te
def _load_tokenizer(self):
"""Lazy-load the SentencePiece tokenizer."""
tok_path = _this_dir / "tokenizer.model"
if not tok_path.exists():
tok_path = hf_hub_download(self.name_or_path, "tokenizer.model")
return self._te_mod.Tokenizer(str(tok_path))
@torch.no_grad()
def encode_image(self, pixel_values: torch.Tensor) -> TIPSv2ImageOutput:
"""Encode images. pixel_values: (B, 3, H, W) in [0, 1]."""
pixel_values = pixel_values.to(self.device)
cls_token, register_tokens, patch_tokens = self.vision_encoder(pixel_values)
return TIPSv2ImageOutput(
cls_token=cls_token,
register_tokens=register_tokens,
patch_tokens=patch_tokens,
)
@torch.no_grad()
def encode_text(
self,
texts: Union[str, List[str], torch.Tensor],
padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Encode text. Pass strings (auto-tokenized) or pre-tokenized tensors."""
if isinstance(texts, (str, list)):
if isinstance(texts, str):
texts = [texts]
if self._tokenizer is None:
self._tokenizer = self._load_tokenizer()
ids, paddings = self._tokenizer.tokenize(texts, max_len=self.config.max_len)
ids = torch.from_numpy(ids).to(self.device)
padding_mask = torch.from_numpy(paddings).to(self.device)
else:
ids = texts.to(self.device)
padding_mask = padding_mask.to(self.device)
return self.text_encoder(ids, padding_mask)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
) -> TIPSv2Output:
"""Forward pass for both or either modality."""
image_features = None
text_embeds = None
if pixel_values is not None:
image_features = self.encode_image(pixel_values)
if input_ids is not None:
text_embeds = self.encode_text(input_ids, padding_mask)
return TIPSv2Output(
image_features=image_features,
text_embeds=text_embeds,
temperature=self.config.temperature,
)
|