Commit ·
a3f5a60
1
Parent(s): ab448a5
set use_flash_attn at different position
Browse files- configuration_clip.py +0 -5
- modeling_clip.py +8 -0
configuration_clip.py
CHANGED
|
@@ -263,11 +263,6 @@ class JinaCLIPConfig(PretrainedConfig):
|
|
| 263 |
'with default values.'
|
| 264 |
)
|
| 265 |
|
| 266 |
-
if use_text_flash_attn:
|
| 267 |
-
text_config.hf_model_config_kwargs.use_flash_attn = use_text_flash_attn
|
| 268 |
-
if use_vision_xformers:
|
| 269 |
-
vision_config.x_attention = use_vision_xformers
|
| 270 |
-
|
| 271 |
self.text_config = JinaCLIPTextConfig(**text_config)
|
| 272 |
self.vision_config = JinaCLIPVisionConfig(**vision_config)
|
| 273 |
|
|
|
|
| 263 |
'with default values.'
|
| 264 |
)
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
self.text_config = JinaCLIPTextConfig(**text_config)
|
| 267 |
self.vision_config = JinaCLIPVisionConfig(**vision_config)
|
| 268 |
|
modeling_clip.py
CHANGED
|
@@ -39,6 +39,9 @@ except ImportError:
|
|
| 39 |
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
|
| 40 |
from .eva_model import EVAVisionTransformer
|
| 41 |
from .hf_model import HFTextEncoder
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
logger = logging.get_logger(__name__)
|
| 44 |
|
|
@@ -210,6 +213,11 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
| 210 |
text_config = config.text_config
|
| 211 |
vision_config = config.vision_config
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
self.add_projections = config.add_projections
|
| 214 |
self.projection_dim = config.projection_dim
|
| 215 |
self.text_embed_dim = text_config.embed_dim
|
|
|
|
| 39 |
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
|
| 40 |
from .eva_model import EVAVisionTransformer
|
| 41 |
from .hf_model import HFTextEncoder
|
| 42 |
+
from .rope_embeddings import rx
|
| 43 |
+
from .transform import rt
|
| 44 |
+
from .processing_clip import rp
|
| 45 |
|
| 46 |
logger = logging.get_logger(__name__)
|
| 47 |
|
|
|
|
| 213 |
text_config = config.text_config
|
| 214 |
vision_config = config.vision_config
|
| 215 |
|
| 216 |
+
if config.use_text_flash_attn is not None:
|
| 217 |
+
text_config.hf_model_config_kwargs['use_flash_attn'] = config.use_text_flash_attn
|
| 218 |
+
if config.use_vision_xformers is not None:
|
| 219 |
+
vision_config.x_attention = config.use_vision_xformers
|
| 220 |
+
|
| 221 |
self.add_projections = config.add_projections
|
| 222 |
self.projection_dim = config.projection_dim
|
| 223 |
self.text_embed_dim = text_config.embed_dim
|