Update modeling_n2_eye.py
Browse files- modeling_n2_eye.py +10 -2
modeling_n2_eye.py
CHANGED
|
@@ -5,8 +5,11 @@ from transformers import (
|
|
| 5 |
AutoModelForCausalLM,
|
| 6 |
CLIPVisionModel,
|
| 7 |
PreTrainedModel,
|
| 8 |
-
PretrainedConfig
|
|
|
|
|
|
|
| 9 |
)
|
|
|
|
| 10 |
from typing import Optional
|
| 11 |
|
| 12 |
|
|
@@ -209,4 +212,9 @@ class MultimodalLFM2Model(PreTrainedModel):
|
|
| 209 |
projection_state_dict = torch.load(projection_path, map_location="cpu")
|
| 210 |
model.vision_projection.load_state_dict(projection_state_dict)
|
| 211 |
|
| 212 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
AutoModelForCausalLM,
|
| 6 |
CLIPVisionModel,
|
| 7 |
PreTrainedModel,
|
| 8 |
+
PretrainedConfig,
|
| 9 |
+
AutoConfig,
|
| 10 |
+
AutoModel
|
| 11 |
)
|
| 12 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
| 13 |
from typing import Optional
|
| 14 |
|
| 15 |
|
|
|
|
| 212 |
projection_state_dict = torch.load(projection_path, map_location="cpu")
|
| 213 |
model.vision_projection.load_state_dict(projection_state_dict)
|
| 214 |
|
| 215 |
+
return model
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# Register the model with transformers
|
| 219 |
+
AutoConfig.register("multimodal_lfm2", MultimodalLFM2Config)
|
| 220 |
+
AutoModelForCausalLM.register(MultimodalLFM2Config, MultimodalLFM2Model)
|