Fix num_heads=0 and WHISPER_ATTENTION_CLASSES import for tiny model
Browse files- modeling_minicpmo.py +6 -2
modeling_minicpmo.py
CHANGED
|
@@ -56,7 +56,11 @@ from transformers.cache_utils import StaticCache
|
|
| 56 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 57 |
from transformers.modeling_outputs import ModelOutput
|
| 58 |
from transformers.models.whisper.modeling_whisper import ACT2FN
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
from transformers.models.whisper.modeling_whisper import WhisperConfig
|
| 61 |
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
| 62 |
|
|
@@ -206,7 +210,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 206 |
return Resampler(
|
| 207 |
num_queries=self.config.query_num,
|
| 208 |
embed_dim=embed_dim,
|
| 209 |
-
num_heads=embed_dim // 128,
|
| 210 |
kv_dim=vision_dim,
|
| 211 |
adaptive=True,
|
| 212 |
)
|
|
|
|
| 56 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 57 |
from transformers.modeling_outputs import ModelOutput
|
| 58 |
from transformers.models.whisper.modeling_whisper import ACT2FN
|
| 59 |
+
try:
|
| 60 |
+
from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
|
| 61 |
+
except ImportError:
|
| 62 |
+
from transformers.models.whisper.modeling_whisper import WhisperAttention
|
| 63 |
+
WHISPER_ATTENTION_CLASSES = {"sdpa": WhisperAttention, "eager": WhisperAttention, "flash_attention_2": WhisperAttention}
|
| 64 |
from transformers.models.whisper.modeling_whisper import WhisperConfig
|
| 65 |
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
| 66 |
|
|
|
|
| 210 |
return Resampler(
|
| 211 |
num_queries=self.config.query_num,
|
| 212 |
embed_dim=embed_dim,
|
| 213 |
+
num_heads=max(1, embed_dim // 128),
|
| 214 |
kv_dim=vision_dim,
|
| 215 |
adaptive=True,
|
| 216 |
)
|