Feature Extraction
Transformers
Safetensors
custom_code
C-RADIOv4-H / siglip2_adaptor.py
mranzinger's picture
Upload model (#1)
4eb2761 verified
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from argparse import Namespace
import string
from typing import List
import torch
from torch import nn
import torch.nn.functional as F
from .adaptor_registry import adaptor_registry, dict_t, state_t
from .adaptor_generic import GenericAdaptor
from .utils import rank_gate
_VERSION_MAP = {
'siglip2-g-384': 'google/siglip2-giant-opt-patch16-384',
'siglip2-so400m': 'google/siglip2-so400m-patch16-naflex',
}
class SigLIP2Adaptor(GenericAdaptor):
def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t):
super().__init__(main_config, adaptor_config, state)
version = adaptor_config['model']
version = _VERSION_MAP[version]
from transformers import AutoModel, AutoProcessor
with rank_gate():
model = AutoModel.from_pretrained(version, trust_remote_code=True)
proc = AutoProcessor.from_pretrained(version, trust_remote_code=True)
self.tokenizer = SigLIP2WrappedTokenizer(proc)
self.text_model = model.text_model
del model
def encode_text(self, text, normalize: bool = False):
output = self.text_model(**text, return_dict=True)
token = output.pooler_output
if normalize:
token = F.normalize(token, dim=-1)
return token
class SigLIP2WrappedTokenizer:
def __init__(self, proc):
self._proc = proc
def __call__(self, text: List[str]):
text = [canonicalize_text(t) for t in text]
ret = self._proc(text=text, return_tensors='pt', max_length=64, padding='max_length', truncation=True)
return ret
def canonicalize_text(
text: str,
*,
keep_punctuation_exact_string=None,
trans_punctuation: dict = str.maketrans("", "", string.punctuation),
):
"""Returns canonicalized `text` (lowercase and punctuation removed).
From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
Args:
text: string to be canonicalized.
keep_punctuation_exact_string: If provided, then this exact string kept.
For example providing '{}' will keep any occurrences of '{}' (but will
still remove '{' and '}' that appear separately).
"""
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(trans_punctuation)
for part in text.split(keep_punctuation_exact_string)
)
else:
text = text.translate(trans_punctuation)
text = text.lower()
text = " ".join(text.split())
return text.strip()
@adaptor_registry.register_adaptor("siglip2")
def create_siglip2_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t):
return SigLIP2Adaptor(main_config, adaptor_config, state)