Update modeling_tinyllava_phi.py
Browse files- modeling_tinyllava_phi.py +8 -10
modeling_tinyllava_phi.py
CHANGED
|
@@ -283,16 +283,14 @@ class TinyLlavaPreTrainedModel(PreTrainedModel):
|
|
| 283 |
|
| 284 |
|
| 285 |
class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
self.post_init()
|
| 295 |
-
|
| 296 |
|
| 297 |
def get_input_embeddings(self):
|
| 298 |
return self.language_model.get_input_embeddings()
|
|
|
|
| 283 |
|
| 284 |
|
| 285 |
class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
|
| 286 |
+
def __init__(self, config: TinyLlavaConfig):
|
| 287 |
+
self._supports_sdpa = True # Ligne déplacée pour être exécutée en premier
|
| 288 |
+
super().__init__(config)
|
| 289 |
+
|
| 290 |
+
self.language_model = PhiForCausalLM(config.text_config)
|
| 291 |
+
self.vision_tower = VisionTower(config.vision_config, config.vision_model_name_or_path)
|
| 292 |
+
self.connector = Connector(config)
|
| 293 |
+
self.post_init()
|
|
|
|
|
|
|
| 294 |
|
| 295 |
def get_input_embeddings(self):
|
| 296 |
return self.language_model.get_input_embeddings()
|