feat-matryoshka-support
Browse files- configuration_clip.py +5 -1
- modeling_clip.py +43 -5
configuration_clip.py
CHANGED
|
@@ -6,7 +6,7 @@
|
|
| 6 |
|
| 7 |
import os
|
| 8 |
from copy import deepcopy
|
| 9 |
-
from typing import Any, Dict, Optional, Union
|
| 10 |
|
| 11 |
from transformers import PretrainedConfig, logging
|
| 12 |
|
|
@@ -157,6 +157,8 @@ class JinaCLIPConfig(PretrainedConfig):
|
|
| 157 |
logit_scale_init_value: float = 2.6592,
|
| 158 |
use_text_flash_attn: Optional[bool] = None,
|
| 159 |
use_vision_xformers: Optional[bool] = None,
|
|
|
|
|
|
|
| 160 |
**kwargs,
|
| 161 |
):
|
| 162 |
# If `_config_dict` exist, we use them for the backward compatibility.
|
|
@@ -167,6 +169,8 @@ class JinaCLIPConfig(PretrainedConfig):
|
|
| 167 |
vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
|
| 168 |
self.use_text_flash_attn = use_text_flash_attn
|
| 169 |
self.use_vision_xformers = use_vision_xformers
|
|
|
|
|
|
|
| 170 |
|
| 171 |
super().__init__(**kwargs)
|
| 172 |
|
|
|
|
| 6 |
|
| 7 |
import os
|
| 8 |
from copy import deepcopy
|
| 9 |
+
from typing import Any, Dict, List, Optional, Union
|
| 10 |
|
| 11 |
from transformers import PretrainedConfig, logging
|
| 12 |
|
|
|
|
| 157 |
logit_scale_init_value: float = 2.6592,
|
| 158 |
use_text_flash_attn: Optional[bool] = None,
|
| 159 |
use_vision_xformers: Optional[bool] = None,
|
| 160 |
+
matryoshka_dimensions: Optional[List[int]] = None,
|
| 161 |
+
truncate_dim: Optional[int] = None,
|
| 162 |
**kwargs,
|
| 163 |
):
|
| 164 |
# If `_config_dict` exist, we use them for the backward compatibility.
|
|
|
|
| 169 |
vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
|
| 170 |
self.use_text_flash_attn = use_text_flash_attn
|
| 171 |
self.use_vision_xformers = use_vision_xformers
|
| 172 |
+
self.matryoshka_dimensions = matryoshka_dimensions
|
| 173 |
+
self.truncate_dim = truncate_dim
|
| 174 |
|
| 175 |
super().__init__(**kwargs)
|
| 176 |
|
modeling_clip.py
CHANGED
|
@@ -4,12 +4,13 @@
|
|
| 4 |
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
|
| 5 |
# and adjusted for Jina CLIP
|
| 6 |
|
|
|
|
| 7 |
from functools import partial
|
| 8 |
-
from typing import List, Optional, Tuple, Union
|
| 9 |
from io import BytesIO
|
| 10 |
-
import
|
| 11 |
-
|
| 12 |
import numpy as np
|
|
|
|
| 13 |
import torch
|
| 14 |
import torch.nn.functional as f
|
| 15 |
import torch.utils.checkpoint
|
|
@@ -39,9 +40,14 @@ except ImportError:
|
|
| 39 |
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
|
| 40 |
from .eva_model import EVAVisionTransformer
|
| 41 |
from .hf_model import HFTextEncoder
|
|
|
|
| 42 |
# needed for HF to correctly import in cache
|
| 43 |
from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
|
| 44 |
-
from .transform import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
logger = logging.get_logger(__name__)
|
| 47 |
|
|
@@ -280,6 +286,25 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
| 280 |
)
|
| 281 |
return self.visual_projection(self.vision_model(x=x))
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
@torch.inference_mode()
|
| 284 |
def encode_text(
|
| 285 |
self,
|
|
@@ -290,6 +315,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
| 290 |
convert_to_tensor: bool = False,
|
| 291 |
device: Optional[torch.device] = None,
|
| 292 |
normalize_embeddings: bool = True,
|
|
|
|
| 293 |
**tokenizer_kwargs,
|
| 294 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 295 |
"""
|
|
@@ -315,6 +341,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
| 315 |
If set to true, returned vectors will have length 1. In that case,
|
| 316 |
the faster dot-product (util.dot_score) instead of cosine similarity
|
| 317 |
can be used.
|
|
|
|
|
|
|
| 318 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
| 319 |
Keyword arguments for the tokenizer
|
| 320 |
Returns:
|
|
@@ -364,6 +392,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
| 364 |
else:
|
| 365 |
range_iter = range(0, len(sentences), batch_size)
|
| 366 |
|
|
|
|
| 367 |
for i in range_iter:
|
| 368 |
encoded_input = self.tokenizer(
|
| 369 |
sentences[i : i + batch_size],
|
|
@@ -372,6 +401,9 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
| 372 |
).to(self.device)
|
| 373 |
|
| 374 |
embeddings = self.get_text_features(input_ids=encoded_input)
|
|
|
|
|
|
|
|
|
|
| 375 |
if normalize_embeddings:
|
| 376 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
| 377 |
if convert_to_numpy:
|
|
@@ -406,6 +438,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
| 406 |
convert_to_tensor: bool = False,
|
| 407 |
device: Optional[torch.device] = None,
|
| 408 |
normalize_embeddings: bool = True,
|
|
|
|
| 409 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 410 |
"""
|
| 411 |
Computes image embeddings.
|
|
@@ -431,6 +464,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
| 431 |
If set to true, returned vectors will have length 1. In that case,
|
| 432 |
the faster dot-product (util.dot_score) instead of cosine similarity
|
| 433 |
can be used.
|
|
|
|
|
|
|
| 434 |
Returns:
|
| 435 |
By default, a list of tensors is returned.
|
| 436 |
If convert_to_tensor, a stacked tensor is returned.
|
|
@@ -476,7 +511,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
| 476 |
range_iter = range(0, len(images), batch_size)
|
| 477 |
|
| 478 |
from PIL import Image
|
| 479 |
-
|
|
|
|
| 480 |
for i in range_iter:
|
| 481 |
batch_images = images[i:i+batch_size]
|
| 482 |
processed_inputs = []
|
|
@@ -501,6 +537,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
| 501 |
processed_inputs = processed_inputs.to(self.device)
|
| 502 |
embeddings = self.get_image_features(processed_inputs)
|
| 503 |
|
|
|
|
|
|
|
| 504 |
if normalize_embeddings:
|
| 505 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
| 506 |
if convert_to_numpy:
|
|
|
|
| 4 |
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
|
| 5 |
# and adjusted for Jina CLIP
|
| 6 |
|
| 7 |
+
import base64
|
| 8 |
from functools import partial
|
|
|
|
| 9 |
from io import BytesIO
|
| 10 |
+
from typing import List, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
import numpy as np
|
| 13 |
+
import requests
|
| 14 |
import torch
|
| 15 |
import torch.nn.functional as f
|
| 16 |
import torch.utils.checkpoint
|
|
|
|
| 40 |
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
|
| 41 |
from .eva_model import EVAVisionTransformer
|
| 42 |
from .hf_model import HFTextEncoder
|
| 43 |
+
|
| 44 |
# needed for HF to correctly import in cache
|
| 45 |
from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
|
| 46 |
+
from .transform import ( # noqa: F401
|
| 47 |
+
OPENAI_DATASET_MEAN,
|
| 48 |
+
OPENAI_DATASET_STD,
|
| 49 |
+
image_transform,
|
| 50 |
+
)
|
| 51 |
|
| 52 |
logger = logging.get_logger(__name__)
|
| 53 |
|
|
|
|
| 286 |
)
|
| 287 |
return self.visual_projection(self.vision_model(x=x))
|
| 288 |
|
| 289 |
+
def truncate_embeddings(self, embeddings, truncate_dim):
|
| 290 |
+
if "jina-clip-v1" in self.config._name_or_path:
|
| 291 |
+
logger.warning(
|
| 292 |
+
"Matryoshka embeddings are not supported for jina-clip-v1, so dimension truncation will not be performed."
|
| 293 |
+
)
|
| 294 |
+
return embeddings
|
| 295 |
+
elif not self.config.matryoshka_dimensions:
|
| 296 |
+
logger.warning(
|
| 297 |
+
"Matryoshka embeddings are not supported, so dimension truncation will not be performed."
|
| 298 |
+
)
|
| 299 |
+
return embeddings
|
| 300 |
+
elif truncate_dim in self.config.matryoshka_dimensions:
|
| 301 |
+
return embeddings[:, :truncate_dim]
|
| 302 |
+
else:
|
| 303 |
+
raise ValueError(
|
| 304 |
+
f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
|
| 305 |
+
f"Supported dimensions are {self.config.matryoshka_dimensions}."
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
@torch.inference_mode()
|
| 309 |
def encode_text(
|
| 310 |
self,
|
|
|
|
| 315 |
convert_to_tensor: bool = False,
|
| 316 |
device: Optional[torch.device] = None,
|
| 317 |
normalize_embeddings: bool = True,
|
| 318 |
+
truncate_dim: Optional[int] = None,
|
| 319 |
**tokenizer_kwargs,
|
| 320 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 321 |
"""
|
|
|
|
| 341 |
If set to true, returned vectors will have length 1. In that case,
|
| 342 |
the faster dot-product (util.dot_score) instead of cosine similarity
|
| 343 |
can be used.
|
| 344 |
+
truncate_dim(`int`, *optional*, defaults to None):
|
| 345 |
+
The dimension to truncate sentence embeddings to. `None` does no truncation.
|
| 346 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
| 347 |
Keyword arguments for the tokenizer
|
| 348 |
Returns:
|
|
|
|
| 392 |
else:
|
| 393 |
range_iter = range(0, len(sentences), batch_size)
|
| 394 |
|
| 395 |
+
truncate_dim = truncate_dim or self.config.truncate_dim
|
| 396 |
for i in range_iter:
|
| 397 |
encoded_input = self.tokenizer(
|
| 398 |
sentences[i : i + batch_size],
|
|
|
|
| 401 |
).to(self.device)
|
| 402 |
|
| 403 |
embeddings = self.get_text_features(input_ids=encoded_input)
|
| 404 |
+
|
| 405 |
+
if truncate_dim:
|
| 406 |
+
embeddings = self.truncate_embeddings(embeddings, truncate_dim)
|
| 407 |
if normalize_embeddings:
|
| 408 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
| 409 |
if convert_to_numpy:
|
|
|
|
| 438 |
convert_to_tensor: bool = False,
|
| 439 |
device: Optional[torch.device] = None,
|
| 440 |
normalize_embeddings: bool = True,
|
| 441 |
+
truncate_dim: Optional[int] = None,
|
| 442 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 443 |
"""
|
| 444 |
Computes image embeddings.
|
|
|
|
| 464 |
If set to true, returned vectors will have length 1. In that case,
|
| 465 |
the faster dot-product (util.dot_score) instead of cosine similarity
|
| 466 |
can be used.
|
| 467 |
+
truncate_dim(`int`, *optional*, defaults to None):
|
| 468 |
+
The dimension to truncate sentence embeddings to. `None` does no truncation.
|
| 469 |
Returns:
|
| 470 |
By default, a list of tensors is returned.
|
| 471 |
If convert_to_tensor, a stacked tensor is returned.
|
|
|
|
| 511 |
range_iter = range(0, len(images), batch_size)
|
| 512 |
|
| 513 |
from PIL import Image
|
| 514 |
+
|
| 515 |
+
truncate_dim = truncate_dim or self.config.truncate_dim
|
| 516 |
for i in range_iter:
|
| 517 |
batch_images = images[i:i+batch_size]
|
| 518 |
processed_inputs = []
|
|
|
|
| 537 |
processed_inputs = processed_inputs.to(self.device)
|
| 538 |
embeddings = self.get_image_features(processed_inputs)
|
| 539 |
|
| 540 |
+
if truncate_dim:
|
| 541 |
+
embeddings = self.truncate_embeddings(embeddings, truncate_dim)
|
| 542 |
if normalize_embeddings:
|
| 543 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
| 544 |
if convert_to_numpy:
|