fix clip_vision trouble
Browse files
src/configuration_medclip.py
CHANGED
|
@@ -2,7 +2,7 @@ import copy
|
|
| 2 |
|
| 3 |
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
from transformers.utils import logging
|
| 5 |
-
|
| 6 |
|
| 7 |
logger = logging.get_logger(__name__)
|
| 8 |
|
|
@@ -69,12 +69,12 @@ class MedCLIPConfig(PretrainedConfig):
|
|
| 69 |
text_model_type = text_config.pop("model_type")
|
| 70 |
vision_model_type = vision_config.pop("model_type")
|
| 71 |
|
| 72 |
-
from transformers import AutoConfig
|
| 73 |
-
|
| 74 |
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
|
| 75 |
|
| 76 |
if vision_model_type == "clip":
|
| 77 |
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
|
|
|
|
|
|
|
| 78 |
else:
|
| 79 |
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
|
| 80 |
|
|
|
|
| 2 |
|
| 3 |
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
from transformers.utils import logging
|
| 5 |
+
from transformers import AutoConfig, CLIPVisionConfig
|
| 6 |
|
| 7 |
logger = logging.get_logger(__name__)
|
| 8 |
|
|
|
|
| 69 |
text_model_type = text_config.pop("model_type")
|
| 70 |
vision_model_type = vision_config.pop("model_type")
|
| 71 |
|
|
|
|
|
|
|
| 72 |
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
|
| 73 |
|
| 74 |
if vision_model_type == "clip":
|
| 75 |
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
|
| 76 |
+
elif vision_model_type == "clip_vision_model":
|
| 77 |
+
self.vision_config = CLIPVisionConfig(**vision_config)
|
| 78 |
else:
|
| 79 |
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
|
| 80 |
|