|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from argparse import Namespace |
|
|
import string |
|
|
from typing import List |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from .adaptor_registry import adaptor_registry, dict_t, state_t |
|
|
|
|
|
from .adaptor_generic import GenericAdaptor |
|
|
from .utils import rank_gate |
|
|
|
|
|
|
|
|
_VERSION_MAP = { |
|
|
'siglip2-g-384': 'google/siglip2-giant-opt-patch16-384', |
|
|
'siglip2-so400m': 'google/siglip2-so400m-patch16-naflex', |
|
|
} |
|
|
|
|
|
|
|
|
class SigLIP2Adaptor(GenericAdaptor): |
|
|
def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t): |
|
|
super().__init__(main_config, adaptor_config, state) |
|
|
|
|
|
version = adaptor_config['model'] |
|
|
version = _VERSION_MAP[version] |
|
|
|
|
|
from transformers import AutoModel, AutoProcessor |
|
|
with rank_gate(): |
|
|
model = AutoModel.from_pretrained(version, trust_remote_code=True) |
|
|
proc = AutoProcessor.from_pretrained(version, trust_remote_code=True) |
|
|
|
|
|
self.tokenizer = SigLIP2WrappedTokenizer(proc) |
|
|
self.text_model = model.text_model |
|
|
|
|
|
del model |
|
|
|
|
|
def encode_text(self, text, normalize: bool = False): |
|
|
output = self.text_model(**text, return_dict=True) |
|
|
token = output.pooler_output |
|
|
|
|
|
if normalize: |
|
|
token = F.normalize(token, dim=-1) |
|
|
|
|
|
return token |
|
|
|
|
|
|
|
|
class SigLIP2WrappedTokenizer: |
|
|
def __init__(self, proc): |
|
|
self._proc = proc |
|
|
|
|
|
def __call__(self, text: List[str]): |
|
|
text = [canonicalize_text(t) for t in text] |
|
|
ret = self._proc(text=text, return_tensors='pt', max_length=64, padding='max_length', truncation=True) |
|
|
return ret |
|
|
|
|
|
|
|
|
def canonicalize_text( |
|
|
text: str, |
|
|
*, |
|
|
keep_punctuation_exact_string=None, |
|
|
trans_punctuation: dict = str.maketrans("", "", string.punctuation), |
|
|
): |
|
|
"""Returns canonicalized `text` (lowercase and punctuation removed). |
|
|
|
|
|
From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 |
|
|
|
|
|
Args: |
|
|
text: string to be canonicalized. |
|
|
keep_punctuation_exact_string: If provided, then this exact string kept. |
|
|
For example providing '{}' will keep any occurrences of '{}' (but will |
|
|
still remove '{' and '}' that appear separately). |
|
|
""" |
|
|
text = text.replace("_", " ") |
|
|
if keep_punctuation_exact_string: |
|
|
text = keep_punctuation_exact_string.join( |
|
|
part.translate(trans_punctuation) |
|
|
for part in text.split(keep_punctuation_exact_string) |
|
|
) |
|
|
else: |
|
|
text = text.translate(trans_punctuation) |
|
|
text = text.lower() |
|
|
text = " ".join(text.split()) |
|
|
return text.strip() |
|
|
|
|
|
|
|
|
@adaptor_registry.register_adaptor("siglip2") |
|
|
def create_siglip2_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t): |
|
|
return SigLIP2Adaptor(main_config, adaptor_config, state) |
|
|
|