Haziq-exe commited on
Commit
2bc0447
·
verified ·
1 Parent(s): 509805e

Add validation for minimum embed_dim in init_resampler

Browse files

Resampler is initialized with the hard-coded line `num_heads = embed_dim // 128`. When running with `config.hidden_size < 128`, it cryptically fails later on while giving an error message stating that `num_heads=0 while embed_dim={embed_dimension value}`

This change adds explicit validation with a helpful error message about the minimum embed_dim size so no more cryptic errors pop up and it is easier to debug.

Affects use cases:
- Model compression/distillation research
- Creating tiny test models for CI/CD

Files changed (1) hide show
  1. modeling_minicpmo.py +9 -1
modeling_minicpmo.py CHANGED
@@ -203,10 +203,18 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
203
  return model
204
 
205
  def init_resampler(self, embed_dim, vision_dim):
 
 
 
 
 
 
 
 
206
  return Resampler(
207
  num_queries=self.config.query_num,
208
  embed_dim=embed_dim,
209
- num_heads=embed_dim // 128,
210
  kv_dim=vision_dim,
211
  adaptive=True,
212
  )
 
203
  return model
204
 
205
  def init_resampler(self, embed_dim, vision_dim):
206
+ MIN_EMBED_DIM = 128
207
+ if embed_dim < MIN_EMBED_DIM:
208
+ raise ValueError(
209
+ f"Resampler requires embed_dim >= {MIN_EMBED_DIM} "
210
+ f"(needed for num_heads = embed_dim // 128 >= 1). "
211
+ f"Got embed_dim={embed_dim}."
212
+ )
213
+
214
  return Resampler(
215
  num_queries=self.config.query_num,
216
  embed_dim=embed_dim,
217
+ num_heads=embed_dim // MIN_EMBED_DIM,
218
  kv_dim=vision_dim,
219
  adaptive=True,
220
  )