Nsan05 commited on
Commit
49eccdf
·
verified ·
1 Parent(s): d8c8924

Fix num_heads=0 and WHISPER_ATTENTION_CLASSES import for tiny model

Browse files
Files changed (1) hide show
  1. modeling_minicpmo.py +6 -2
modeling_minicpmo.py CHANGED
@@ -56,7 +56,11 @@ from transformers.cache_utils import StaticCache
56
  from transformers.modeling_outputs import BaseModelOutputWithPast
57
  from transformers.modeling_outputs import ModelOutput
58
  from transformers.models.whisper.modeling_whisper import ACT2FN
59
- from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
 
 
 
 
60
  from transformers.models.whisper.modeling_whisper import WhisperConfig
61
  from transformers.models.whisper.modeling_whisper import WhisperEncoder
62
 
@@ -206,7 +210,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
206
  return Resampler(
207
  num_queries=self.config.query_num,
208
  embed_dim=embed_dim,
209
- num_heads=embed_dim // 128,
210
  kv_dim=vision_dim,
211
  adaptive=True,
212
  )
 
56
  from transformers.modeling_outputs import BaseModelOutputWithPast
57
  from transformers.modeling_outputs import ModelOutput
58
  from transformers.models.whisper.modeling_whisper import ACT2FN
59
+ try:
60
+ from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
61
+ except ImportError:
62
+ from transformers.models.whisper.modeling_whisper import WhisperAttention
63
+ WHISPER_ATTENTION_CLASSES = {"sdpa": WhisperAttention, "eager": WhisperAttention, "flash_attention_2": WhisperAttention}
64
  from transformers.models.whisper.modeling_whisper import WhisperConfig
65
  from transformers.models.whisper.modeling_whisper import WhisperEncoder
66
 
 
210
  return Resampler(
211
  num_queries=self.config.query_num,
212
  embed_dim=embed_dim,
213
+ num_heads=max(1, embed_dim // 128),
214
  kv_dim=vision_dim,
215
  adaptive=True,
216
  )