"""TIPSv2 model for HuggingFace — wraps vision and text encoders.""" import importlib import os from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Union import numpy as np import torch from huggingface_hub import hf_hub_download from transformers import PreTrainedModel from .configuration_tips import TIPSv2Config _this_dir = Path(__file__).parent _sibling_cache = {} def _load_sibling(name, repo_id=None): """Import a sibling .py from the same dir, downloading from HF if needed.""" if name in _sibling_cache: return _sibling_cache[name] path = _this_dir / f"{name}.py" if not path.exists() and repo_id: path = Path(hf_hub_download(repo_id, f"{name}.py")) spec = importlib.util.spec_from_file_location(name, str(path)) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) _sibling_cache[name] = mod return mod @dataclass class TIPSv2ImageOutput: """Output from the vision encoder.""" cls_token: torch.Tensor # (B, 1, D) register_tokens: torch.Tensor # (B, R, D) patch_tokens: torch.Tensor # (B, N, D) @dataclass class TIPSv2Output: """Output from the full model.""" image_features: Optional[TIPSv2ImageOutput] = None text_embeds: Optional[torch.Tensor] = None temperature: Optional[float] = None class TIPSv2Model(PreTrainedModel): """TIPSv2 vision-language model. Usage:: model = AutoModel.from_pretrained("google/tipsv2-b14", trust_remote_code=True) # Image features out = model.encode_image(pixel_values) # pixel_values in [0, 1] cls = out.cls_token # (B, 1, D) spatial = out.patch_tokens # (B, N, D) # Text features text_emb = model.encode_text(["a photo of a cat"]) # (B, D) """ config_class = TIPSv2Config _no_split_modules = [] _supports_cache_class = False _tied_weights_keys = [] @property def all_tied_weights_keys(self): return {} def __init__(self, config: TIPSv2Config): super().__init__(config) repo_id = getattr(config, "_name_or_path", None) ie = _load_sibling("image_encoder", repo_id) te = _load_sibling("text_encoder", repo_id) build_fn = getattr(ie, config.vision_fn) self.vision_encoder = build_fn( img_size=config.img_size, patch_size=config.patch_size, ffn_layer=config.ffn_layer, block_chunks=0, init_values=config.init_values, interpolate_antialias=True, interpolate_offset=0.0, ) self.text_encoder = te.TextEncoder( config={ "hidden_size": config.text_hidden_size, "mlp_dim": config.text_mlp_dim, "num_heads": config.text_num_heads, "num_layers": config.text_num_layers, }, vocab_size=config.vocab_size, ) self._tokenizer = None self._te_mod = te def _load_tokenizer(self): """Lazy-load the SentencePiece tokenizer.""" tok_path = _this_dir / "tokenizer.model" if not tok_path.exists(): tok_path = hf_hub_download(self.name_or_path, "tokenizer.model") return self._te_mod.Tokenizer(str(tok_path)) @torch.no_grad() def encode_image(self, pixel_values: torch.Tensor) -> TIPSv2ImageOutput: """Encode images. pixel_values: (B, 3, H, W) in [0, 1].""" pixel_values = pixel_values.to(self.device) cls_token, register_tokens, patch_tokens = self.vision_encoder(pixel_values) return TIPSv2ImageOutput( cls_token=cls_token, register_tokens=register_tokens, patch_tokens=patch_tokens, ) @torch.no_grad() def encode_text( self, texts: Union[str, List[str], torch.Tensor], padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Encode text. Pass strings (auto-tokenized) or pre-tokenized tensors.""" if isinstance(texts, (str, list)): if isinstance(texts, str): texts = [texts] if self._tokenizer is None: self._tokenizer = self._load_tokenizer() ids, paddings = self._tokenizer.tokenize(texts, max_len=self.config.max_len) ids = torch.from_numpy(ids).to(self.device) padding_mask = torch.from_numpy(paddings).to(self.device) else: ids = texts.to(self.device) padding_mask = padding_mask.to(self.device) return self.text_encoder(ids, padding_mask) def forward( self, pixel_values: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, ) -> TIPSv2Output: """Forward pass for both or either modality.""" image_features = None text_embeds = None if pixel_values is not None: image_features = self.encode_image(pixel_values) if input_ids is not None: text_embeds = self.encode_text(input_ids, padding_mask) return TIPSv2Output( image_features=image_features, text_embeds=text_embeds, temperature=self.config.temperature, )