Add validation for minimum embed_dim in init_resampler

#57
by Haziq-exe - opened
Files changed (1) hide show
  1. modeling_minicpmo.py +9 -1
modeling_minicpmo.py CHANGED
@@ -203,10 +203,18 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
203
  return model
204
 
205
  def init_resampler(self, embed_dim, vision_dim):
 
 
 
 
 
 
 
 
206
  return Resampler(
207
  num_queries=self.config.query_num,
208
  embed_dim=embed_dim,
209
- num_heads=embed_dim // 128,
210
  kv_dim=vision_dim,
211
  adaptive=True,
212
  )
 
203
  return model
204
 
205
  def init_resampler(self, embed_dim, vision_dim):
206
+ MIN_EMBED_DIM = 128
207
+ if embed_dim < MIN_EMBED_DIM:
208
+ raise ValueError(
209
+ f"Resampler requires embed_dim >= {MIN_EMBED_DIM} "
210
+ f"(needed for num_heads = embed_dim // 128 >= 1). "
211
+ f"Got embed_dim={embed_dim}."
212
+ )
213
+
214
  return Resampler(
215
  num_queries=self.config.query_num,
216
  embed_dim=embed_dim,
217
+ num_heads=embed_dim // MIN_EMBED_DIM,
218
  kv_dim=vision_dim,
219
  adaptive=True,
220
  )