gberton commited on
Commit
3fef103
·
verified ·
1 Parent(s): 96352c1

Upload modeling_tips.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_tips.py +161 -0
modeling_tips.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TIPSv2 model for HuggingFace — wraps vision and text encoders."""
2
+
3
+ import importlib
4
+ import os
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import List, Optional, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from huggingface_hub import hf_hub_download
12
+ from transformers import PreTrainedModel
13
+
14
+ from .configuration_tips import TIPSv2Config
15
+
16
+ _this_dir = Path(__file__).parent
17
+ _sibling_cache = {}
18
+
19
+
20
+ def _load_sibling(name, repo_id=None):
21
+ """Import a sibling .py from the same dir, downloading from HF if needed."""
22
+ if name in _sibling_cache:
23
+ return _sibling_cache[name]
24
+ path = _this_dir / f"{name}.py"
25
+ if not path.exists() and repo_id:
26
+ path = Path(hf_hub_download(repo_id, f"{name}.py"))
27
+ spec = importlib.util.spec_from_file_location(name, str(path))
28
+ mod = importlib.util.module_from_spec(spec)
29
+ spec.loader.exec_module(mod)
30
+ _sibling_cache[name] = mod
31
+ return mod
32
+
33
+
34
+ @dataclass
35
+ class TIPSv2ImageOutput:
36
+ """Output from the vision encoder."""
37
+ cls_token: torch.Tensor # (B, 1, D)
38
+ register_tokens: torch.Tensor # (B, R, D)
39
+ patch_tokens: torch.Tensor # (B, N, D)
40
+
41
+
42
+ @dataclass
43
+ class TIPSv2Output:
44
+ """Output from the full model."""
45
+ image_features: Optional[TIPSv2ImageOutput] = None
46
+ text_embeds: Optional[torch.Tensor] = None
47
+ temperature: Optional[float] = None
48
+
49
+
50
+ class TIPSv2Model(PreTrainedModel):
51
+ """TIPSv2 vision-language model.
52
+
53
+ Usage::
54
+
55
+ model = AutoModel.from_pretrained("google/tipsv2-b14", trust_remote_code=True)
56
+
57
+ # Image features
58
+ out = model.encode_image(pixel_values) # pixel_values in [0, 1]
59
+ cls = out.cls_token # (B, 1, D)
60
+ spatial = out.patch_tokens # (B, N, D)
61
+
62
+ # Text features
63
+ text_emb = model.encode_text(["a photo of a cat"]) # (B, D)
64
+ """
65
+
66
+ config_class = TIPSv2Config
67
+ _no_split_modules = []
68
+ _supports_cache_class = False
69
+ _tied_weights_keys = []
70
+
71
+ @property
72
+ def all_tied_weights_keys(self):
73
+ return {}
74
+
75
+ def __init__(self, config: TIPSv2Config):
76
+ super().__init__(config)
77
+
78
+ repo_id = getattr(config, "_name_or_path", None)
79
+ ie = _load_sibling("image_encoder", repo_id)
80
+ te = _load_sibling("text_encoder", repo_id)
81
+
82
+ build_fn = getattr(ie, config.vision_fn)
83
+ self.vision_encoder = build_fn(
84
+ img_size=config.img_size,
85
+ patch_size=config.patch_size,
86
+ ffn_layer=config.ffn_layer,
87
+ block_chunks=0,
88
+ init_values=config.init_values,
89
+ interpolate_antialias=True,
90
+ interpolate_offset=0.0,
91
+ )
92
+
93
+ self.text_encoder = te.TextEncoder(
94
+ config={
95
+ "hidden_size": config.text_hidden_size,
96
+ "mlp_dim": config.text_mlp_dim,
97
+ "num_heads": config.text_num_heads,
98
+ "num_layers": config.text_num_layers,
99
+ },
100
+ vocab_size=config.vocab_size,
101
+ )
102
+
103
+ self._tokenizer = None
104
+ self._te_mod = te
105
+
106
+ def _load_tokenizer(self):
107
+ """Lazy-load the SentencePiece tokenizer."""
108
+ tok_path = _this_dir / "tokenizer.model"
109
+ if not tok_path.exists():
110
+ tok_path = hf_hub_download(self.name_or_path, "tokenizer.model")
111
+ return self._te_mod.Tokenizer(str(tok_path))
112
+
113
+ @torch.no_grad()
114
+ def encode_image(self, pixel_values: torch.Tensor) -> TIPSv2ImageOutput:
115
+ """Encode images. pixel_values: (B, 3, H, W) in [0, 1]."""
116
+ pixel_values = pixel_values.to(self.device)
117
+ cls_token, register_tokens, patch_tokens = self.vision_encoder(pixel_values)
118
+ return TIPSv2ImageOutput(
119
+ cls_token=cls_token,
120
+ register_tokens=register_tokens,
121
+ patch_tokens=patch_tokens,
122
+ )
123
+
124
+ @torch.no_grad()
125
+ def encode_text(
126
+ self,
127
+ texts: Union[str, List[str], torch.Tensor],
128
+ padding_mask: Optional[torch.Tensor] = None,
129
+ ) -> torch.Tensor:
130
+ """Encode text. Pass strings (auto-tokenized) or pre-tokenized tensors."""
131
+ if isinstance(texts, (str, list)):
132
+ if isinstance(texts, str):
133
+ texts = [texts]
134
+ if self._tokenizer is None:
135
+ self._tokenizer = self._load_tokenizer()
136
+ ids, paddings = self._tokenizer.tokenize(texts, max_len=self.config.max_len)
137
+ ids = torch.from_numpy(ids).to(self.device)
138
+ padding_mask = torch.from_numpy(paddings).to(self.device)
139
+ else:
140
+ ids = texts.to(self.device)
141
+ padding_mask = padding_mask.to(self.device)
142
+ return self.text_encoder(ids, padding_mask)
143
+
144
+ def forward(
145
+ self,
146
+ pixel_values: Optional[torch.Tensor] = None,
147
+ input_ids: Optional[torch.Tensor] = None,
148
+ padding_mask: Optional[torch.Tensor] = None,
149
+ ) -> TIPSv2Output:
150
+ """Forward pass for both or either modality."""
151
+ image_features = None
152
+ text_embeds = None
153
+ if pixel_values is not None:
154
+ image_features = self.encode_image(pixel_values)
155
+ if input_ids is not None:
156
+ text_embeds = self.encode_text(input_ids, padding_mask)
157
+ return TIPSv2Output(
158
+ image_features=image_features,
159
+ text_embeds=text_embeds,
160
+ temperature=self.config.temperature,
161
+ )