vectominist commited on
Commit
d42aa85
·
verified ·
1 Parent(s): 1a17124

Update usad_modules.py

Browse files
Files changed (1) hide show
  1. usad_modules.py +461 -130
usad_modules.py CHANGED
@@ -1,25 +1,38 @@
1
- # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
 
15
  import contextlib
16
  import math
17
  from collections import defaultdict
18
- from typing import Dict, List, Optional, Tuple, Union
19
 
20
  import torch
21
  import torch.nn.functional as F
22
  from torch import nn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  class SamePad(nn.Module):
@@ -66,6 +79,20 @@ class GLU(nn.Module):
66
  return outputs * gate.sigmoid()
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  class ResidualConnectionModule(nn.Module):
70
  def __init__(
71
  self,
@@ -79,11 +106,15 @@ class ResidualConnectionModule(nn.Module):
79
  self.input_factor = input_factor
80
 
81
  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
82
- return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor)
 
 
83
 
84
 
85
  class Linear(nn.Module):
86
- def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
 
 
87
  super(Linear, self).__init__()
88
  self.linear = nn.Linear(in_features, out_features, bias=bias)
89
  nn.init.xavier_uniform_(self.linear.weight)
@@ -122,10 +153,15 @@ class FeedForwardModule(nn.Module):
122
  encoder_dim: int = 512,
123
  expansion_factor: int = 4,
124
  dropout_p: float = 0.1,
 
125
  ) -> None:
126
  super(FeedForwardModule, self).__init__()
127
  self.sequential = nn.Sequential(
128
- nn.LayerNorm(encoder_dim),
 
 
 
 
129
  Linear(encoder_dim, encoder_dim * expansion_factor, bias=True),
130
  Swish(),
131
  nn.Dropout(p=dropout_p),
@@ -195,15 +231,22 @@ class ConformerConvModule(nn.Module):
195
  kernel_size: int = 31,
196
  expansion_factor: int = 2,
197
  dropout_p: float = 0.1,
 
198
  ) -> None:
199
  super(ConformerConvModule, self).__init__()
200
  assert (
201
  kernel_size - 1
202
  ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
203
- assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
 
 
204
 
205
  self.sequential = nn.Sequential(
206
- nn.LayerNorm(in_channels),
 
 
 
 
207
  Transpose(shape=(1, 2)),
208
  PointwiseConv1d(
209
  in_channels,
@@ -222,7 +265,9 @@ class ConformerConvModule(nn.Module):
222
  ),
223
  nn.BatchNorm1d(in_channels),
224
  Swish(),
225
- PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True),
 
 
226
  nn.Dropout(p=dropout_p),
227
  )
228
 
@@ -249,13 +294,19 @@ class FramewiseConv2dSubampling(nn.Module):
249
  )
250
 
251
  def forward(
252
- self, inputs: torch.Tensor, input_lengths: torch.LongTensor
253
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
254
  # inputs: (B, T, C) -> (B, 1, T, C)
255
  if self.subsample_rate == 2 and inputs.shape[1] % 2 == 0:
256
  inputs = F.pad(inputs, (0, 0, 0, 1), "constant", 0)
 
 
 
 
257
  outputs = self.cnn(inputs.unsqueeze(1))
258
- batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size()
 
 
259
 
260
  outputs = outputs.permute(0, 2, 1, 3)
261
  outputs = outputs.contiguous().view(
@@ -263,12 +314,21 @@ class FramewiseConv2dSubampling(nn.Module):
263
  )
264
 
265
  if self.subsample_rate == 4:
266
- output_lengths = (((input_lengths - 1) >> 1) - 1) >> 1
267
  else:
268
  output_lengths = input_lengths >> 1
269
 
270
  return outputs, output_lengths
271
 
 
 
 
 
 
 
 
 
 
272
 
273
  class PatchwiseConv2dSubampling(nn.Module):
274
  def __init__(
@@ -292,9 +352,13 @@ class PatchwiseConv2dSubampling(nn.Module):
292
  padding=0,
293
  )
294
  self.cnn = nn.Sequential(
295
- nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
 
 
296
  nn.ReLU(),
297
- nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
 
 
298
  nn.ReLU(),
299
  )
300
 
@@ -303,8 +367,8 @@ class PatchwiseConv2dSubampling(nn.Module):
303
  return self.patch_size_time * self.patch_size_freq // self.mel_dim
304
 
305
  def forward(
306
- self, inputs: torch.Tensor, input_lengths: torch.LongTensor
307
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
308
  assert (
309
  inputs.shape[2] == self.mel_dim
310
  ), "inputs.shape[2] should be equal to mel_dim"
@@ -326,11 +390,10 @@ class PatchwiseConv2dSubampling(nn.Module):
326
 
327
 
328
  class RelPositionalEncoding(nn.Module):
329
- def __init__(self, d_model: int, max_len: int = 10000) -> None:
330
  super(RelPositionalEncoding, self).__init__()
331
  self.d_model = d_model
332
  self.pe = None
333
- self.extend_pe(torch.tensor(0.0).expand(1, max_len))
334
 
335
  def extend_pe(self, x: torch.Tensor) -> None:
336
  if self.pe is not None:
@@ -339,11 +402,14 @@ class RelPositionalEncoding(nn.Module):
339
  self.pe = self.pe.to(dtype=x.dtype, device=x.device)
340
  return
341
 
342
- pe_positive = torch.zeros(x.size(1), self.d_model)
343
- pe_negative = torch.zeros(x.size(1), self.d_model)
344
- position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
 
 
 
345
  div_term = torch.exp(
346
- torch.arange(0, self.d_model, 2, dtype=torch.float32)
347
  * -(math.log(10000.0) / self.d_model)
348
  )
349
  pe_positive[:, 0::2] = torch.sin(position * div_term)
@@ -359,9 +425,13 @@ class RelPositionalEncoding(nn.Module):
359
  def forward(self, x: torch.Tensor) -> torch.Tensor:
360
  # x: (B, T, C)
361
  self.extend_pe(x)
 
362
  pos_emb = self.pe[
363
  :,
364
- self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
 
 
 
365
  ]
366
  return pos_emb
367
 
@@ -393,90 +463,171 @@ class RelativeMultiHeadAttention(nn.Module):
393
 
394
  self.out_proj = Linear(d_model, d_model)
395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  def forward(
397
  self,
398
  query: torch.Tensor,
399
  key: torch.Tensor,
400
  value: torch.Tensor,
401
  pos_embedding: torch.Tensor,
402
- mask: Optional[torch.Tensor] = None,
403
- ) -> Tuple[torch.Tensor, torch.Tensor]:
404
- batch_size = value.size(0)
405
-
406
- query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
407
- key = (
408
- self.key_proj(key)
409
- .view(batch_size, -1, self.num_heads, self.d_head)
410
- .permute(0, 2, 1, 3)
411
- )
412
- value = (
413
- self.value_proj(value)
414
- .view(batch_size, -1, self.num_heads, self.d_head)
415
- .permute(0, 2, 1, 3)
416
- )
417
- pos_embedding = self.pos_proj(pos_embedding).view(
418
- batch_size, -1, self.num_heads, self.d_head
419
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  content_score = torch.matmul(
422
- (query + self.u_bias).transpose(1, 2), key.transpose(2, 3)
423
- )
424
- pos_score = torch.matmul(
425
- (query + self.v_bias).transpose(1, 2),
426
- pos_embedding.permute(0, 2, 3, 1),
427
- )
428
- pos_score = self._relative_shift(pos_score)
429
-
430
- score = (content_score + pos_score) / self.sqrt_dim
431
 
432
- if mask is not None:
433
- mask = mask.unsqueeze(1)
434
- score.masked_fill_(mask, -1e9)
435
 
436
- attn = F.softmax(score, -1)
437
  attn = self.dropout(attn)
438
 
439
- context = torch.matmul(attn, value).transpose(1, 2)
440
- context = context.contiguous().view(batch_size, -1, self.d_model)
441
-
442
- return self.out_proj(context), attn
443
-
444
- def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor:
445
- batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
446
- zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
447
- padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
448
-
449
- padded_pos_score = padded_pos_score.view(
450
- batch_size, num_heads, seq_length2 + 1, seq_length1
451
  )
452
- pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)[
453
- :, :, :, : seq_length2 // 2 + 1
454
- ]
455
 
456
- return pos_score
457
 
458
 
459
  class MultiHeadedSelfAttentionModule(nn.Module):
460
- def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1):
 
 
 
 
 
 
461
  super(MultiHeadedSelfAttentionModule, self).__init__()
462
  self.positional_encoding = RelPositionalEncoding(d_model)
463
- self.layer_norm = nn.LayerNorm(d_model)
464
- self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p)
 
 
 
 
465
  self.dropout = nn.Dropout(p=dropout_p)
466
 
467
  def forward(
468
- self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None
469
- ) -> Tuple[torch.Tensor, torch.Tensor]:
470
- batch_size = inputs.size(0)
471
- pos_embedding = self.positional_encoding(inputs)
472
- pos_embedding = pos_embedding.repeat(batch_size, 1, 1)
 
 
473
 
474
  inputs = self.layer_norm(inputs)
475
  outputs, attn = self.attention(
476
- inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask
 
 
 
 
477
  )
478
 
479
- return self.dropout(outputs), attn
480
 
481
 
482
  class ConformerBlock(nn.Module):
@@ -485,10 +636,6 @@ class ConformerBlock(nn.Module):
485
  encoder_dim: int = 512,
486
  attention_type: str = "mhsa",
487
  num_attention_heads: int = 8,
488
- mamba_d_state: int = 16,
489
- mamba_d_conv: int = 4,
490
- mamba_expand: int = 2,
491
- mamba_bidirectional: bool = True,
492
  feed_forward_expansion_factor: int = 4,
493
  conv_expansion_factor: int = 2,
494
  feed_forward_dropout_p: float = 0.1,
@@ -497,29 +644,37 @@ class ConformerBlock(nn.Module):
497
  conv_kernel_size: int = 31,
498
  half_step_residual: bool = True,
499
  transformer_style: bool = False,
 
 
 
500
  ):
501
  super(ConformerBlock, self).__init__()
502
 
503
  self.transformer_style = transformer_style
504
  self.attention_type = attention_type
 
 
505
 
506
  if half_step_residual and not transformer_style:
507
  self.feed_forward_residual_factor = 0.5
508
  else:
509
  self.feed_forward_residual_factor = 1
510
 
511
- assert attention_type in ["mhsa", "mamba"]
512
- if attention_type == "mhsa":
513
- attention = MultiHeadedSelfAttentionModule(
514
- d_model=encoder_dim,
515
- num_heads=num_attention_heads,
516
- dropout_p=attention_dropout_p,
517
- )
 
 
518
 
519
  self.ffn_1 = FeedForwardModule(
520
  encoder_dim=encoder_dim,
521
  expansion_factor=feed_forward_expansion_factor,
522
  dropout_p=feed_forward_dropout_p,
 
523
  )
524
  self.attention = attention
525
  if not transformer_style:
@@ -528,28 +683,49 @@ class ConformerBlock(nn.Module):
528
  kernel_size=conv_kernel_size,
529
  expansion_factor=conv_expansion_factor,
530
  dropout_p=conv_dropout_p,
 
531
  )
532
  self.ffn_2 = FeedForwardModule(
533
  encoder_dim=encoder_dim,
534
  expansion_factor=feed_forward_expansion_factor,
535
  dropout_p=feed_forward_dropout_p,
 
536
  )
537
- self.layernorm = nn.LayerNorm(encoder_dim)
 
 
 
 
 
 
 
 
538
 
539
- def forward(
540
- self, x: torch.Tensor
541
- ) -> Tuple[torch.Tensor, Dict[str, Union[torch.Tensor, None]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
  # FFN 1
543
  ffn_1_out = self.ffn_1(x)
544
  x = ffn_1_out * self.feed_forward_residual_factor + x
545
 
546
  # Attention
547
- if not isinstance(self.attention, MultiHeadedSelfAttentionModule):
548
- # MAMBA
549
- attn_out = self.attention(x)
550
- attn = None
551
- else:
552
- attn_out, attn = self.attention(x)
553
  x = attn_out + x
554
 
555
  if self.transformer_style:
@@ -575,10 +751,85 @@ class ConformerBlock(nn.Module):
575
  "attn": attn,
576
  "conv": conv_out,
577
  "ffn_2": ffn_2_out,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578
  }
579
 
580
  return x, other
581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
 
583
  class ConformerEncoder(nn.Module):
584
  def __init__(self, cfg):
@@ -599,7 +850,7 @@ class ConformerEncoder(nn.Module):
599
  )
600
  self.framewise_in_proj = nn.Sequential(
601
  Linear(
602
- cfg.conv_subsample_channels * (((cfg.input_dim - 1) // 2 - 1) // 2),
603
  cfg.encoder_dim,
604
  ),
605
  nn.Dropout(p=cfg.input_dropout_p),
@@ -619,7 +870,8 @@ class ConformerEncoder(nn.Module):
619
  nn.Dropout(p=cfg.input_dropout_p),
620
  )
621
  assert not cfg.use_framewise_subsample or (
622
- cfg.conv_subsample_rate == self.patchwise_subsample.subsample_rate
 
623
  ), (
624
  f"conv_subsample_rate ({cfg.conv_subsample_rate}) != patchwise_subsample.subsample_rate"
625
  f"({self.patchwise_subsample.subsample_rate})"
@@ -628,12 +880,21 @@ class ConformerEncoder(nn.Module):
628
  self.framewise_norm, self.patchwise_norm = None, None
629
  if getattr(cfg, "subsample_normalization", False):
630
  if cfg.use_framewise_subsample:
631
- self.framewise_norm = nn.LayerNorm(cfg.encoder_dim)
 
 
 
 
632
  if cfg.use_patchwise_subsample:
633
- self.patchwise_norm = nn.LayerNorm(cfg.encoder_dim)
 
 
 
 
634
 
635
  self.conv_pos = None
636
- if getattr(cfg, "conv_pos", False):
 
637
  num_pos_layers = cfg.conv_pos_depth
638
  k = max(3, cfg.conv_pos_width // num_pos_layers)
639
  self.conv_pos = nn.Sequential(
@@ -649,7 +910,9 @@ class ConformerEncoder(nn.Module):
649
  ),
650
  SamePad(k),
651
  TransposeLast(),
652
- nn.LayerNorm(cfg.encoder_dim, elementwise_affine=False),
 
 
653
  TransposeLast(),
654
  nn.GELU(),
655
  )
@@ -657,7 +920,15 @@ class ConformerEncoder(nn.Module):
657
  ],
658
  TransposeLast(),
659
  )
660
- self.conv_pos_post_ln = nn.LayerNorm(cfg.encoder_dim)
 
 
 
 
 
 
 
 
661
 
662
  self.layers = nn.ModuleList(
663
  [
@@ -665,10 +936,6 @@ class ConformerEncoder(nn.Module):
665
  encoder_dim=cfg.encoder_dim,
666
  attention_type=cfg.attention_type,
667
  num_attention_heads=cfg.num_attention_heads,
668
- mamba_d_state=cfg.mamba_d_state,
669
- mamba_d_conv=cfg.mamba_d_conv,
670
- mamba_expand=cfg.mamba_expand,
671
- mamba_bidirectional=cfg.mamba_bidirectional,
672
  feed_forward_expansion_factor=cfg.feed_forward_expansion_factor,
673
  conv_expansion_factor=cfg.conv_expansion_factor,
674
  feed_forward_dropout_p=cfg.feed_forward_dropout_p,
@@ -677,10 +944,29 @@ class ConformerEncoder(nn.Module):
677
  conv_kernel_size=cfg.conv_kernel_size,
678
  half_step_residual=cfg.half_step_residual,
679
  transformer_style=getattr(cfg, "transformer_style", False),
 
 
 
680
  )
681
  for _ in range(cfg.num_layers)
682
  ]
683
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
 
685
  def count_parameters(self) -> int:
686
  """Count parameters of encoder"""
@@ -696,6 +982,8 @@ class ConformerEncoder(nn.Module):
696
  self,
697
  inputs: torch.Tensor,
698
  input_lengths: Optional[torch.Tensor] = None,
 
 
699
  return_hidden: bool = False,
700
  freeze_input_layers: bool = False,
701
  target_layer: Optional[int] = None,
@@ -708,9 +996,13 @@ class ConformerEncoder(nn.Module):
708
  device=inputs.device,
709
  )
710
 
711
- with torch.no_grad() if freeze_input_layers else contextlib.ExitStack():
 
 
712
  frame_feat, patch_feat = None, None
 
713
  if self.framewise_subsample is not None:
 
714
  frame_feat, frame_lengths = self.framewise_subsample(
715
  inputs, input_lengths
716
  )
@@ -719,6 +1011,7 @@ class ConformerEncoder(nn.Module):
719
  frame_feat = self.framewise_norm(frame_feat)
720
 
721
  if self.patchwise_subsample is not None:
 
722
  patch_feat, patch_lengths = self.patchwise_subsample(
723
  inputs, input_lengths
724
  )
@@ -726,7 +1019,11 @@ class ConformerEncoder(nn.Module):
726
  if self.patchwise_norm is not None:
727
  patch_feat = self.patchwise_norm(patch_feat)
728
 
 
 
 
729
  if frame_feat is not None and patch_feat is not None:
 
730
  min_len = min(frame_feat.size(1), patch_feat.size(1))
731
  frame_feat = frame_feat[:, :min_len]
732
  patch_feat = patch_feat[:, :min_len]
@@ -744,21 +1041,55 @@ class ConformerEncoder(nn.Module):
744
  features = patch_feat
745
  output_lengths = patch_lengths
746
 
747
- if self.conv_pos is not None:
748
- features = features + self.conv_pos(features)
 
 
 
 
 
 
 
 
749
  features = self.conv_pos_post_ln(features)
750
 
751
- layer_results = defaultdict(list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
752
 
 
753
  outputs = features
 
754
  for i, layer in enumerate(self.layers):
755
- outputs, other = layer(outputs)
 
 
 
 
 
 
 
 
 
 
756
  if return_hidden:
757
  layer_results["hidden_states"].append(outputs)
758
  for k, v in other.items():
759
  layer_results[k].append(v)
760
 
761
- if target_layer is not None and i == target_layer:
762
  break
763
 
764
  return outputs, output_lengths, layer_results
 
1
+ # Reference: https://github.com/sooftware/conformer
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import contextlib
4
  import math
5
  from collections import defaultdict
6
+ from typing import Dict, List, Optional, Tuple
7
 
8
  import torch
9
  import torch.nn.functional as F
10
  from torch import nn
11
+ from torch.nn.attention import SDPBackend, sdpa_kernel
12
+
13
+
14
+ def lengths_to_padding_mask(
15
+ lengths: torch.Tensor, max_len: Optional[int] = None
16
+ ) -> torch.Tensor:
17
+ """Create padding mask from lengths.
18
+
19
+ Args:
20
+ lengths: A 1-D tensor of shape (B,).
21
+ max_len: An integer. It will be automatically set to the max value of lengths
22
+ if not given.
23
+
24
+ Returns:
25
+ A bool tensor of shape (B, max_len), where padded positions are indicated by True.
26
+ """
27
+ batch_size = lengths.size(0)
28
+ max_len = lengths.max().item() if max_len is None else max_len
29
+ seq_range = torch.arange(
30
+ 0, max_len, dtype=torch.long, device=lengths.device
31
+ )
32
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
33
+ lengths_expand = lengths.unsqueeze(1).expand_as(seq_range_expand)
34
+ padding_mask = seq_range_expand >= lengths_expand
35
+ return padding_mask
36
 
37
 
38
  class SamePad(nn.Module):
 
79
  return outputs * gate.sigmoid()
80
 
81
 
82
+ class RMSNorm(torch.nn.Module):
83
+ def __init__(self, dim: int, eps: float = 1e-5):
84
+ super().__init__()
85
+ self.eps = eps
86
+ self.weight = nn.Parameter(torch.ones(dim))
87
+
88
+ def _norm(self, x):
89
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
90
+
91
+ def forward(self, x):
92
+ output = self._norm(x.float()).type_as(x)
93
+ return output * self.weight
94
+
95
+
96
  class ResidualConnectionModule(nn.Module):
97
  def __init__(
98
  self,
 
106
  self.input_factor = input_factor
107
 
108
  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
109
+ return (self.module(inputs) * self.module_factor) + (
110
+ inputs * self.input_factor
111
+ )
112
 
113
 
114
  class Linear(nn.Module):
115
+ def __init__(
116
+ self, in_features: int, out_features: int, bias: bool = True
117
+ ) -> None:
118
  super(Linear, self).__init__()
119
  self.linear = nn.Linear(in_features, out_features, bias=bias)
120
  nn.init.xavier_uniform_(self.linear.weight)
 
153
  encoder_dim: int = 512,
154
  expansion_factor: int = 4,
155
  dropout_p: float = 0.1,
156
+ rms_norm: bool = False,
157
  ) -> None:
158
  super(FeedForwardModule, self).__init__()
159
  self.sequential = nn.Sequential(
160
+ (
161
+ nn.LayerNorm(encoder_dim)
162
+ if not rms_norm
163
+ else RMSNorm(encoder_dim)
164
+ ),
165
  Linear(encoder_dim, encoder_dim * expansion_factor, bias=True),
166
  Swish(),
167
  nn.Dropout(p=dropout_p),
 
231
  kernel_size: int = 31,
232
  expansion_factor: int = 2,
233
  dropout_p: float = 0.1,
234
+ rms_norm: bool = False,
235
  ) -> None:
236
  super(ConformerConvModule, self).__init__()
237
  assert (
238
  kernel_size - 1
239
  ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
240
+ assert (
241
+ expansion_factor == 2
242
+ ), "Currently, Only Supports expansion_factor 2"
243
 
244
  self.sequential = nn.Sequential(
245
+ (
246
+ nn.LayerNorm(in_channels)
247
+ if not rms_norm
248
+ else RMSNorm(in_channels)
249
+ ),
250
  Transpose(shape=(1, 2)),
251
  PointwiseConv1d(
252
  in_channels,
 
265
  ),
266
  nn.BatchNorm1d(in_channels),
267
  Swish(),
268
+ PointwiseConv1d(
269
+ in_channels, in_channels, stride=1, padding=0, bias=True
270
+ ),
271
  nn.Dropout(p=dropout_p),
272
  )
273
 
 
294
  )
295
 
296
  def forward(
297
+ self, inputs: torch.Tensor, input_lengths: torch.Tensor
298
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
299
  # inputs: (B, T, C) -> (B, 1, T, C)
300
  if self.subsample_rate == 2 and inputs.shape[1] % 2 == 0:
301
  inputs = F.pad(inputs, (0, 0, 0, 1), "constant", 0)
302
+ if self.subsample_rate == 4 and inputs.shape[1] % 4 < 3:
303
+ inputs = F.pad(
304
+ inputs, (0, 0, 0, 3 - inputs.shape[1] % 4), "constant", 0
305
+ )
306
  outputs = self.cnn(inputs.unsqueeze(1))
307
+ batch_size, channels, subsampled_lengths, sumsampled_dim = (
308
+ outputs.size()
309
+ )
310
 
311
  outputs = outputs.permute(0, 2, 1, 3)
312
  outputs = outputs.contiguous().view(
 
314
  )
315
 
316
  if self.subsample_rate == 4:
317
+ output_lengths = input_lengths >> 2
318
  else:
319
  output_lengths = input_lengths >> 1
320
 
321
  return outputs, output_lengths
322
 
323
+ def get_out_dim(self, input_dim: int) -> int:
324
+ # dummy input to get the output dimension
325
+ with torch.no_grad():
326
+ device = next(self.parameters()).device
327
+ inputs = torch.zeros(1, 16, input_dim, device=device)
328
+ input_lengths = torch.tensor([16], device=device)
329
+ outputs, _ = self.forward(inputs, input_lengths)
330
+ return outputs.size(-1)
331
+
332
 
333
  class PatchwiseConv2dSubampling(nn.Module):
334
  def __init__(
 
352
  padding=0,
353
  )
354
  self.cnn = nn.Sequential(
355
+ nn.Conv2d(
356
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
357
+ ),
358
  nn.ReLU(),
359
+ nn.Conv2d(
360
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
361
+ ),
362
  nn.ReLU(),
363
  )
364
 
 
367
  return self.patch_size_time * self.patch_size_freq // self.mel_dim
368
 
369
  def forward(
370
+ self, inputs: torch.Tensor, input_lengths: torch.Tensor
371
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
372
  assert (
373
  inputs.shape[2] == self.mel_dim
374
  ), "inputs.shape[2] should be equal to mel_dim"
 
390
 
391
 
392
  class RelPositionalEncoding(nn.Module):
393
+ def __init__(self, d_model: int) -> None:
394
  super(RelPositionalEncoding, self).__init__()
395
  self.d_model = d_model
396
  self.pe = None
 
397
 
398
  def extend_pe(self, x: torch.Tensor) -> None:
399
  if self.pe is not None:
 
402
  self.pe = self.pe.to(dtype=x.dtype, device=x.device)
403
  return
404
 
405
+ length = x.size(1)
406
+ pe_positive = torch.zeros(length, self.d_model, device="cpu")
407
+ pe_negative = torch.zeros(length, self.d_model, device="cpu")
408
+ position = torch.arange(
409
+ 0, length, dtype=torch.float32, device="cpu"
410
+ ).unsqueeze(1)
411
  div_term = torch.exp(
412
+ torch.arange(0, self.d_model, 2, dtype=torch.float32, device="cpu")
413
  * -(math.log(10000.0) / self.d_model)
414
  )
415
  pe_positive[:, 0::2] = torch.sin(position * div_term)
 
425
  def forward(self, x: torch.Tensor) -> torch.Tensor:
426
  # x: (B, T, C)
427
  self.extend_pe(x)
428
+ assert self.pe is not None
429
  pos_emb = self.pe[
430
  :,
431
+ self.pe.size(1) // 2
432
+ - x.size(1)
433
+ + 1 : self.pe.size(1) // 2
434
+ + x.size(1),
435
  ]
436
  return pos_emb
437
 
 
463
 
464
  self.out_proj = Linear(d_model, d_model)
465
 
466
+ @staticmethod
467
+ def _relative_shift(pos_score: torch.Tensor) -> torch.Tensor:
468
+ # pos_score: (B, H, T, 2T-1)
469
+ B, H, T, L = pos_score.size()
470
+
471
+ # Pad on the left of the last dimension: (B, H, T, 2T)
472
+ pos_score = F.pad(pos_score, (1, 0))
473
+
474
+ # Reshape to (B, H, 2T, T)
475
+ pos_score = pos_score.view(B, H, L + 1, T)
476
+
477
+ # Slice and reshape back to (B, H, T, 2T-1)
478
+ pos_score = pos_score[:, :, 1:].view(B, H, T, L)
479
+
480
+ # Keep only first T positions => (B, H, T, T)
481
+ return pos_score[:, :, :, : (L // 2 + 1)]
482
+
483
  def forward(
484
  self,
485
  query: torch.Tensor,
486
  key: torch.Tensor,
487
  value: torch.Tensor,
488
  pos_embedding: torch.Tensor,
489
+ padding_mask: Optional[torch.Tensor] = None,
490
+ *,
491
+ need_weights: bool = False,
492
+ use_sdpa: Optional[bool] = None,
493
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
494
+ """
495
+ - If need_weights=True: returns (output, attn) like your original code.
496
+ - If need_weights=False: returns (output, None) and uses SDPA in eval for speed/memory.
497
+ """
498
+ B, Tq, _ = query.size()
499
+ _, Tk, _ = key.size()
500
+
501
+ # Project
502
+ q = self.query_proj(query) # (B, Tq, C)
503
+ k = self.key_proj(key) # (B, Tk, C)
504
+ v = self.value_proj(value) # (B, Tk, C)
505
+
506
+ # Reshape to (B, H, T, Dh)
507
+ q = q.view(B, Tq, self.num_heads, self.d_head).transpose(
508
+ 1, 2
509
+ ) # (B,H,Tq,Dh)
510
+ k = k.view(B, Tk, self.num_heads, self.d_head).transpose(
511
+ 1, 2
512
+ ) # (B,H,Tk,Dh)
513
+ v = v.view(B, Tk, self.num_heads, self.d_head).transpose(
514
+ 1, 2
515
+ ) # (B,H,Tk,Dh)
516
+
517
+ # Positional projection.
518
+ # IMPORTANT: allow pos_embedding to be (1, 2T-1, C) and broadcast across batch.
519
+ # pos_embedding expected length: 2Tq - 1 for self-attn.
520
+ pB = pos_embedding.size(0)
521
+ p = self.pos_proj(pos_embedding) # (pB, L, C)
522
+ p = p.view(pB, -1, self.num_heads, self.d_head).transpose(
523
+ 1, 2
524
+ ) # (pB,H,L,Dh)
525
+
526
+ # Compute position-based bias (scaled) to feed SDPA or add to scores
527
+ # q_pos: (B,H,Tq,Dh), p^T: (pB,H,Dh,L) -> broadcast on pB if pB==1
528
+ q_pos = q + self.v_bias.unsqueeze(0).unsqueeze(2) # (B,H,Tq,Dh)
529
+ pos_score = torch.matmul(q_pos, p.transpose(-2, -1)) # (B,H,Tq,L)
530
+ pos_bias = self._relative_shift(pos_score) # (B,H,Tq,Tq) for self-attn
531
+ pos_bias = pos_bias.mul(
532
+ 1.0 / self.sqrt_dim
533
+ ) # scale matches SDPA scaling
534
+
535
+ if padding_mask is not None:
536
+ # padding_mask: (B, T) -> (B, 1, 1, T) to broadcast with pos_bias (B, H, Tq, Tk)
537
+ # This masks out key positions that are padded across all heads and queries
538
+ if padding_mask.dtype != torch.bool:
539
+ padding_mask = padding_mask.to(torch.bool)
540
+ pos_bias = pos_bias.masked_fill(
541
+ padding_mask[:, None, None, :], -1e9
542
+ )
543
 
544
+ if use_sdpa is None:
545
+ use_sdpa = (not self.training) and (not need_weights)
546
+
547
+ # ---- Fast inference path: no attention matrix materialized ----
548
+ if use_sdpa:
549
+ # Content term uses u_bias
550
+ q_content = q + self.u_bias.unsqueeze(0).unsqueeze(
551
+ 2
552
+ ) # (B,H,Tq,Dh)
553
+
554
+ with sdpa_kernel(
555
+ [
556
+ SDPBackend.FLASH_ATTENTION,
557
+ SDPBackend.EFFICIENT_ATTENTION,
558
+ SDPBackend.MATH,
559
+ ]
560
+ ):
561
+ out = F.scaled_dot_product_attention(
562
+ q_content, # (B,H,Tq,Dh)
563
+ k, # (B,H,Tk,Dh)
564
+ v, # (B,H,Tk,Dh)
565
+ attn_mask=pos_bias, # (B,H,Tq,Tk) additive bias
566
+ dropout_p=0.0, # dropout disabled in inference
567
+ is_causal=False,
568
+ ) # (BH, Tq, Dh)
569
+
570
+ out = out.transpose(1, 2).contiguous().view(B, Tq, self.d_model)
571
+
572
+ return self.out_proj(out), None
573
+
574
+ # ---- Reference path (training / if you need attn weights): matches your math ----
575
+ q_content = q + self.u_bias.unsqueeze(0).unsqueeze(2) # (B,H,Tq,Dh)
576
  content_score = torch.matmul(
577
+ q_content, k.transpose(-2, -1)
578
+ ) # (B,H,Tq,Tk)
579
+ content_score = content_score.mul(1.0 / self.sqrt_dim)
 
 
 
 
 
 
580
 
581
+ score = content_score + pos_bias # already scaled
 
 
582
 
583
+ attn = F.softmax(score, dim=-1)
584
  attn = self.dropout(attn)
585
 
586
+ context = torch.matmul(attn, v) # (B,H,Tq,Dh)
587
+ context = (
588
+ context.transpose(1, 2).contiguous().view(B, Tq, self.d_model)
 
 
 
 
 
 
 
 
 
589
  )
 
 
 
590
 
591
+ return self.out_proj(context), attn
592
 
593
 
594
  class MultiHeadedSelfAttentionModule(nn.Module):
595
+ def __init__(
596
+ self,
597
+ d_model: int,
598
+ num_heads: int,
599
+ dropout_p: float = 0.1,
600
+ rms_norm: bool = False,
601
+ ):
602
  super(MultiHeadedSelfAttentionModule, self).__init__()
603
  self.positional_encoding = RelPositionalEncoding(d_model)
604
+ self.layer_norm = (
605
+ nn.LayerNorm(d_model) if not rms_norm else RMSNorm(d_model)
606
+ )
607
+ self.attention = RelativeMultiHeadAttention(
608
+ d_model, num_heads, dropout_p
609
+ )
610
  self.dropout = nn.Dropout(p=dropout_p)
611
 
612
  def forward(
613
+ self,
614
+ inputs: torch.Tensor,
615
+ padding_mask: Optional[torch.Tensor] = None,
616
+ pos_embedding: Optional[torch.Tensor] = None,
617
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
618
+ if pos_embedding is None:
619
+ pos_embedding = self.positional_encoding(inputs)
620
 
621
  inputs = self.layer_norm(inputs)
622
  outputs, attn = self.attention(
623
+ inputs,
624
+ inputs,
625
+ inputs,
626
+ pos_embedding=pos_embedding,
627
+ padding_mask=padding_mask,
628
  )
629
 
630
+ return self.dropout(outputs), attn, pos_embedding
631
 
632
 
633
  class ConformerBlock(nn.Module):
 
636
  encoder_dim: int = 512,
637
  attention_type: str = "mhsa",
638
  num_attention_heads: int = 8,
 
 
 
 
639
  feed_forward_expansion_factor: int = 4,
640
  conv_expansion_factor: int = 2,
641
  feed_forward_dropout_p: float = 0.1,
 
644
  conv_kernel_size: int = 31,
645
  half_step_residual: bool = True,
646
  transformer_style: bool = False,
647
+ usad_v2: bool = False,
648
+ pre_norm: bool = False,
649
+ rms_norm: bool = False,
650
  ):
651
  super(ConformerBlock, self).__init__()
652
 
653
  self.transformer_style = transformer_style
654
  self.attention_type = attention_type
655
+ self.usad_v2 = usad_v2
656
+ self.pre_norm = pre_norm
657
 
658
  if half_step_residual and not transformer_style:
659
  self.feed_forward_residual_factor = 0.5
660
  else:
661
  self.feed_forward_residual_factor = 1
662
 
663
+ assert (
664
+ attention_type == "mhsa"
665
+ ), "Only 'mhsa' attention is supported in this implementation."
666
+ attention = MultiHeadedSelfAttentionModule(
667
+ d_model=encoder_dim,
668
+ num_heads=num_attention_heads,
669
+ dropout_p=attention_dropout_p,
670
+ rms_norm=rms_norm,
671
+ )
672
 
673
  self.ffn_1 = FeedForwardModule(
674
  encoder_dim=encoder_dim,
675
  expansion_factor=feed_forward_expansion_factor,
676
  dropout_p=feed_forward_dropout_p,
677
+ rms_norm=rms_norm,
678
  )
679
  self.attention = attention
680
  if not transformer_style:
 
683
  kernel_size=conv_kernel_size,
684
  expansion_factor=conv_expansion_factor,
685
  dropout_p=conv_dropout_p,
686
+ rms_norm=rms_norm,
687
  )
688
  self.ffn_2 = FeedForwardModule(
689
  encoder_dim=encoder_dim,
690
  expansion_factor=feed_forward_expansion_factor,
691
  dropout_p=feed_forward_dropout_p,
692
+ rms_norm=rms_norm,
693
  )
694
+ self.layernorm = (
695
+ (
696
+ nn.LayerNorm(encoder_dim)
697
+ if not rms_norm
698
+ else RMSNorm(encoder_dim)
699
+ )
700
+ if not pre_norm
701
+ else nn.Identity()
702
+ )
703
 
704
+ def forward_attention(
705
+ self,
706
+ x: torch.Tensor,
707
+ pos_embedding: Optional[torch.Tensor] = None,
708
+ padding_mask: Optional[torch.Tensor] = None,
709
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
710
+ attn_out, attn, pos_embedding = self.attention(
711
+ x, pos_embedding=pos_embedding, padding_mask=padding_mask
712
+ )
713
+ return attn_out, attn, pos_embedding
714
+
715
+ def forward_legacy(
716
+ self,
717
+ x: torch.Tensor,
718
+ pos_embedding: Optional[torch.Tensor] = None,
719
+ padding_mask: Optional[torch.Tensor] = None,
720
+ ) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]:
721
  # FFN 1
722
  ffn_1_out = self.ffn_1(x)
723
  x = ffn_1_out * self.feed_forward_residual_factor + x
724
 
725
  # Attention
726
+ attn_out, attn, pos_embedding = self.forward_attention(
727
+ x, pos_embedding, padding_mask
728
+ )
 
 
 
729
  x = attn_out + x
730
 
731
  if self.transformer_style:
 
751
  "attn": attn,
752
  "conv": conv_out,
753
  "ffn_2": ffn_2_out,
754
+ "pos_embedding": pos_embedding,
755
+ }
756
+
757
+ return x, other
758
+
759
+ def forward_transformer(
760
+ self,
761
+ x: torch.Tensor,
762
+ pos_embedding: Optional[torch.Tensor] = None,
763
+ padding_mask: Optional[torch.Tensor] = None,
764
+ ) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]:
765
+ # Attention
766
+ attn_out, attn, pos_embedding = self.forward_attention(
767
+ x, pos_embedding, padding_mask
768
+ )
769
+ x = attn_out + x
770
+
771
+ # FFN
772
+ ffn_out = self.ffn_1(x)
773
+ x = ffn_out * self.feed_forward_residual_factor + x
774
+
775
+ x = self.layernorm(x)
776
+ return x, {
777
+ "ffn_1": ffn_out,
778
+ "attn": attn,
779
+ "conv": None,
780
+ "ffn_2": None,
781
+ "pos_embedding": pos_embedding,
782
+ }
783
+
784
+ def forward_conformer(
785
+ self,
786
+ x: torch.Tensor,
787
+ pos_embedding: Optional[torch.Tensor] = None,
788
+ padding_mask: Optional[torch.Tensor] = None,
789
+ ) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]:
790
+ # FFN 1
791
+ ffn_1_out = self.ffn_1(x)
792
+ x = ffn_1_out * self.feed_forward_residual_factor + x
793
+
794
+ # Attention
795
+ attn_out, attn, pos_embedding = self.forward_attention(
796
+ x, pos_embedding, padding_mask
797
+ )
798
+ x = attn_out + x
799
+
800
+ # Convolution
801
+ conv_out = self.conv(x)
802
+ x = conv_out + x
803
+
804
+ # FFN 2
805
+ ffn_2_out = self.ffn_2(x)
806
+ x = ffn_2_out * self.feed_forward_residual_factor + x
807
+ x = self.layernorm(x)
808
+
809
+ other = {
810
+ "ffn_1": ffn_1_out,
811
+ "attn": attn,
812
+ "conv": conv_out,
813
+ "ffn_2": ffn_2_out,
814
+ "pos_embedding": pos_embedding,
815
  }
816
 
817
  return x, other
818
 
819
+ def forward(
820
+ self,
821
+ x: torch.Tensor,
822
+ pos_embedding: Optional[torch.Tensor] = None,
823
+ padding_mask: Optional[torch.Tensor] = None,
824
+ ) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]:
825
+ if not self.usad_v2:
826
+ return self.forward_legacy(x, pos_embedding, padding_mask)
827
+
828
+ if self.transformer_style:
829
+ return self.forward_transformer(x, pos_embedding, padding_mask)
830
+
831
+ return self.forward_conformer(x, pos_embedding, padding_mask)
832
+
833
 
834
  class ConformerEncoder(nn.Module):
835
  def __init__(self, cfg):
 
850
  )
851
  self.framewise_in_proj = nn.Sequential(
852
  Linear(
853
+ self.framewise_subsample.get_out_dim(cfg.input_dim),
854
  cfg.encoder_dim,
855
  ),
856
  nn.Dropout(p=cfg.input_dropout_p),
 
870
  nn.Dropout(p=cfg.input_dropout_p),
871
  )
872
  assert not cfg.use_framewise_subsample or (
873
+ cfg.conv_subsample_rate
874
+ == self.patchwise_subsample.subsample_rate
875
  ), (
876
  f"conv_subsample_rate ({cfg.conv_subsample_rate}) != patchwise_subsample.subsample_rate"
877
  f"({self.patchwise_subsample.subsample_rate})"
 
880
  self.framewise_norm, self.patchwise_norm = None, None
881
  if getattr(cfg, "subsample_normalization", False):
882
  if cfg.use_framewise_subsample:
883
+ self.framewise_norm = (
884
+ nn.LayerNorm(cfg.encoder_dim)
885
+ if not getattr(cfg, "rms_norm", False)
886
+ else RMSNorm(cfg.encoder_dim)
887
+ )
888
  if cfg.use_patchwise_subsample:
889
+ self.patchwise_norm = (
890
+ nn.LayerNorm(cfg.encoder_dim)
891
+ if not getattr(cfg, "rms_norm", False)
892
+ else RMSNorm(cfg.encoder_dim)
893
+ )
894
 
895
  self.conv_pos = None
896
+ self.conv_pos_post_ln = None
897
+ if cfg.conv_pos:
898
  num_pos_layers = cfg.conv_pos_depth
899
  k = max(3, cfg.conv_pos_width // num_pos_layers)
900
  self.conv_pos = nn.Sequential(
 
910
  ),
911
  SamePad(k),
912
  TransposeLast(),
913
+ nn.LayerNorm(
914
+ cfg.encoder_dim, elementwise_affine=False
915
+ ),
916
  TransposeLast(),
917
  nn.GELU(),
918
  )
 
920
  ],
921
  TransposeLast(),
922
  )
923
+ self.conv_pos_post_ln = (
924
+ (
925
+ nn.LayerNorm(cfg.encoder_dim)
926
+ if not getattr(cfg, "rms_norm", False)
927
+ else RMSNorm(cfg.encoder_dim)
928
+ )
929
+ if not getattr(cfg, "pre_norm", False)
930
+ else nn.Identity()
931
+ )
932
 
933
  self.layers = nn.ModuleList(
934
  [
 
936
  encoder_dim=cfg.encoder_dim,
937
  attention_type=cfg.attention_type,
938
  num_attention_heads=cfg.num_attention_heads,
 
 
 
 
939
  feed_forward_expansion_factor=cfg.feed_forward_expansion_factor,
940
  conv_expansion_factor=cfg.conv_expansion_factor,
941
  feed_forward_dropout_p=cfg.feed_forward_dropout_p,
 
944
  conv_kernel_size=cfg.conv_kernel_size,
945
  half_step_residual=cfg.half_step_residual,
946
  transformer_style=getattr(cfg, "transformer_style", False),
947
+ usad_v2=getattr(cfg, "usad_v2", False),
948
+ pre_norm=getattr(cfg, "pre_norm", False),
949
+ rms_norm=getattr(cfg, "rms_norm", False),
950
  )
951
  for _ in range(cfg.num_layers)
952
  ]
953
  )
954
+ self.layerdrop_p = getattr(cfg, "layerdrop_p", 0.0)
955
+
956
+ if cfg.attention_type == "mhsa" and len(self.layers) > 0:
957
+ # Share positional encoding across layers
958
+ shared_pos = None
959
+ for layer in self.layers:
960
+ if isinstance(layer.attention, MultiHeadedSelfAttentionModule):
961
+ if shared_pos is None:
962
+ shared_pos = layer.attention.positional_encoding
963
+ else:
964
+ layer.attention.positional_encoding = shared_pos
965
+ if shared_pos is not None:
966
+ # precompute positional encodings
967
+ # expecting most mel inputs to be fewer than 2000 frames (20 seconds)
968
+ max_len = 2000 // cfg.conv_subsample_rate
969
+ shared_pos.extend_pe(torch.tensor(0.0).expand(1, max_len))
970
 
971
  def count_parameters(self) -> int:
972
  """Count parameters of encoder"""
 
982
  self,
983
  inputs: torch.Tensor,
984
  input_lengths: Optional[torch.Tensor] = None,
985
+ padding_mask: Optional[torch.Tensor] = None,
986
+ *,
987
  return_hidden: bool = False,
988
  freeze_input_layers: bool = False,
989
  target_layer: Optional[int] = None,
 
996
  device=inputs.device,
997
  )
998
 
999
+ with (
1000
+ torch.no_grad() if freeze_input_layers else contextlib.ExitStack()
1001
+ ):
1002
  frame_feat, patch_feat = None, None
1003
+ frame_lengths, patch_lengths = None, None
1004
  if self.framewise_subsample is not None:
1005
+ assert self.framewise_in_proj is not None
1006
  frame_feat, frame_lengths = self.framewise_subsample(
1007
  inputs, input_lengths
1008
  )
 
1011
  frame_feat = self.framewise_norm(frame_feat)
1012
 
1013
  if self.patchwise_subsample is not None:
1014
+ assert self.patchwise_in_proj is not None
1015
  patch_feat, patch_lengths = self.patchwise_subsample(
1016
  inputs, input_lengths
1017
  )
 
1019
  if self.patchwise_norm is not None:
1020
  patch_feat = self.patchwise_norm(patch_feat)
1021
 
1022
+ assert frame_feat is not None or patch_feat is not None
1023
+ assert frame_lengths is not None or patch_lengths is not None
1024
+
1025
  if frame_feat is not None and patch_feat is not None:
1026
+ assert frame_lengths is not None and patch_lengths is not None
1027
  min_len = min(frame_feat.size(1), patch_feat.size(1))
1028
  frame_feat = frame_feat[:, :min_len]
1029
  patch_feat = patch_feat[:, :min_len]
 
1041
  features = patch_feat
1042
  output_lengths = patch_lengths
1043
 
1044
+ assert features is not None
1045
+ assert output_lengths is not None
1046
+
1047
+ # Positional encoding with convolutional layers
1048
+ if self.conv_pos is not None and self.conv_pos_post_ln is not None:
1049
+ pos = self.conv_pos(features)
1050
+ if not self.training:
1051
+ features = features.add_(pos)
1052
+ else:
1053
+ features = features + pos
1054
  features = self.conv_pos_post_ln(features)
1055
 
1056
+ # Create padding mask for attention
1057
+ if padding_mask is not None:
1058
+ # downsample to match features length
1059
+ input_len = padding_mask.size(1)
1060
+ feat_len = features.size(1)
1061
+ factor = input_len / feat_len
1062
+ indices = (
1063
+ torch.arange(feat_len, device=padding_mask.device) * factor
1064
+ ).long()
1065
+ padding_mask = padding_mask.index_select(1, indices)
1066
+ else:
1067
+ # create from output_lengths
1068
+ padding_mask = lengths_to_padding_mask(
1069
+ output_lengths, max_len=features.size(1)
1070
+ )
1071
 
1072
+ layer_results = defaultdict(list)
1073
  outputs = features
1074
+ other = {}
1075
  for i, layer in enumerate(self.layers):
1076
+ if (
1077
+ self.training
1078
+ and self.layerdrop_p > 0
1079
+ and torch.rand(1).item() < self.layerdrop_p
1080
+ ):
1081
+ continue
1082
+ outputs, other = layer(
1083
+ outputs,
1084
+ pos_embedding=other.get("pos_embedding"),
1085
+ padding_mask=padding_mask,
1086
+ )
1087
  if return_hidden:
1088
  layer_results["hidden_states"].append(outputs)
1089
  for k, v in other.items():
1090
  layer_results[k].append(v)
1091
 
1092
+ if target_layer is not None and i + 1 == target_layer:
1093
  break
1094
 
1095
  return outputs, output_lengths, layer_results