HuaminChen commited on
Commit
75eae3f
·
verified ·
1 Parent(s): b4136e7

Fix loading example to match checkpoint format

Browse files
Files changed (1) hide show
  1. README.md +20 -10
README.md CHANGED
@@ -88,25 +88,35 @@ pip install torch transformers pillow safetensors
88
  import torch
89
  from huggingface_hub import hf_hub_download
90
 
91
- # Download checkpoint
92
  checkpoint_path = hf_hub_download(
93
  repo_id="llm-semantic-router/multi-modal-embed-small",
94
  filename="model.pt"
95
  )
 
 
 
 
96
 
97
- # Load model
 
98
  import sys
99
  sys.path.append("path/to/2DMSE-Multimodal-Embedder")
100
- from src.models import create_multimodal_model
101
-
102
- model = create_multimodal_model(
103
- text_encoder_name="sentence-transformers/all-MiniLM-L6-v2",
104
- image_encoder_name="google/siglip-base-patch16-512",
105
- audio_encoder_name="openai/whisper-tiny",
106
- output_dim=384,
 
 
 
 
 
107
  )
108
  state_dict = torch.load(checkpoint_path, map_location="cpu")
109
- model.load_state_dict(state_dict["model_state_dict"])
110
  model.eval()
111
  ```
112
 
 
88
  import torch
89
  from huggingface_hub import hf_hub_download
90
 
91
+ # Download checkpoint and config
92
  checkpoint_path = hf_hub_download(
93
  repo_id="llm-semantic-router/multi-modal-embed-small",
94
  filename="model.pt"
95
  )
96
+ config_path = hf_hub_download(
97
+ repo_id="llm-semantic-router/multi-modal-embed-small",
98
+ filename="config.json"
99
+ )
100
 
101
+ # Load model with matching architecture
102
+ import json
103
  import sys
104
  sys.path.append("path/to/2DMSE-Multimodal-Embedder")
105
+ from src.models import MultimodalEmbedder
106
+
107
+ with open(config_path) as f:
108
+ config = json.load(f)
109
+
110
+ model = MultimodalEmbedder(
111
+ text_encoder_name=config["text_encoder_name"],
112
+ image_encoder_name=config["image_encoder_name"],
113
+ audio_encoder_name=config["audio_encoder_name"],
114
+ output_dim=config["output_dim"],
115
+ fusion_type=config["fusion_type"],
116
+ num_fusion_layers=config["num_fusion_layers"],
117
  )
118
  state_dict = torch.load(checkpoint_path, map_location="cpu")
119
+ model.load_state_dict(state_dict)
120
  model.eval()
121
  ```
122