Add validation for minimum embed_dim in init_resampler
#57
by
Haziq-exe
- opened
- modeling_minicpmo.py +9 -1
modeling_minicpmo.py
CHANGED
|
@@ -203,10 +203,18 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 203 |
return model
|
| 204 |
|
| 205 |
def init_resampler(self, embed_dim, vision_dim):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
return Resampler(
|
| 207 |
num_queries=self.config.query_num,
|
| 208 |
embed_dim=embed_dim,
|
| 209 |
-
num_heads=embed_dim //
|
| 210 |
kv_dim=vision_dim,
|
| 211 |
adaptive=True,
|
| 212 |
)
|
|
|
|
| 203 |
return model
|
| 204 |
|
| 205 |
def init_resampler(self, embed_dim, vision_dim):
|
| 206 |
+
MIN_EMBED_DIM = 128
|
| 207 |
+
if embed_dim < MIN_EMBED_DIM:
|
| 208 |
+
raise ValueError(
|
| 209 |
+
f"Resampler requires embed_dim >= {MIN_EMBED_DIM} "
|
| 210 |
+
f"(needed for num_heads = embed_dim // 128 >= 1). "
|
| 211 |
+
f"Got embed_dim={embed_dim}."
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
return Resampler(
|
| 215 |
num_queries=self.config.query_num,
|
| 216 |
embed_dim=embed_dim,
|
| 217 |
+
num_heads=embed_dim // MIN_EMBED_DIM,
|
| 218 |
kv_dim=vision_dim,
|
| 219 |
adaptive=True,
|
| 220 |
)
|