tipsv2-so400m14 / modeling_tips.py
gberton's picture
Upload modeling_tips.py with huggingface_hub
3fef103 verified
"""TIPSv2 model for HuggingFace — wraps vision and text encoders."""
import importlib
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from transformers import PreTrainedModel
from .configuration_tips import TIPSv2Config
_this_dir = Path(__file__).parent
_sibling_cache = {}
def _load_sibling(name, repo_id=None):
"""Import a sibling .py from the same dir, downloading from HF if needed."""
if name in _sibling_cache:
return _sibling_cache[name]
path = _this_dir / f"{name}.py"
if not path.exists() and repo_id:
path = Path(hf_hub_download(repo_id, f"{name}.py"))
spec = importlib.util.spec_from_file_location(name, str(path))
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
_sibling_cache[name] = mod
return mod
@dataclass
class TIPSv2ImageOutput:
"""Output from the vision encoder."""
cls_token: torch.Tensor # (B, 1, D)
register_tokens: torch.Tensor # (B, R, D)
patch_tokens: torch.Tensor # (B, N, D)
@dataclass
class TIPSv2Output:
"""Output from the full model."""
image_features: Optional[TIPSv2ImageOutput] = None
text_embeds: Optional[torch.Tensor] = None
temperature: Optional[float] = None
class TIPSv2Model(PreTrainedModel):
"""TIPSv2 vision-language model.
Usage::
model = AutoModel.from_pretrained("google/tipsv2-b14", trust_remote_code=True)
# Image features
out = model.encode_image(pixel_values) # pixel_values in [0, 1]
cls = out.cls_token # (B, 1, D)
spatial = out.patch_tokens # (B, N, D)
# Text features
text_emb = model.encode_text(["a photo of a cat"]) # (B, D)
"""
config_class = TIPSv2Config
_no_split_modules = []
_supports_cache_class = False
_tied_weights_keys = []
@property
def all_tied_weights_keys(self):
return {}
def __init__(self, config: TIPSv2Config):
super().__init__(config)
repo_id = getattr(config, "_name_or_path", None)
ie = _load_sibling("image_encoder", repo_id)
te = _load_sibling("text_encoder", repo_id)
build_fn = getattr(ie, config.vision_fn)
self.vision_encoder = build_fn(
img_size=config.img_size,
patch_size=config.patch_size,
ffn_layer=config.ffn_layer,
block_chunks=0,
init_values=config.init_values,
interpolate_antialias=True,
interpolate_offset=0.0,
)
self.text_encoder = te.TextEncoder(
config={
"hidden_size": config.text_hidden_size,
"mlp_dim": config.text_mlp_dim,
"num_heads": config.text_num_heads,
"num_layers": config.text_num_layers,
},
vocab_size=config.vocab_size,
)
self._tokenizer = None
self._te_mod = te
def _load_tokenizer(self):
"""Lazy-load the SentencePiece tokenizer."""
tok_path = _this_dir / "tokenizer.model"
if not tok_path.exists():
tok_path = hf_hub_download(self.name_or_path, "tokenizer.model")
return self._te_mod.Tokenizer(str(tok_path))
@torch.no_grad()
def encode_image(self, pixel_values: torch.Tensor) -> TIPSv2ImageOutput:
"""Encode images. pixel_values: (B, 3, H, W) in [0, 1]."""
pixel_values = pixel_values.to(self.device)
cls_token, register_tokens, patch_tokens = self.vision_encoder(pixel_values)
return TIPSv2ImageOutput(
cls_token=cls_token,
register_tokens=register_tokens,
patch_tokens=patch_tokens,
)
@torch.no_grad()
def encode_text(
self,
texts: Union[str, List[str], torch.Tensor],
padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Encode text. Pass strings (auto-tokenized) or pre-tokenized tensors."""
if isinstance(texts, (str, list)):
if isinstance(texts, str):
texts = [texts]
if self._tokenizer is None:
self._tokenizer = self._load_tokenizer()
ids, paddings = self._tokenizer.tokenize(texts, max_len=self.config.max_len)
ids = torch.from_numpy(ids).to(self.device)
padding_mask = torch.from_numpy(paddings).to(self.device)
else:
ids = texts.to(self.device)
padding_mask = padding_mask.to(self.device)
return self.text_encoder(ids, padding_mask)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
) -> TIPSv2Output:
"""Forward pass for both or either modality."""
image_features = None
text_embeds = None
if pixel_values is not None:
image_features = self.encode_image(pixel_values)
if input_ids is not None:
text_embeds = self.encode_text(input_ids, padding_mask)
return TIPSv2Output(
image_features=image_features,
text_embeds=text_embeds,
temperature=self.config.temperature,
)