| |
| |
| """ |
| ================================================ |
| @author: Jaron |
| @time: 2024/08/21 17:51:45 |
| @email: fjjth98@163.com |
| @description: |
| ================================================ |
| """ |
| from typing import Union |
|
|
| from transformers import PretrainedConfig |
| from transformers.models.auto import CONFIG_MAPPING |
|
|
|
|
| class CCAMConfig(PretrainedConfig): |
|
|
| def __init__( |
| self, |
| num_query: int = 1024, |
| num_heads: int = 16, |
| hidden_size: int = 1024, |
| intermediate_size: int = 4096, |
| num_key_value_heads: int = 16, |
| dropout: float = 0.1, |
| mlp_bias: bool = True, |
| hidden_act: str = 'swiglu', |
| output_size: int = None, |
| attention_bias: bool = True, |
| layer_norm_eps: float = 1e-5, |
| cross_hidden_size: int = None, |
| attention_dropout: float = 0.1, |
| _attn_implementation: str = 'flash_attention_2', |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.dropout = dropout |
| self.mlp_bias = mlp_bias |
| self.num_query = num_query |
| self.num_heads = num_heads |
| self.hidden_act = hidden_act |
| self.hidden_size = hidden_size |
| self.output_size = output_size |
| self.layer_norm_eps = layer_norm_eps |
| self.attention_bias = attention_bias |
| self.intermediate_size = intermediate_size |
| self.cross_hidden_size = cross_hidden_size |
| self.attention_dropout = attention_dropout |
| self.num_key_value_heads = num_key_value_heads |
| self._attn_implementation = _attn_implementation |
|
|
|
|
| class VideoCCAMConfig(PretrainedConfig): |
| model_type = 'videoccam' |
| _auto_class = 'AutoConfig' |
|
|
| def __init__( |
| self, |
| vision_config: Union[dict, PretrainedConfig] = None, |
| text_config: Union[dict, PretrainedConfig] = None, |
| projector_config: dict = None, |
| image_token_id: int = None, |
| video_token_id: int = None, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| if isinstance(vision_config, dict): |
| self.vision_config = CONFIG_MAPPING[vision_config['model_type']](**vision_config) |
| else: |
| self.vision_config = vision_config |
| if isinstance(text_config, dict): |
| self.text_config = CONFIG_MAPPING[text_config['model_type']](**text_config) |
| else: |
| self.text_config = text_config |
| if isinstance(projector_config, dict): |
| self.projector_config = CCAMConfig(**projector_config) |
| else: |
| self.projector_config = projector_config |
| self.image_token_id = image_token_id |
| self.video_token_id = video_token_id |
|
|