BucketOfFish commited on
Commit
6be42f2
·
1 Parent(s): 7da2fc9

Removed non-flash classes

Browse files
Files changed (2) hide show
  1. configuration_phi.py +0 -4
  2. modeling_phi.py +64 -339
configuration_phi.py CHANGED
@@ -29,8 +29,6 @@ class PhiConfig(PretrainedConfig):
29
  n_head_kv: Optional[int] = None,
30
  rotary_dim: Optional[int] = 32,
31
  activation_function: Optional[str] = "gelu_new",
32
- flash_attn: bool = False,
33
- flash_rotary: bool = False,
34
  fused_dense: bool = False,
35
  attn_pdrop: float = 0.0,
36
  embd_pdrop: float = 0.0,
@@ -50,8 +48,6 @@ class PhiConfig(PretrainedConfig):
50
  self.n_head_kv = n_head_kv
51
  self.rotary_dim = min(rotary_dim, n_embd // n_head)
52
  self.activation_function = activation_function
53
- self.flash_attn = flash_attn
54
- self.flash_rotary = flash_rotary
55
  self.fused_dense = fused_dense
56
  self.attn_pdrop = attn_pdrop
57
  self.embd_pdrop = embd_pdrop
 
29
  n_head_kv: Optional[int] = None,
30
  rotary_dim: Optional[int] = 32,
31
  activation_function: Optional[str] = "gelu_new",
 
 
32
  fused_dense: bool = False,
33
  attn_pdrop: float = 0.0,
34
  embd_pdrop: float = 0.0,
 
48
  self.n_head_kv = n_head_kv
49
  self.rotary_dim = min(rotary_dim, n_embd // n_head)
50
  self.activation_function = activation_function
 
 
51
  self.fused_dense = fused_dense
52
  self.attn_pdrop = attn_pdrop
53
  self.embd_pdrop = embd_pdrop
modeling_phi.py CHANGED
@@ -19,16 +19,10 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
19
 
20
  from .configuration_phi import PhiConfig
21
 
22
- try:
23
- from flash_attn.bert_padding import pad_input, unpad_input
24
- from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
25
- from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
26
- from flash_attn.ops.fused_dense import FusedDense
27
- except:
28
- pad_input, unpad_input = None, None
29
- FlashRotaryEmbedding = None
30
- FlashSelfAttention, FlashCrossAttention = None, None
31
- FusedDense = None
32
 
33
 
34
  @dataclass
@@ -168,128 +162,6 @@ def _apply_rotary_emb_qkv(
168
  )
169
 
170
 
171
- class RotaryEmbedding(nn.Module):
172
- """Rotary positional embedding (RoPE).
173
-
174
- Reference:
175
- RoFormer: Enhanced Transformer with Rotary Position Embedding.
176
- https://arxiv.org/pdf/2104.09864.pdf.
177
-
178
- """
179
-
180
- def __init__(
181
- self,
182
- dim: int,
183
- base: int = 10000,
184
- scale_base: Optional[float] = None,
185
- pos_idx_in_fp32: bool = True,
186
- max_position_embeddings: int = 2048,
187
- device: Optional[str] = None,
188
- **kwargs,
189
- ) -> None:
190
- super().__init__()
191
-
192
- if scale_base is not None:
193
- raise NotImplementedError
194
-
195
- self.dim = dim
196
- self.base = float(base)
197
- self.scale_base = scale_base
198
- self.pos_idx_in_fp32 = pos_idx_in_fp32
199
- self.max_position_embeddings = max_position_embeddings
200
- self.device = device
201
-
202
- # Generate and save the inverse frequency buffer (non-trainable)
203
- inv_freq = self._compute_inv_freq(device)
204
- self.register_buffer("inv_freq", inv_freq, persistent=False)
205
-
206
- # Generate and save the scale buffer (non-trainable)
207
- scale = (
208
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
209
- if scale_base is not None
210
- else None
211
- )
212
- self.register_buffer("scale", scale, persistent=False)
213
-
214
- # Initialize cached attributes since ONNX can't rely on dynamic initialization
215
- self._update_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.float32)
216
-
217
- def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
218
- return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
219
-
220
- def _update_cos_sin_cache(
221
- self,
222
- seqlen: int,
223
- device: Optional[str] = None,
224
- dtype: Optional[torch.dtype] = None,
225
- ) -> None:
226
- self._seq_len_cached = seqlen
227
-
228
- # fp32 is preferred since the output of `torch.arange` can be quite large
229
- # and bf16 would lose a lot of precision
230
- if self.pos_idx_in_fp32:
231
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
232
- if self.inv_freq.dtype != torch.float32:
233
- inv_freq = self._compute_inv_freq(device=device)
234
- else:
235
- inv_freq = self.inv_freq
236
- else:
237
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
238
- inv_freq = self.inv_freq
239
-
240
- # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
241
- freqs = torch.outer(t, inv_freq)
242
- if self.scale is None:
243
- self._cos_cached = torch.cos(freqs).to(dtype)
244
- self._sin_cached = torch.sin(freqs).to(dtype)
245
- else:
246
- power = (
247
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
248
- ) / self.scale_base
249
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
250
-
251
- # Force the scale multiplication to happen in fp32
252
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
253
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
254
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
255
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
256
-
257
- def forward(
258
- self,
259
- qkv: torch.Tensor,
260
- kv: Optional[torch.Tensor] = None,
261
- seqlen_offset: int = 0,
262
- **kwargs,
263
- ) -> Tuple[torch.Tensor, torch.Tensor]:
264
- if (
265
- self._seq_len_cached < qkv.shape[1] + seqlen_offset
266
- or self._cos_cached.device != qkv.device
267
- or self._cos_cached.dtype != qkv.dtype
268
- or (self.training and self._cos_cached.is_inference())
269
- ):
270
- self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
271
-
272
- if kv is None:
273
- return _apply_rotary_emb_qkv(
274
- qkv,
275
- self._cos_cached[seqlen_offset:],
276
- self._sin_cached[seqlen_offset:],
277
- )
278
- else:
279
- q = _apply_rotary_emb(
280
- qkv,
281
- self._cos_cached[seqlen_offset:],
282
- self._sin_cached[seqlen_offset:],
283
- )
284
- kv = _apply_rotary_emb_kv(
285
- kv,
286
- self._cos_cached[seqlen_offset:],
287
- self._sin_cached[seqlen_offset:],
288
- )
289
-
290
- return q, kv
291
-
292
-
293
  class MLP(nn.Module):
294
  """Multi-Layer Perceptron.
295
 
@@ -324,139 +196,6 @@ class MLP(nn.Module):
324
  return hidden_states
325
 
326
 
327
- class SelfAttention(nn.Module):
328
- """Self-attention layer (compatible with PyTorch).
329
-
330
- Reference:
331
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
332
-
333
- """
334
-
335
- def __init__(
336
- self,
337
- causal: bool = True,
338
- softmax_scale: Optional[float] = None,
339
- attention_dropout: float = 0.0,
340
- ) -> None:
341
- super().__init__()
342
-
343
- self.causal = causal
344
- self.softmax_scale = softmax_scale
345
- self.drop = nn.Dropout(attention_dropout)
346
-
347
- @torch.autocast("cpu", enabled=False)
348
- @torch.autocast("cuda", enabled=False)
349
- def forward(
350
- self,
351
- qkv: torch.FloatTensor,
352
- causal: bool = None,
353
- key_padding_mask: Optional[torch.BoolTensor] = None,
354
- **kwargs,
355
- ) -> torch.FloatTensor:
356
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
357
- q, k, v = qkv.unbind(dim=2)
358
-
359
- q = q.to(torch.float32)
360
- k = k.to(torch.float32)
361
-
362
- causal = self.causal if causal is None else causal
363
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
364
-
365
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
366
- # using float16, which might lead to overflow
367
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
368
-
369
- if key_padding_mask is not None:
370
- padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
371
- padding_mask.masked_fill_(key_padding_mask, 0.0)
372
-
373
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
374
-
375
- if causal:
376
- causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
377
- scores = scores + causal_mask.to(dtype=scores.dtype)
378
-
379
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
380
- attention = self.drop(attention)
381
-
382
- output = torch.einsum("bhts,bshd->bthd", attention, v)
383
-
384
- return output
385
-
386
-
387
- class CrossAttention(nn.Module):
388
- """Cross-attention layer (compatible with PyTorch).
389
-
390
- Reference:
391
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
392
-
393
- """
394
-
395
- def __init__(
396
- self,
397
- causal: bool = True,
398
- softmax_scale: Optional[float] = None,
399
- attention_dropout: float = 0.0,
400
- ) -> None:
401
- super().__init__()
402
-
403
- self.causal = causal
404
- self.softmax_scale = softmax_scale
405
- self.drop = nn.Dropout(attention_dropout)
406
-
407
- @torch.autocast("cpu", enabled=False)
408
- @torch.autocast("cuda", enabled=False)
409
- def forward(
410
- self,
411
- q: torch.FloatTensor,
412
- kv: torch.FloatTensor,
413
- causal: bool = None,
414
- key_padding_mask: Optional[torch.BoolTensor] = None,
415
- **kwargs,
416
- ) -> torch.FloatTensor:
417
- batch_size, seqlen_q = q.shape[0], q.shape[1]
418
- seqlen_k = kv.shape[1]
419
-
420
- if kv.shape[3] != q.shape[2]:
421
- kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
422
- k, v = kv.unbind(dim=2)
423
-
424
- q = q.to(torch.float32)
425
- k = k.to(torch.float32)
426
-
427
- causal = self.causal if causal is None else causal
428
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
429
-
430
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
431
- # using float16, which might lead to overflow
432
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
433
-
434
- if key_padding_mask is not None:
435
- padding_mask = torch.full(
436
- (batch_size, seqlen_k),
437
- -10000.0,
438
- dtype=scores.dtype,
439
- device=scores.device,
440
- )
441
- padding_mask.masked_fill_(key_padding_mask, 0.0)
442
-
443
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
444
-
445
- if causal:
446
- rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
447
- cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
448
- causal_mask = cols > rows + seqlen_k - seqlen_q
449
-
450
- scores = scores.masked_fill(causal_mask, -10000.0)
451
-
452
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
453
- attention = self.drop(attention)
454
-
455
- output = torch.einsum("bhts,bshd->bthd", attention, v)
456
-
457
- return output
458
-
459
-
460
  def _find_mha_dims(
461
  config: PretrainedConfig,
462
  n_head: Optional[int] = None,
@@ -532,14 +271,8 @@ class MHA(nn.Module):
532
  # Rotary embedding
533
  self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
534
  if self.rotary_dim > 0:
535
- rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
536
- if rotary_cls is None:
537
- rotary_cls = RotaryEmbedding
538
-
539
  rotary_kwargs = {}
540
- if rotary_cls is RotaryEmbedding:
541
- rotary_kwargs["max_position_embeddings"] = config.n_positions
542
-
543
  self.rotary_emb = rotary_cls(
544
  self.rotary_dim,
545
  base=rotary_base,
@@ -563,13 +296,8 @@ class MHA(nn.Module):
563
  self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
564
 
565
  # Attention
566
- attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention
567
- if attn_cls is None:
568
- attn_cls = SelfAttention
569
-
570
- cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention
571
- if cross_attn_cls is None:
572
- cross_attn_cls = CrossAttention
573
 
574
  self.inner_attn = attn_cls(
575
  causal=causal,
@@ -582,7 +310,6 @@ class MHA(nn.Module):
582
  attention_dropout=config.attn_pdrop,
583
  )
584
 
585
- self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
586
  self.layer_idx = layer_idx
587
  self.return_residual = return_residual
588
  self.checkpointing = checkpointing
@@ -596,24 +323,23 @@ class MHA(nn.Module):
596
  if self.rotary_dim > 0:
597
  qkv = self.rotary_emb(qkv)
598
 
599
- if self.flash_attn:
600
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
601
 
602
- cu_seqlens, max_seqlen = None, None
603
- if key_padding_mask is not None:
604
- # If `key_padding_mask` is supplied, we need to unpad the input and retrieve
605
- # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
606
- qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
607
 
608
- if self.checkpointing:
609
- attn_output = torch.utils.checkpoint.checkpoint(
610
- self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
611
- )
612
- else:
613
- attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
614
 
615
- # If `key_padding_mask` is supplied, we need to pad the output back to the original shape
616
- return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
617
 
618
  if self.checkpointing:
619
  return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
@@ -644,54 +370,53 @@ class MHA(nn.Module):
644
  if past_key_values is not None:
645
  kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
646
 
647
- if self.flash_attn:
648
- batch_size, seqlen_q = q.shape[0], q.shape[1]
649
- seqlen_k = kv.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
 
651
- cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
652
- None,
653
- None,
654
- None,
655
- None,
 
 
 
 
 
656
  )
657
- if key_padding_mask is not None:
658
- kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
659
-
660
- if seqlen_q == 1:
661
- key_padding_mask = torch.ones(batch_size, 1, device=q.device)
662
- elif seqlen_q != seqlen_k:
663
- key_padding_mask = key_padding_mask[:, -seqlen_q:]
664
-
665
- q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
666
-
667
- if self.checkpointing:
668
- attn_output = torch.utils.checkpoint.checkpoint(
669
- self.inner_cross_attn,
670
- q,
671
- kv,
672
- causal=causal,
673
- cu_seqlens=cu_seqlens_q,
674
- max_seqlen=max_seqlen_q,
675
- cu_seqlens_k=cu_seqlens_k,
676
- max_seqlen_k=max_seqlen_k,
677
- )
678
- else:
679
- attn_output = self.inner_cross_attn(
680
- q,
681
- kv,
682
- causal=causal,
683
- cu_seqlens=cu_seqlens_q,
684
- max_seqlen=max_seqlen_q,
685
- cu_seqlens_k=cu_seqlens_k,
686
- max_seqlen_k=max_seqlen_k,
687
- )
688
-
689
- return (
690
- pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
691
- if key_padding_mask is not None
692
- else attn_output
693
  )
694
 
 
 
 
 
 
 
695
  if self.checkpointing:
696
  return torch.utils.checkpoint.checkpoint(
697
  self.inner_cross_attn,
 
19
 
20
  from .configuration_phi import PhiConfig
21
 
22
+ from flash_attn.bert_padding import pad_input, unpad_input
23
+ from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
24
+ from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
25
+ from flash_attn.ops.fused_dense import FusedDense
 
 
 
 
 
 
26
 
27
 
28
  @dataclass
 
162
  )
163
 
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  class MLP(nn.Module):
166
  """Multi-Layer Perceptron.
167
 
 
196
  return hidden_states
197
 
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def _find_mha_dims(
200
  config: PretrainedConfig,
201
  n_head: Optional[int] = None,
 
271
  # Rotary embedding
272
  self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
273
  if self.rotary_dim > 0:
274
+ rotary_cls = FlashRotaryEmbedding
 
 
 
275
  rotary_kwargs = {}
 
 
 
276
  self.rotary_emb = rotary_cls(
277
  self.rotary_dim,
278
  base=rotary_base,
 
296
  self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
297
 
298
  # Attention
299
+ attn_cls = FlashSelfAttention
300
+ cross_attn_cls = FlashCrossAttention
 
 
 
 
 
301
 
302
  self.inner_attn = attn_cls(
303
  causal=causal,
 
310
  attention_dropout=config.attn_pdrop,
311
  )
312
 
 
313
  self.layer_idx = layer_idx
314
  self.return_residual = return_residual
315
  self.checkpointing = checkpointing
 
323
  if self.rotary_dim > 0:
324
  qkv = self.rotary_emb(qkv)
325
 
326
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
 
327
 
328
+ cu_seqlens, max_seqlen = None, None
329
+ if key_padding_mask is not None:
330
+ # If `key_padding_mask` is supplied, we need to unpad the input and retrieve
331
+ # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
332
+ qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
333
 
334
+ if self.checkpointing:
335
+ attn_output = torch.utils.checkpoint.checkpoint(
336
+ self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
337
+ )
338
+ else:
339
+ attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
340
 
341
+ # If `key_padding_mask` is supplied, we need to pad the output back to the original shape
342
+ return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
343
 
344
  if self.checkpointing:
345
  return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
 
370
  if past_key_values is not None:
371
  kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
372
 
373
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
374
+ seqlen_k = kv.shape[1]
375
+
376
+ cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
377
+ None,
378
+ None,
379
+ None,
380
+ None,
381
+ )
382
+ if key_padding_mask is not None:
383
+ kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
384
+
385
+ if seqlen_q == 1:
386
+ key_padding_mask = torch.ones(batch_size, 1, device=q.device)
387
+ elif seqlen_q != seqlen_k:
388
+ key_padding_mask = key_padding_mask[:, -seqlen_q:]
389
+
390
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
391
 
392
+ if self.checkpointing:
393
+ attn_output = torch.utils.checkpoint.checkpoint(
394
+ self.inner_cross_attn,
395
+ q,
396
+ kv,
397
+ causal=causal,
398
+ cu_seqlens=cu_seqlens_q,
399
+ max_seqlen=max_seqlen_q,
400
+ cu_seqlens_k=cu_seqlens_k,
401
+ max_seqlen_k=max_seqlen_k,
402
  )
403
+ else:
404
+ attn_output = self.inner_cross_attn(
405
+ q,
406
+ kv,
407
+ causal=causal,
408
+ cu_seqlens=cu_seqlens_q,
409
+ max_seqlen=max_seqlen_q,
410
+ cu_seqlens_k=cu_seqlens_k,
411
+ max_seqlen_k=max_seqlen_k,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  )
413
 
414
+ return (
415
+ pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
416
+ if key_padding_mask is not None
417
+ else attn_output
418
+ )
419
+
420
  if self.checkpointing:
421
  return torch.utils.checkpoint.checkpoint(
422
  self.inner_cross_attn,