| """TIPSv2 model for HuggingFace — wraps vision and text encoders.""" |
|
|
| import importlib |
| import os |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import List, Optional, Union |
|
|
| import numpy as np |
| import torch |
| from huggingface_hub import hf_hub_download |
| from transformers import PreTrainedModel |
|
|
| from .configuration_tips import TIPSv2Config |
|
|
| _this_dir = Path(__file__).parent |
| _sibling_cache = {} |
|
|
|
|
| def _load_sibling(name, repo_id=None): |
| """Import a sibling .py from the same dir, downloading from HF if needed.""" |
| if name in _sibling_cache: |
| return _sibling_cache[name] |
| path = _this_dir / f"{name}.py" |
| if not path.exists() and repo_id: |
| path = Path(hf_hub_download(repo_id, f"{name}.py")) |
| spec = importlib.util.spec_from_file_location(name, str(path)) |
| mod = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(mod) |
| _sibling_cache[name] = mod |
| return mod |
|
|
|
|
| @dataclass |
| class TIPSv2ImageOutput: |
| """Output from the vision encoder.""" |
| cls_token: torch.Tensor |
| register_tokens: torch.Tensor |
| patch_tokens: torch.Tensor |
|
|
|
|
| @dataclass |
| class TIPSv2Output: |
| """Output from the full model.""" |
| image_features: Optional[TIPSv2ImageOutput] = None |
| text_embeds: Optional[torch.Tensor] = None |
| temperature: Optional[float] = None |
|
|
|
|
| class TIPSv2Model(PreTrainedModel): |
| """TIPSv2 vision-language model. |
| |
| Usage:: |
| |
| model = AutoModel.from_pretrained("google/tipsv2-b14", trust_remote_code=True) |
| |
| # Image features |
| out = model.encode_image(pixel_values) # pixel_values in [0, 1] |
| cls = out.cls_token # (B, 1, D) |
| spatial = out.patch_tokens # (B, N, D) |
| |
| # Text features |
| text_emb = model.encode_text(["a photo of a cat"]) # (B, D) |
| """ |
|
|
| config_class = TIPSv2Config |
| _no_split_modules = [] |
| _supports_cache_class = False |
| _tied_weights_keys = [] |
|
|
| @property |
| def all_tied_weights_keys(self): |
| return {} |
|
|
| def __init__(self, config: TIPSv2Config): |
| super().__init__(config) |
|
|
| repo_id = getattr(config, "_name_or_path", None) |
| ie = _load_sibling("image_encoder", repo_id) |
| te = _load_sibling("text_encoder", repo_id) |
|
|
| build_fn = getattr(ie, config.vision_fn) |
| self.vision_encoder = build_fn( |
| img_size=config.img_size, |
| patch_size=config.patch_size, |
| ffn_layer=config.ffn_layer, |
| block_chunks=0, |
| init_values=config.init_values, |
| interpolate_antialias=True, |
| interpolate_offset=0.0, |
| ) |
|
|
| self.text_encoder = te.TextEncoder( |
| config={ |
| "hidden_size": config.text_hidden_size, |
| "mlp_dim": config.text_mlp_dim, |
| "num_heads": config.text_num_heads, |
| "num_layers": config.text_num_layers, |
| }, |
| vocab_size=config.vocab_size, |
| ) |
|
|
| self._tokenizer = None |
| self._te_mod = te |
|
|
| def _load_tokenizer(self): |
| """Lazy-load the SentencePiece tokenizer.""" |
| tok_path = _this_dir / "tokenizer.model" |
| if not tok_path.exists(): |
| tok_path = hf_hub_download(self.name_or_path, "tokenizer.model") |
| return self._te_mod.Tokenizer(str(tok_path)) |
|
|
| @torch.no_grad() |
| def encode_image(self, pixel_values: torch.Tensor) -> TIPSv2ImageOutput: |
| """Encode images. pixel_values: (B, 3, H, W) in [0, 1].""" |
| pixel_values = pixel_values.to(self.device) |
| cls_token, register_tokens, patch_tokens = self.vision_encoder(pixel_values) |
| return TIPSv2ImageOutput( |
| cls_token=cls_token, |
| register_tokens=register_tokens, |
| patch_tokens=patch_tokens, |
| ) |
|
|
| @torch.no_grad() |
| def encode_text( |
| self, |
| texts: Union[str, List[str], torch.Tensor], |
| padding_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Encode text. Pass strings (auto-tokenized) or pre-tokenized tensors.""" |
| if isinstance(texts, (str, list)): |
| if isinstance(texts, str): |
| texts = [texts] |
| if self._tokenizer is None: |
| self._tokenizer = self._load_tokenizer() |
| ids, paddings = self._tokenizer.tokenize(texts, max_len=self.config.max_len) |
| ids = torch.from_numpy(ids).to(self.device) |
| padding_mask = torch.from_numpy(paddings).to(self.device) |
| else: |
| ids = texts.to(self.device) |
| padding_mask = padding_mask.to(self.device) |
| return self.text_encoder(ids, padding_mask) |
|
|
| def forward( |
| self, |
| pixel_values: Optional[torch.Tensor] = None, |
| input_ids: Optional[torch.Tensor] = None, |
| padding_mask: Optional[torch.Tensor] = None, |
| ) -> TIPSv2Output: |
| """Forward pass for both or either modality.""" |
| image_features = None |
| text_embeds = None |
| if pixel_values is not None: |
| image_features = self.encode_image(pixel_values) |
| if input_ids is not None: |
| text_embeds = self.encode_text(input_ids, padding_mask) |
| return TIPSv2Output( |
| image_features=image_features, |
| text_embeds=text_embeds, |
| temperature=self.config.temperature, |
| ) |
|
|