support use_flash_attn in from_pretrained (#2)
Browse files- support flash attn in from_pretrained (d7c984ce33a82aa27b9bb6cf4e6a0ef775577760)
- change use_flash_attn and add x_attention attribute (ab448a5fe4db0f489546be2a56a8fd0e64f73d5b)
- set use_flash_attn at different position (a3f5a6005182cd3d5a4be6a9695c09f3952cc0d5)
- remove imports used for testing (853dc7d429ec17e8c8b8a7778453062e4cbcff16)
Co-authored-by: Michael Günther <michael-guenther@users.noreply.huggingface.co>
- configuration_clip.py +4 -0
- modeling_clip.py +5 -0
configuration_clip.py
CHANGED
|
@@ -155,6 +155,8 @@ class JinaCLIPConfig(PretrainedConfig):
|
|
| 155 |
add_projections: bool = False,
|
| 156 |
projection_dim: int = 768,
|
| 157 |
logit_scale_init_value: float = 2.6592,
|
|
|
|
|
|
|
| 158 |
**kwargs,
|
| 159 |
):
|
| 160 |
# If `_config_dict` exist, we use them for the backward compatibility.
|
|
@@ -163,6 +165,8 @@ class JinaCLIPConfig(PretrainedConfig):
|
|
| 163 |
|
| 164 |
text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
|
| 165 |
vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
|
|
|
|
|
|
|
| 166 |
|
| 167 |
super().__init__(**kwargs)
|
| 168 |
|
|
|
|
| 155 |
add_projections: bool = False,
|
| 156 |
projection_dim: int = 768,
|
| 157 |
logit_scale_init_value: float = 2.6592,
|
| 158 |
+
use_text_flash_attn: Optional[bool] = None,
|
| 159 |
+
use_vision_xformers: Optional[bool] = None,
|
| 160 |
**kwargs,
|
| 161 |
):
|
| 162 |
# If `_config_dict` exist, we use them for the backward compatibility.
|
|
|
|
| 165 |
|
| 166 |
text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
|
| 167 |
vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
|
| 168 |
+
self.use_text_flash_attn = use_text_flash_attn
|
| 169 |
+
self.use_vision_xformers = use_vision_xformers
|
| 170 |
|
| 171 |
super().__init__(**kwargs)
|
| 172 |
|
modeling_clip.py
CHANGED
|
@@ -213,6 +213,11 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
| 213 |
text_config = config.text_config
|
| 214 |
vision_config = config.vision_config
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
self.add_projections = config.add_projections
|
| 217 |
self.projection_dim = config.projection_dim
|
| 218 |
self.text_embed_dim = text_config.embed_dim
|
|
|
|
| 213 |
text_config = config.text_config
|
| 214 |
vision_config = config.vision_config
|
| 215 |
|
| 216 |
+
if config.use_text_flash_attn is not None:
|
| 217 |
+
text_config.hf_model_config_kwargs['use_flash_attn'] = config.use_text_flash_attn
|
| 218 |
+
if config.use_vision_xformers is not None:
|
| 219 |
+
vision_config.x_attention = config.use_vision_xformers
|
| 220 |
+
|
| 221 |
self.add_projections = config.add_projections
|
| 222 |
self.projection_dim = config.projection_dim
|
| 223 |
self.text_embed_dim = text_config.embed_dim
|