Feature Extraction
Transformers
Safetensors
custom_code
File size: 3,307 Bytes
4eb2761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from argparse import Namespace
import string
from typing import List

import torch
from torch import nn
import torch.nn.functional as F

from .adaptor_registry import adaptor_registry, dict_t, state_t

from .adaptor_generic import GenericAdaptor
from .utils import rank_gate


_VERSION_MAP = {
    'siglip2-g-384': 'google/siglip2-giant-opt-patch16-384',
    'siglip2-so400m': 'google/siglip2-so400m-patch16-naflex',
}


class SigLIP2Adaptor(GenericAdaptor):
    def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t):
        super().__init__(main_config, adaptor_config, state)

        version = adaptor_config['model']
        version = _VERSION_MAP[version]

        from transformers import AutoModel, AutoProcessor
        with rank_gate():
            model = AutoModel.from_pretrained(version, trust_remote_code=True)
            proc = AutoProcessor.from_pretrained(version, trust_remote_code=True)

        self.tokenizer = SigLIP2WrappedTokenizer(proc)
        self.text_model = model.text_model

        del model

    def encode_text(self, text, normalize: bool = False):
        output = self.text_model(**text, return_dict=True)
        token = output.pooler_output

        if normalize:
            token = F.normalize(token, dim=-1)

        return token


class SigLIP2WrappedTokenizer:
    def __init__(self, proc):
        self._proc = proc

    def __call__(self, text: List[str]):
        text = [canonicalize_text(t) for t in text]
        ret = self._proc(text=text, return_tensors='pt', max_length=64, padding='max_length', truncation=True)
        return ret


def canonicalize_text(
    text: str,
    *,
    keep_punctuation_exact_string=None,
    trans_punctuation: dict = str.maketrans("", "", string.punctuation),
):
    """Returns canonicalized `text` (lowercase and punctuation removed).

    From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94

    Args:
      text: string to be canonicalized.
      keep_punctuation_exact_string: If provided, then this exact string kept.
        For example providing '{}' will keep any occurrences of '{}' (but will
        still remove '{' and '}' that appear separately).
    """
    text = text.replace("_", " ")
    if keep_punctuation_exact_string:
        text = keep_punctuation_exact_string.join(
            part.translate(trans_punctuation)
            for part in text.split(keep_punctuation_exact_string)
        )
    else:
        text = text.translate(trans_punctuation)
    text = text.lower()
    text = " ".join(text.split())
    return text.strip()


@adaptor_registry.register_adaptor("siglip2")
def create_siglip2_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t):
    return SigLIP2Adaptor(main_config, adaptor_config, state)