bluestarburst commited on
Commit
2057037
·
1 Parent(s): f248e7b

Upload folder using huggingface_hub

Browse files
animatediff/models/motion_module.py CHANGED
@@ -290,7 +290,7 @@ class VersatileAttention(CrossAttention):
290
 
291
  query = self.to_q(hidden_states)
292
  dim = query.shape[-1]
293
- query = self.reshape_heads_to_batch_dim(query)
294
 
295
  if self.added_kv_proj_dim is not None:
296
  raise NotImplementedError
@@ -299,8 +299,8 @@ class VersatileAttention(CrossAttention):
299
  key = self.to_k(encoder_hidden_states)
300
  value = self.to_v(encoder_hidden_states)
301
 
302
- key = self.reshape_heads_to_batch_dim(key)
303
- value = self.reshape_heads_to_batch_dim(value)
304
 
305
  if attention_mask is not None:
306
  if attention_mask.shape[-1] != query.shape[1]:
 
290
 
291
  query = self.to_q(hidden_states)
292
  dim = query.shape[-1]
293
+ query = self.head_to_batch_dim(query)
294
 
295
  if self.added_kv_proj_dim is not None:
296
  raise NotImplementedError
 
299
  key = self.to_k(encoder_hidden_states)
300
  value = self.to_v(encoder_hidden_states)
301
 
302
+ key = self.head_to_batch_dim(key)
303
+ value = self.head_to_batch_dim(value)
304
 
305
  if attention_mask is not None:
306
  if attention_mask.shape[-1] != query.shape[1]: