SafaaAI commited on
Commit
d6ebd84
·
verified ·
1 Parent(s): dda7c7d

Update configuration.py

Browse files
Files changed (1) hide show
  1. configuration.py +4 -5
configuration.py CHANGED
@@ -1,6 +1,7 @@
1
  from transformers import PretrainedConfig
2
  from transformers import CONFIG_MAPPING
3
  from transformers import AutoConfig
 
4
 
5
  IGNORE_INDEX = -100
6
  IMAGE_TOKEN_INDEX = -200
@@ -79,9 +80,9 @@ class TinyLlavaConfig(PretrainedConfig):
79
  def _load_text_config(self, text_config=None):
80
  if self.llm_model_name_or_path is None or self.llm_model_name_or_path == '':
81
  self.text_config = CONFIG_MAPPING['llama']()
82
-
83
  else:
84
- self.text_config = AutoConfig.from_pretrained(self.llm_model_name_or_path, trust_remote_code=True)
85
  if text_config is not None:
86
  self.text_config = self.text_config.from_dict(text_config)
87
 
@@ -111,6 +112,4 @@ class TinyLlavaConfig(PretrainedConfig):
111
 
112
  self.vision_config.model_name_or_path = self.vision_model_name_or_path.split(':')[-1]
113
  self.vision_config.model_name_or_path2 = self.vision_model_name_or_path2.split(':')[-1]
114
- self.vision_hidden_size = getattr(self.vision_config, 'hidden_size', None)
115
-
116
-
 
1
  from transformers import PretrainedConfig
2
  from transformers import CONFIG_MAPPING
3
  from transformers import AutoConfig
4
+ from transformers.models.phi.configuration_phi import PhiConfig # Ligne ajoutée
5
 
6
  IGNORE_INDEX = -100
7
  IMAGE_TOKEN_INDEX = -200
 
80
  def _load_text_config(self, text_config=None):
81
  if self.llm_model_name_or_path is None or self.llm_model_name_or_path == '':
82
  self.text_config = CONFIG_MAPPING['llama']()
83
+
84
  else:
85
+ self.text_config = PhiConfig.from_pretrained(self.llm_model_name_or_path, trust_remote_code=True) # Ligne modifiée
86
  if text_config is not None:
87
  self.text_config = self.text_config.from_dict(text_config)
88
 
 
112
 
113
  self.vision_config.model_name_or_path = self.vision_model_name_or_path.split(':')[-1]
114
  self.vision_config.model_name_or_path2 = self.vision_model_name_or_path2.split(':')[-1]
115
+ self.vision_hidden_size = getattr(self.vision_config, 'hidden_size', None)