HuaminChen commited on
Commit
82e6778
·
verified ·
1 Parent(s): e737256

Fix standalone model loading code - verified to work

Browse files
Files changed (1) hide show
  1. README.md +37 -21
README.md CHANGED
@@ -100,20 +100,18 @@ class MultiModalEmbedder(nn.Module):
100
 
101
  def __init__(self):
102
  super().__init__()
103
- # Text encoder
104
  self.text_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
105
  self.text_encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
106
- self.text_proj = nn.Linear(384, 384)
107
 
108
- # Image encoder
109
  self.image_processor = SiglipProcessor.from_pretrained("google/siglip-base-patch16-512")
110
  self.image_encoder = SiglipModel.from_pretrained("google/siglip-base-patch16-512").vision_model
111
  self.image_proj = nn.Linear(768, 384)
112
 
113
- # Audio encoder
114
  self.audio_processor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
115
  self.audio_encoder = WhisperModel.from_pretrained("openai/whisper-tiny").encoder
116
- self.audio_proj = nn.Linear(384, 384)
117
 
118
  def encode_text(self, texts):
119
  if isinstance(texts, str):
@@ -121,8 +119,7 @@ class MultiModalEmbedder(nn.Module):
121
  inputs = self.text_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
122
  inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}
123
  outputs = self.text_encoder(**inputs)
124
- embeddings = outputs.last_hidden_state.mean(dim=1)
125
- embeddings = self.text_proj(embeddings)
126
  return F.normalize(embeddings, p=2, dim=-1)
127
 
128
  def encode_image(self, images):
@@ -130,16 +127,17 @@ class MultiModalEmbedder(nn.Module):
130
  inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}
131
  outputs = self.image_encoder(**inputs)
132
  embeddings = outputs.pooler_output
133
- embeddings = self.image_proj(embeddings)
134
  return F.normalize(embeddings, p=2, dim=-1)
135
 
136
  def encode_audio(self, waveform):
137
- # waveform: [batch, samples] at 16kHz
138
- inputs = self.audio_processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
 
 
139
  inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}
140
  outputs = self.audio_encoder(**inputs)
141
- embeddings = outputs.last_hidden_state.mean(dim=1)
142
- embeddings = self.audio_proj(embeddings)
143
  return F.normalize(embeddings, p=2, dim=-1)
144
 
145
  # Load model
@@ -150,15 +148,33 @@ checkpoint_path = hf_hub_download(
150
  repo_id="llm-semantic-router/multi-modal-embed-small",
151
  filename="model.pt"
152
  )
153
- state_dict = torch.load(checkpoint_path, map_location="cpu")
154
-
155
- # Map checkpoint keys to our model
156
- model.text_encoder.load_state_dict({k.replace("text_encoder.encoder.", ""): v for k, v in state_dict.items() if k.startswith("text_encoder.encoder.")})
157
- model.text_proj.load_state_dict({k.replace("text_encoder.projection.", ""): v for k, v in state_dict.items() if k.startswith("text_encoder.projection.")})
158
- model.image_encoder.load_state_dict({k.replace("image_encoder.vision_encoder.", ""): v for k, v in state_dict.items() if k.startswith("image_encoder.vision_encoder.")})
159
- model.image_proj.load_state_dict({k.replace("image_encoder.projection.", ""): v for k, v in state_dict.items() if k.startswith("image_encoder.projection.")})
160
- model.audio_encoder.load_state_dict({k.replace("audio_encoder.encoder.", ""): v for k, v in state_dict.items() if k.startswith("audio_encoder.encoder.")})
161
- model.audio_proj.load_state_dict({k.replace("audio_encoder.projection.", ""): v for k, v in state_dict.items() if k.startswith("audio_encoder.projection.")})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  model.eval()
164
  print("Model loaded successfully!")
 
100
 
101
  def __init__(self):
102
  super().__init__()
103
+ # Text encoder (384d, no projection needed)
104
  self.text_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
105
  self.text_encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
 
106
 
107
+ # Image encoder (768d -> 384d projection)
108
  self.image_processor = SiglipProcessor.from_pretrained("google/siglip-base-patch16-512")
109
  self.image_encoder = SiglipModel.from_pretrained("google/siglip-base-patch16-512").vision_model
110
  self.image_proj = nn.Linear(768, 384)
111
 
112
+ # Audio encoder (384d, no projection needed)
113
  self.audio_processor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
114
  self.audio_encoder = WhisperModel.from_pretrained("openai/whisper-tiny").encoder
 
115
 
116
  def encode_text(self, texts):
117
  if isinstance(texts, str):
 
119
  inputs = self.text_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
120
  inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}
121
  outputs = self.text_encoder(**inputs)
122
+ embeddings = outputs.last_hidden_state.mean(dim=1) # Mean pooling
 
123
  return F.normalize(embeddings, p=2, dim=-1)
124
 
125
  def encode_image(self, images):
 
127
  inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}
128
  outputs = self.image_encoder(**inputs)
129
  embeddings = outputs.pooler_output
130
+ embeddings = self.image_proj(embeddings) # 768 -> 384
131
  return F.normalize(embeddings, p=2, dim=-1)
132
 
133
  def encode_audio(self, waveform):
134
+ # waveform: numpy array or tensor at 16kHz
135
+ if isinstance(waveform, torch.Tensor):
136
+ waveform = waveform.squeeze().numpy()
137
+ inputs = self.audio_processor(waveform, sampling_rate=16000, return_tensors="pt")
138
  inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}
139
  outputs = self.audio_encoder(**inputs)
140
+ embeddings = outputs.last_hidden_state.mean(dim=1) # Mean pooling
 
141
  return F.normalize(embeddings, p=2, dim=-1)
142
 
143
  # Load model
 
148
  repo_id="llm-semantic-router/multi-modal-embed-small",
149
  filename="model.pt"
150
  )
151
+ state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
152
+
153
+ # Load text encoder weights
154
+ model.text_encoder.load_state_dict({
155
+ k.replace("text_encoder.encoder.", ""): v
156
+ for k, v in state_dict.items()
157
+ if k.startswith("text_encoder.encoder.")
158
+ })
159
+
160
+ # Load image encoder and projection weights
161
+ model.image_encoder.load_state_dict({
162
+ k.replace("image_encoder.vision_encoder.", ""): v
163
+ for k, v in state_dict.items()
164
+ if k.startswith("image_encoder.vision_encoder.")
165
+ })
166
+ model.image_proj.load_state_dict({
167
+ k.replace("image_encoder.projection.", ""): v
168
+ for k, v in state_dict.items()
169
+ if k.startswith("image_encoder.projection.")
170
+ })
171
+
172
+ # Load audio encoder weights
173
+ model.audio_encoder.load_state_dict({
174
+ k.replace("audio_encoder.encoder.", ""): v
175
+ for k, v in state_dict.items()
176
+ if k.startswith("audio_encoder.encoder.")
177
+ })
178
 
179
  model.eval()
180
  print("Model loaded successfully!")