| |
| |
| |
| |
| |
| |
| |
| from argparse import Namespace |
| import string |
| from typing import List |
|
|
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| from .adaptor_registry import adaptor_registry, dict_t, state_t |
|
|
| from .adaptor_generic import GenericAdaptor |
| from .utils import rank_gate |
|
|
|
|
| _VERSION_MAP = { |
| 'siglip2-g-384': 'google/siglip2-giant-opt-patch16-384', |
| 'siglip2-so400m': 'google/siglip2-so400m-patch16-naflex', |
| } |
|
|
|
|
| class SigLIP2Adaptor(GenericAdaptor): |
| def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t): |
| super().__init__(main_config, adaptor_config, state) |
|
|
| version = adaptor_config['model'] |
| version = _VERSION_MAP[version] |
|
|
| from transformers import AutoModel, AutoProcessor |
| with rank_gate(): |
| model = AutoModel.from_pretrained(version, trust_remote_code=True) |
| proc = AutoProcessor.from_pretrained(version, trust_remote_code=True) |
|
|
| self.tokenizer = SigLIP2WrappedTokenizer(proc) |
| self.text_model = model.text_model |
|
|
| del model |
|
|
| def encode_text(self, text, normalize: bool = False): |
| output = self.text_model(**text, return_dict=True) |
| token = output.pooler_output |
|
|
| if normalize: |
| token = F.normalize(token, dim=-1) |
|
|
| return token |
|
|
|
|
| class SigLIP2WrappedTokenizer: |
| def __init__(self, proc): |
| self._proc = proc |
|
|
| def __call__(self, text: List[str]): |
| text = [canonicalize_text(t) for t in text] |
| ret = self._proc(text=text, return_tensors='pt', max_length=64, padding='max_length', truncation=True) |
| return ret |
|
|
|
|
| def canonicalize_text( |
| text: str, |
| *, |
| keep_punctuation_exact_string=None, |
| trans_punctuation: dict = str.maketrans("", "", string.punctuation), |
| ): |
| """Returns canonicalized `text` (lowercase and punctuation removed). |
| |
| From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 |
| |
| Args: |
| text: string to be canonicalized. |
| keep_punctuation_exact_string: If provided, then this exact string kept. |
| For example providing '{}' will keep any occurrences of '{}' (but will |
| still remove '{' and '}' that appear separately). |
| """ |
| text = text.replace("_", " ") |
| if keep_punctuation_exact_string: |
| text = keep_punctuation_exact_string.join( |
| part.translate(trans_punctuation) |
| for part in text.split(keep_punctuation_exact_string) |
| ) |
| else: |
| text = text.translate(trans_punctuation) |
| text = text.lower() |
| text = " ".join(text.split()) |
| return text.strip() |
|
|
|
|
| @adaptor_registry.register_adaptor("siglip2") |
| def create_siglip2_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t): |
| return SigLIP2Adaptor(main_config, adaptor_config, state) |
|
|