Update configuration.py
Browse files- configuration.py +4 -5
configuration.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
from transformers import CONFIG_MAPPING
|
| 3 |
from transformers import AutoConfig
|
|
|
|
| 4 |
|
| 5 |
IGNORE_INDEX = -100
|
| 6 |
IMAGE_TOKEN_INDEX = -200
|
|
@@ -79,9 +80,9 @@ class TinyLlavaConfig(PretrainedConfig):
|
|
| 79 |
def _load_text_config(self, text_config=None):
|
| 80 |
if self.llm_model_name_or_path is None or self.llm_model_name_or_path == '':
|
| 81 |
self.text_config = CONFIG_MAPPING['llama']()
|
| 82 |
-
|
| 83 |
else:
|
| 84 |
-
self.text_config =
|
| 85 |
if text_config is not None:
|
| 86 |
self.text_config = self.text_config.from_dict(text_config)
|
| 87 |
|
|
@@ -111,6 +112,4 @@ class TinyLlavaConfig(PretrainedConfig):
|
|
| 111 |
|
| 112 |
self.vision_config.model_name_or_path = self.vision_model_name_or_path.split(':')[-1]
|
| 113 |
self.vision_config.model_name_or_path2 = self.vision_model_name_or_path2.split(':')[-1]
|
| 114 |
-
self.vision_hidden_size = getattr(self.vision_config, 'hidden_size', None)
|
| 115 |
-
|
| 116 |
-
|
|
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
from transformers import CONFIG_MAPPING
|
| 3 |
from transformers import AutoConfig
|
| 4 |
+
from transformers.models.phi.configuration_phi import PhiConfig # Ligne ajoutée
|
| 5 |
|
| 6 |
IGNORE_INDEX = -100
|
| 7 |
IMAGE_TOKEN_INDEX = -200
|
|
|
|
| 80 |
def _load_text_config(self, text_config=None):
|
| 81 |
if self.llm_model_name_or_path is None or self.llm_model_name_or_path == '':
|
| 82 |
self.text_config = CONFIG_MAPPING['llama']()
|
| 83 |
+
|
| 84 |
else:
|
| 85 |
+
self.text_config = PhiConfig.from_pretrained(self.llm_model_name_or_path, trust_remote_code=True) # Ligne modifiée
|
| 86 |
if text_config is not None:
|
| 87 |
self.text_config = self.text_config.from_dict(text_config)
|
| 88 |
|
|
|
|
| 112 |
|
| 113 |
self.vision_config.model_name_or_path = self.vision_model_name_or_path.split(':')[-1]
|
| 114 |
self.vision_config.model_name_or_path2 = self.vision_model_name_or_path2.split(':')[-1]
|
| 115 |
+
self.vision_hidden_size = getattr(self.vision_config, 'hidden_size', None)
|
|
|
|
|
|