| |
| |
| |
| |
| |
|
|
| import os |
| from copy import deepcopy |
| from typing import Any, Dict, Optional, Union |
|
|
| from transformers import PretrainedConfig, logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| """ Jina CLIP model configuration """ |
|
|
|
|
| class JinaCLIPTextConfig(PretrainedConfig): |
| model_type = 'jina_clip_text' |
|
|
| def __init__( |
| self, |
| embed_dim: int = 768, |
| hf_model_name_or_path: str = 'jinaai/jina-bert-v2-base-en-flash', |
| hf_model_config_kwargs: Optional[Dict[str, Any]] = None, |
| pooler_type: Optional[str] = None, |
| proj_type: Optional[str] = None, |
| proj_bias: bool = False, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| self.embed_dim = embed_dim |
| self.hf_model_name_or_path = hf_model_name_or_path |
| self.hf_model_config_kwargs = hf_model_config_kwargs or {} |
| self.pooler_type = pooler_type |
| self.proj_type = proj_type |
| self.proj_bias = proj_bias |
|
|
| @classmethod |
| def from_pretrained( |
| cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs |
| ) -> 'PretrainedConfig': |
| cls._set_token_in_kwargs(kwargs) |
|
|
| configdict, kwargs = cls.get_config_dict( |
| pretrained_model_name_or_path, **kwargs |
| ) |
|
|
| |
| if configdict.get('model_type') == 'jina_clip': |
| configdict = configdict['text_config'] |
|
|
| if ( |
| 'model_type' in configdict |
| and hasattr(cls, 'model_type') |
| and configdict['model_type'] != cls.model_type |
| ): |
| logger.warning( |
| f'You are using a model of type {configdict["model_type"]} to ' |
| f'instantiate a model of type {cls.model_type}. This is not supported ' |
| 'for all configurations of models and can yield errors.' |
| ) |
|
|
| return cls.from_dict(configdict, **kwargs) |
|
|
|
|
| class JinaCLIPVisionConfig(PretrainedConfig): |
| model_type = 'jina_clip_vision' |
|
|
| def __init__( |
| self, |
| embed_dim: int = 768, |
| width: int = 768, |
| image_size: int = 224, |
| patch_size: int = 16, |
| layers: int = 12, |
| head_width: int = 64, |
| mlp_ratio: float = 4.0, |
| ls_init_value: Optional[float] = None, |
| patch_dropout: float = 0.0, |
| qkv_bias: bool = True, |
| fused_layer_norm: bool = False, |
| x_attention: bool = False, |
| post_norm: bool = False, |
| rope_embeddings: bool = False, |
| pt_hw_seq_len: int = 16, |
| intp_freq: bool = False, |
| naive_swiglu: bool = False, |
| subln: bool = False, |
| drop_path_rate: float = 0.0, |
| proj_type: Optional[str] = None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| self.layers = layers |
| self.embed_dim = embed_dim |
| self.width = width |
| self.head_width = head_width |
| self.mlp_ratio = mlp_ratio |
| self.image_size = image_size |
| self.patch_size = patch_size |
| self.ls_init_value = ls_init_value |
| self.patch_dropout = patch_dropout |
| self.qkv_bias = qkv_bias |
| self.fused_layer_norm = fused_layer_norm |
| self.x_attention = x_attention |
| self.post_norm = post_norm |
| self.rope_embeddings = rope_embeddings |
| self.pt_hw_seq_len = pt_hw_seq_len |
| self.intp_freq = intp_freq |
| self.naive_swiglu = naive_swiglu |
| self.subln = subln |
| self.drop_path_rate = drop_path_rate |
| self.proj_type = proj_type |
|
|
| @classmethod |
| def from_pretrained( |
| cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs |
| ) -> 'PretrainedConfig': |
| cls._set_token_in_kwargs(kwargs) |
|
|
| configdict, kwargs = cls.get_config_dict( |
| pretrained_model_name_or_path, **kwargs |
| ) |
|
|
| |
| if configdict.get('model_type') == 'jina_clip': |
| configdict = configdict['vision_config'] |
|
|
| if ( |
| 'model_type' in configdict |
| and hasattr(cls, 'model_type') |
| and configdict['model_type'] != cls.model_type |
| ): |
| logger.warning( |
| f'You are using a model of type {configdict["model_type"]} to ' |
| f'instantiate a model of type {cls.model_type}. This is not supported ' |
| 'for all configurations of models and can yield errors.' |
| ) |
|
|
| return cls.from_dict(configdict, **kwargs) |
|
|
|
|
| class JinaCLIPConfig(PretrainedConfig): |
| model_type = 'jina_clip' |
| is_composition = True |
|
|
| def __init__( |
| self, |
| text_config: Optional[Dict] = None, |
| vision_config: Optional[Dict] = None, |
| add_projections: bool = False, |
| projection_dim: int = 768, |
| logit_scale_init_value: float = 2.6592, |
| **kwargs, |
| ): |
| |
| |
| |
|
|
| text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None) |
| vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None) |
|
|
| super().__init__(**kwargs) |
|
|
| if text_config_dict is not None: |
| if text_config is None: |
| text_config = {} |
|
|
| |
| _text_config_dict = JinaCLIPTextConfig(**text_config_dict).to_dict() |
|
|
| |
| |
| for key, value in _text_config_dict.items(): |
| if ( |
| key in text_config |
| and value != text_config[key] |
| and key not in ['transformers_version'] |
| ): |
| |
| if key in text_config_dict: |
| message = ( |
| f'`{key}` is found in both `text_config_dict` and ' |
| f'`text_config` but with different values. ' |
| f'The value `text_config_dict["{key}"]` will be used ' |
| f'instead.' |
| ) |
| |
| |
| else: |
| message = ( |
| f'`text_config_dict` is provided which will be used to ' |
| f'initialize `JinaCLIPTextConfig`. The ' |
| f'value `text_config["{key}"]` will be overriden.' |
| ) |
| logger.info(message) |
|
|
| |
| text_config.update(_text_config_dict) |
|
|
| if vision_config_dict is not None: |
| if vision_config is None: |
| vision_config = {} |
|
|
| |
| _vision_config_dict = JinaCLIPVisionConfig(**vision_config_dict).to_dict() |
| |
| if 'id2label' in _vision_config_dict: |
| _vision_config_dict['id2label'] = { |
| str(key): value |
| for key, value in _vision_config_dict['id2label'].items() |
| } |
|
|
| |
| |
| for key, value in _vision_config_dict.items(): |
| if ( |
| key in vision_config |
| and value != vision_config[key] |
| and key not in ['transformers_version'] |
| ): |
| |
| if key in vision_config_dict: |
| message = ( |
| f'`{key}` is found in both `vision_config_dict` and ' |
| f'`vision_config` but with different ' |
| f'values. The value `vision_config_dict["{key}"]` will ' |
| f'be used instead.' |
| ) |
| |
| |
| else: |
| message = ( |
| f'`vision_config_dict` is provided which will be used to ' |
| f'initialize `JinaCLIPVisionConfig`. ' |
| f'The value `vision_config["{key}"]` will be overriden.' |
| ) |
| logger.info(message) |
|
|
| |
| |
| vision_config.update(_vision_config_dict) |
|
|
| if text_config is None: |
| text_config = {} |
| logger.info( |
| '`text_config` is `None`. Initializing the `JinaCLIPTextConfig` with ' |
| 'default values.' |
| ) |
|
|
| if vision_config is None: |
| vision_config = {} |
| logger.info( |
| '`vision_config` is `None`. initializing the `JinaCLIPVisionConfig` ' |
| 'with default values.' |
| ) |
|
|
| self.text_config = JinaCLIPTextConfig(**text_config) |
| self.vision_config = JinaCLIPVisionConfig(**vision_config) |
|
|
| self.add_projections = add_projections |
| self.projection_dim = projection_dim |
| self.logit_scale_init_value = logit_scale_init_value |
| self.initializer_factor = 1.0 |
|
|
| if not self.add_projections: |
| if self.text_config.embed_dim != self.vision_config.embed_dim: |
| raise ValueError( |
| 'When projections are disabled (`add_projections=False`), text ' |
| 'and vision towers need to have the same embedding dimensionality. ' |
| f'Currently text embedding dim is {self.text_config.embed_dim} != ' |
| f'{self.vision_config.embed_dim} of the vision tower. ' |
| 'Either set the same output dim for both towers, or enable ' |
| 'projections with `add_projections=True`.' |
| ) |
|
|
| @classmethod |
| def from_text_vision_configs( |
| cls, |
| text_config: JinaCLIPTextConfig, |
| vision_config: JinaCLIPVisionConfig, |
| **kwargs, |
| ): |
| return cls( |
| text_config=text_config.to_dict(), |
| vision_config=vision_config.to_dict(), |
| projection_dim=text_config.projection_dim, |
| **kwargs, |
| ) |
|
|
| def to_dict(self): |
| output = deepcopy(self.__dict__) |
| output['text_config'] = self.text_config.to_dict() |
| output['vision_config'] = self.vision_config.to_dict() |
| output['model_type'] = self.__class__.model_type |
| return output |
|
|