marcoyang commited on
Commit
4e3bc09
·
1 Parent(s): 6540ba2
1284-1180-0027.flac ADDED
Binary file (72.5 kB). View file
 
envirnoment.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch=2.1.1
2
+ lhotse=1.28.0
inference_600m.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+
4
+ from model import MultiKDModel
5
+ from scaling import ScheduledFloat
6
+ from subsampling import Conv2dSubsampling
7
+ from zipformer import Zipformer2
8
+
9
+ from lhotse import Fbank, FbankConfig
10
+ import torchaudio
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ LOG_EPS = math.log(1e-10)
15
+
16
+ class ZipformerConfig:
17
+ def __init__(self):
18
+ # 用 _config 存储所有参数
19
+ self._config = {
20
+ "feature_dim": 128,
21
+ "pos_dim": 48,
22
+ "output_downsampling_factor": 2,
23
+ "downsampling_factor": "1,2,4,8,4,2",
24
+ "num_encoder_layers": "2,2,3,4,3,2",
25
+ "feedforward_dim": "512,768,1024,1536,1024,768",
26
+ "encoder_dim": "192,256,448,768,448,192",
27
+ "encoder_unmasked_dim": "192,192,256,256,256,192",
28
+ "cnn_module_kernel": "31,31,15,15,15,31",
29
+ "num_heads": "4,4,4,8,4,4",
30
+ "causal": True,
31
+ }
32
+
33
+ def __getattr__(self, key):
34
+ if key in self._config:
35
+ return self._config[key]
36
+ raise AttributeError(f"'ZipformerConfig' object has no attribute '{key}'")
37
+
38
+ def __setattr__(self, key, value):
39
+ if key == "_config":
40
+ super().__setattr__(key, value)
41
+ else:
42
+ self._config[key] = value
43
+
44
+ def __delattr__(self, key):
45
+ if key in self._config:
46
+ del self._config[key]
47
+ else:
48
+ raise AttributeError(f"'ZipformerConfig' object has no attribute '{key}'")
49
+
50
+ def to_dict(self):
51
+ return dict(self._config)
52
+
53
+ def __repr__(self):
54
+ return f"ZipformerConfig({self._config})"
55
+
56
+
57
+ def str2bool(v):
58
+ """Used in argparse.ArgumentParser.add_argument to indicate
59
+ that a type is a bool type and user can enter
60
+
61
+ - yes, true, t, y, 1, to represent True
62
+ - no, false, f, n, 0, to represent False
63
+
64
+ See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
65
+ """
66
+ if isinstance(v, bool):
67
+ return v
68
+ if v.lower() in ("yes", "true", "t", "y", "1"):
69
+ return True
70
+ elif v.lower() in ("no", "false", "f", "n", "0"):
71
+ return False
72
+ else:
73
+ raise argparse.ArgumentTypeError("Boolean value expected.")
74
+
75
+ def get_parser():
76
+ parser = argparse.ArgumentParser(
77
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
78
+ )
79
+
80
+ parser.add_argument(
81
+ "--model-version",
82
+ type=str,
83
+ default="600m_uniform_out_ds1",
84
+ )
85
+
86
+ parser.add_argument(
87
+ "--causal",
88
+ type=str2bool,
89
+ default=False,
90
+ help="If True, use causal version of model.",
91
+ )
92
+
93
+ parser.add_argument(
94
+ "--chunk-size",
95
+ type=str,
96
+ default="16,32,64,-1",
97
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
98
+ " Must be just -1 if --causal=False",
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--left-context-frames",
103
+ type=str,
104
+ default="64,128,256,-1",
105
+ help="Maximum left-contexts for causal training, measured in frames which will "
106
+ "be converted to a number of chunks. If splitting into chunks, "
107
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
108
+ )
109
+
110
+ parser.add_argument(
111
+ "--ckpt-path",
112
+ type=str,
113
+ required=True,
114
+ )
115
+
116
+ parser.add_argument(
117
+ "--audio",
118
+ type=str,
119
+ required=True,
120
+ help="The path to the audio"
121
+ )
122
+
123
+ return parser
124
+
125
+ def _to_int_tuple(s: str):
126
+ return tuple(map(int, s.split(",")))
127
+
128
+ def get_encoder_embed(params) -> nn.Module:
129
+ encoder_embed = Conv2dSubsampling(
130
+ in_channels=128,
131
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
132
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
133
+ )
134
+ return encoder_embed
135
+
136
+ def get_encoder_model(params) -> nn.Module:
137
+ encoder = Zipformer2(
138
+ output_downsampling_factor=params.output_downsampling_factor,
139
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
140
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
141
+ encoder_dim=_to_int_tuple(params.encoder_dim),
142
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
143
+ query_head_dim=_to_int_tuple("32"),
144
+ pos_head_dim=_to_int_tuple("4"),
145
+ value_head_dim=_to_int_tuple("12"),
146
+ pos_dim=params.pos_dim,
147
+ num_heads=_to_int_tuple(params.num_heads),
148
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
149
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
150
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
151
+ warmup_batches=4000.0,
152
+ causal=params.causal,
153
+ chunk_size=_to_int_tuple(params.chunk_size),
154
+ left_context_frames=_to_int_tuple(params.left_context_frames),
155
+ )
156
+ return encoder
157
+
158
+ def get_params(args):
159
+ params = ZipformerConfig()
160
+ params.chunk_size = args.chunk_size
161
+ params.left_context_frames = args.left_context_frames
162
+
163
+ model_version = args.model_version
164
+ if model_version == "600m_uniform_out_ds1":
165
+ params.output_downsampling_factor = 1
166
+ params.downsampling_factor = "1,2,4,8,4,2,1"
167
+ params.num_encoder_layers = "1,2,3,4,1,1,1"
168
+ params.feedforward_dim = "3840,3840,3840,3840,3840,3840,3840"
169
+ params.encoder_dim = "1280,1280,1280,1280,1280,1280,1280"
170
+ params.encoder_unmasked_dim = "768,768,768,768,768,768,768"
171
+ params.cnn_module_kernel = "31,31,15,15,15,31,31"
172
+ params.num_heads = "8,8,8,8,8,8,8"
173
+ elif model_version == "600m_uniform_out_ds2":
174
+ params.output_downsampling_factor = 2
175
+ params.downsampling_factor = "1,2,4,8,4,2,1"
176
+ params.num_encoder_layers = "1,2,3,4,1,1,1"
177
+ params.feedforward_dim = "3840,3840,3840,3840,3840,3840,3840"
178
+ params.encoder_dim = "1280,1280,1280,1280,1280,1280,1280"
179
+ params.encoder_unmasked_dim = "768,768,768,768,768,768,768"
180
+ params.cnn_module_kernel = "31,31,15,15,15,31,31"
181
+ params.num_heads = "8,8,8,8,8,8,8"
182
+ else:
183
+ raise ValueError()
184
+ return params
185
+
186
+ def get_model(model_version) -> nn.Module:
187
+ # initialise the encoder model
188
+
189
+ params = get_params(model_version)
190
+ encoder_embed = get_encoder_embed(params)
191
+ encoder = get_encoder_model(params)
192
+ print(params)
193
+
194
+ model = MultiKDModel(
195
+ encoder_embed=encoder_embed,
196
+ encoder=encoder,
197
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
198
+ num_codebooks=0,
199
+ )
200
+
201
+ return model
202
+
203
+ def main(args):
204
+ device = torch.device("cpu")
205
+ if torch.cuda.is_available():
206
+ device = torch.device("cuda")
207
+
208
+ # load model
209
+ model = get_model(args)
210
+ model.to(device)
211
+
212
+ info = model.load_state_dict(
213
+ torch.load(args.ckpt_path)["model"], strict=False
214
+ )
215
+ print(info)
216
+ model.eval()
217
+
218
+ # fbank extractor
219
+ extractor = Fbank(FbankConfig(num_mel_bins=128))
220
+
221
+ # load audio
222
+ audio, fs = torchaudio.load(args.audio)
223
+ assert fs == 16000
224
+ audio_lens = audio.shape[1]
225
+ audios = audio.squeeze()
226
+ feature = [extractor.extract(audios, sampling_rate=fs)]
227
+ feature_lens = [f.size(0) for f in feature]
228
+
229
+ feature = torch.nn.utils.rnn.pad_sequence(feature, batch_first=True, padding_value=LOG_EPS).to(device)
230
+ feature_lens = torch.tensor(feature_lens, device=device)
231
+
232
+ # batch inference
233
+ encoder_out, encoder_out_lens = model.forward_encoder(
234
+ feature,
235
+ feature_lens,
236
+ )
237
+ print(encoder_out)
238
+ print(encoder_out.shape)
239
+
240
+ if __name__=="__main__":
241
+ parser = get_parser()
242
+ args = parser.parse_args()
243
+
244
+ main(args)
inference_600m.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ export PYTHONPATH=./../../../:$PYTHONPATH
4
+
5
+ model_version=600m_uniform_out_ds1
6
+ causal=1
7
+ left_context_frames=256
8
+ chunk_size=8
9
+
10
+ python inference_600m.py \
11
+ --model-version $model_version \
12
+ --ckpt-path iter-400000-avg-4.pt \
13
+ --causal $causal \
14
+ --left-context-frames $left_context_frames \
15
+ --chunk-size $chunk_size \
16
+ --audio 1284-1180-0027.flac
inference_600m_streaming_forward.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ from model import MultiKDModel
6
+ from scaling import ScheduledFloat
7
+ from subsampling import Conv2dSubsampling
8
+ from zipformer import Zipformer2
9
+
10
+ from lhotse import Fbank, FbankConfig
11
+ import torchaudio
12
+ import torch
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+
16
+ from utilities import make_pad_mask, str2bool, ZipformerConfig
17
+
18
+ LOG_EPS = math.log(1e-10)
19
+
20
+ def get_parser():
21
+ parser = argparse.ArgumentParser(
22
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
23
+ )
24
+
25
+ parser.add_argument(
26
+ "--model-version",
27
+ type=str,
28
+ default="600m_uniform_out_ds1",
29
+ )
30
+
31
+ parser.add_argument(
32
+ "--causal",
33
+ type=str2bool,
34
+ default=False,
35
+ help="If True, use causal version of model.",
36
+ )
37
+
38
+ parser.add_argument(
39
+ "--chunk-size",
40
+ type=str,
41
+ default="16,32,64,-1",
42
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
43
+ " Must be just -1 if --causal=False",
44
+ )
45
+
46
+ parser.add_argument(
47
+ "--left-context-frames",
48
+ type=str,
49
+ default="64,128,256,-1",
50
+ help="Maximum left-contexts for causal training, measured in frames which will "
51
+ "be converted to a number of chunks. If splitting into chunks, "
52
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--ckpt-path",
57
+ type=str,
58
+ required=True,
59
+ )
60
+
61
+ parser.add_argument(
62
+ "--audio",
63
+ type=str,
64
+ required=True,
65
+ help="The path to the audio"
66
+ )
67
+
68
+ return parser
69
+
70
+ def _to_int_tuple(s: str):
71
+ return tuple(map(int, s.split(",")))
72
+
73
+ def get_encoder_embed(params) -> nn.Module:
74
+ encoder_embed = Conv2dSubsampling(
75
+ in_channels=128,
76
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
77
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
78
+ )
79
+ return encoder_embed
80
+
81
+ def get_encoder_model(params) -> nn.Module:
82
+ encoder = Zipformer2(
83
+ output_downsampling_factor=params.output_downsampling_factor,
84
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
85
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
86
+ encoder_dim=_to_int_tuple(params.encoder_dim),
87
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
88
+ query_head_dim=_to_int_tuple("32"),
89
+ pos_head_dim=_to_int_tuple("4"),
90
+ value_head_dim=_to_int_tuple("12"),
91
+ pos_dim=params.pos_dim,
92
+ num_heads=_to_int_tuple(params.num_heads),
93
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
94
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
95
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
96
+ warmup_batches=4000.0,
97
+ causal=params.causal,
98
+ chunk_size=_to_int_tuple(params.chunk_size),
99
+ left_context_frames=_to_int_tuple(params.left_context_frames),
100
+ )
101
+ return encoder
102
+
103
+ def get_params(args):
104
+ params = ZipformerConfig()
105
+ params.chunk_size = args.chunk_size
106
+ params.left_context_frames = args.left_context_frames
107
+
108
+ model_version = args.model_version
109
+ if model_version == "600m_uniform_out_ds1":
110
+ params.output_downsampling_factor = 1
111
+ params.downsampling_factor = "1,2,4,8,4,2,1"
112
+ params.num_encoder_layers = "1,2,3,4,1,1,1"
113
+ params.feedforward_dim = "3840,3840,3840,3840,3840,3840,3840"
114
+ params.encoder_dim = "1280,1280,1280,1280,1280,1280,1280"
115
+ params.encoder_unmasked_dim = "768,768,768,768,768,768,768"
116
+ params.cnn_module_kernel = "31,31,15,15,15,31,31"
117
+ params.num_heads = "8,8,8,8,8,8,8"
118
+ elif model_version == "600m_uniform_out_ds2":
119
+ params.output_downsampling_factor = 2
120
+ params.downsampling_factor = "1,2,4,8,4,2,1"
121
+ params.num_encoder_layers = "1,2,3,4,1,1,1"
122
+ params.feedforward_dim = "3840,3840,3840,3840,3840,3840,3840"
123
+ params.encoder_dim = "1280,1280,1280,1280,1280,1280,1280"
124
+ params.encoder_unmasked_dim = "768,768,768,768,768,768,768"
125
+ params.cnn_module_kernel = "31,31,15,15,15,31,31"
126
+ params.num_heads = "8,8,8,8,8,8,8"
127
+ else:
128
+ raise ValueError()
129
+ return params
130
+
131
+ def get_model(model_version) -> nn.Module:
132
+ # initialise the encoder model
133
+
134
+ params = get_params(model_version)
135
+ encoder_embed = get_encoder_embed(params)
136
+ encoder = get_encoder_model(params)
137
+ print(params)
138
+
139
+ model = MultiKDModel(
140
+ encoder_embed=encoder_embed,
141
+ encoder=encoder,
142
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
143
+ num_codebooks=0,
144
+ )
145
+
146
+ return model
147
+
148
+ def get_init_states(
149
+ model: nn.Module,
150
+ batch_size: int = 1,
151
+ device: torch.device = torch.device("cpu"),
152
+ ) -> List[torch.Tensor]:
153
+ """
154
+ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
155
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
156
+ states[-2] is the cached left padding for ConvNeXt module,
157
+ of shape (batch_size, num_channels, left_pad, num_freqs)
158
+ states[-1] is processed_lens of shape (batch,), which records the number
159
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
160
+ """
161
+ states = model.encoder.get_init_states(batch_size, device)
162
+
163
+ embed_states = model.encoder_embed.get_init_states(batch_size, device)
164
+ states.append(embed_states)
165
+
166
+ processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
167
+ states.append(processed_lens)
168
+
169
+ return states
170
+
171
+ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
172
+ """Stack list of zipformer states that correspond to separate utterances
173
+ into a single emformer state, so that it can be used as an input for
174
+ zipformer when those utterances are formed into a batch.
175
+
176
+ Args:
177
+ state_list:
178
+ Each element in state_list corresponding to the internal state
179
+ of the zipformer model for a single utterance. For element-n,
180
+ state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
181
+ state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
182
+ cached_val2, cached_conv1, cached_conv2).
183
+ state_list[n][-2] is the cached left padding for ConvNeXt module,
184
+ of shape (batch_size, num_channels, left_pad, num_freqs)
185
+ state_list[n][-1] is processed_lens of shape (batch,), which records the number
186
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
187
+
188
+ Note:
189
+ It is the inverse of :func:`unstack_states`.
190
+ """
191
+ batch_size = len(state_list)
192
+ assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
193
+ tot_num_layers = (len(state_list[0]) - 2) // 6
194
+
195
+ batch_states = []
196
+ for layer in range(tot_num_layers):
197
+ layer_offset = layer * 6
198
+ # cached_key: (left_context_len, batch_size, key_dim)
199
+ cached_key = torch.cat(
200
+ [state_list[i][layer_offset] for i in range(batch_size)], dim=1
201
+ )
202
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
203
+ cached_nonlin_attn = torch.cat(
204
+ [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
205
+ )
206
+ # cached_val1: (left_context_len, batch_size, value_dim)
207
+ cached_val1 = torch.cat(
208
+ [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
209
+ )
210
+ # cached_val2: (left_context_len, batch_size, value_dim)
211
+ cached_val2 = torch.cat(
212
+ [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
213
+ )
214
+ # cached_conv1: (#batch, channels, left_pad)
215
+ cached_conv1 = torch.cat(
216
+ [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
217
+ )
218
+ # cached_conv2: (#batch, channels, left_pad)
219
+ cached_conv2 = torch.cat(
220
+ [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
221
+ )
222
+ batch_states += [
223
+ cached_key,
224
+ cached_nonlin_attn,
225
+ cached_val1,
226
+ cached_val2,
227
+ cached_conv1,
228
+ cached_conv2,
229
+ ]
230
+
231
+ cached_embed_left_pad = torch.cat(
232
+ [state_list[i][-2] for i in range(batch_size)], dim=0
233
+ )
234
+ batch_states.append(cached_embed_left_pad)
235
+
236
+ processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
237
+ batch_states.append(processed_lens)
238
+
239
+ return batch_states
240
+
241
+ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
242
+ """Unstack the zipformer state corresponding to a batch of utterances
243
+ into a list of states, where the i-th entry is the state from the i-th
244
+ utterance in the batch.
245
+
246
+ Note:
247
+ It is the inverse of :func:`stack_states`.
248
+
249
+ Args:
250
+ batch_states: A list of cached tensors of all encoder layers. For layer-i,
251
+ states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
252
+ cached_conv1, cached_conv2).
253
+ state_list[-2] is the cached left padding for ConvNeXt module,
254
+ of shape (batch_size, num_channels, left_pad, num_freqs)
255
+ states[-1] is processed_lens of shape (batch,), which records the number
256
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
257
+
258
+ Returns:
259
+ state_list: A list of list. Each element in state_list corresponding to the internal state
260
+ of the zipformer model for a single utterance.
261
+ """
262
+ assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
263
+ tot_num_layers = (len(batch_states) - 2) // 6
264
+
265
+ processed_lens = batch_states[-1]
266
+ batch_size = processed_lens.shape[0]
267
+
268
+ state_list = [[] for _ in range(batch_size)]
269
+
270
+ for layer in range(tot_num_layers):
271
+ layer_offset = layer * 6
272
+ # cached_key: (left_context_len, batch_size, key_dim)
273
+ cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
274
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
275
+ cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
276
+ chunks=batch_size, dim=1
277
+ )
278
+ # cached_val1: (left_context_len, batch_size, value_dim)
279
+ cached_val1_list = batch_states[layer_offset + 2].chunk(
280
+ chunks=batch_size, dim=1
281
+ )
282
+ # cached_val2: (left_context_len, batch_size, value_dim)
283
+ cached_val2_list = batch_states[layer_offset + 3].chunk(
284
+ chunks=batch_size, dim=1
285
+ )
286
+ # cached_conv1: (#batch, channels, left_pad)
287
+ cached_conv1_list = batch_states[layer_offset + 4].chunk(
288
+ chunks=batch_size, dim=0
289
+ )
290
+ # cached_conv2: (#batch, channels, left_pad)
291
+ cached_conv2_list = batch_states[layer_offset + 5].chunk(
292
+ chunks=batch_size, dim=0
293
+ )
294
+ for i in range(batch_size):
295
+ state_list[i] += [
296
+ cached_key_list[i],
297
+ cached_nonlin_attn_list[i],
298
+ cached_val1_list[i],
299
+ cached_val2_list[i],
300
+ cached_conv1_list[i],
301
+ cached_conv2_list[i],
302
+ ]
303
+
304
+ cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
305
+ for i in range(batch_size):
306
+ state_list[i].append(cached_embed_left_pad_list[i])
307
+
308
+ processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
309
+ for i in range(batch_size):
310
+ state_list[i].append(processed_lens_list[i])
311
+
312
+ return state_list
313
+
314
+ def streaming_forward(
315
+ features: Tensor,
316
+ feature_lens: Tensor,
317
+ model: nn.Module,
318
+ states: List[Tensor],
319
+ chunk_size: int,
320
+ left_context_len: int,
321
+ ) -> Tuple[Tensor, Tensor, List[Tensor]]:
322
+ """
323
+ Returns encoder outputs, output lengths, and updated states.
324
+ """
325
+ cached_embed_left_pad = states[-2]
326
+ (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
327
+ x=features,
328
+ x_lens=feature_lens,
329
+ cached_left_pad=cached_embed_left_pad,
330
+ )
331
+ assert x.size(1) == chunk_size, (x.size(1), chunk_size)
332
+
333
+ src_key_padding_mask = make_pad_mask(x_lens)
334
+
335
+ # processed_mask is used to mask out initial states
336
+ processed_mask = torch.arange(left_context_len, device=x.device).expand(
337
+ x.size(0), left_context_len
338
+ )
339
+ processed_lens = states[-1] # (batch,)
340
+ # (batch, left_context_size)
341
+ processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
342
+ # Update processed lengths
343
+ new_processed_lens = processed_lens + x_lens
344
+
345
+ # (batch, left_context_size + chunk_size)
346
+ src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
347
+
348
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
349
+ encoder_states = states[:-2]
350
+ (
351
+ encoder_out,
352
+ encoder_out_lens,
353
+ new_encoder_states,
354
+ ) = model.encoder.streaming_forward(
355
+ x=x,
356
+ x_lens=x_lens,
357
+ states=encoder_states,
358
+ src_key_padding_mask=src_key_padding_mask,
359
+ )
360
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
361
+
362
+ new_states = new_encoder_states + [
363
+ new_cached_embed_left_pad,
364
+ new_processed_lens,
365
+ ]
366
+ return encoder_out, encoder_out_lens, new_states
367
+
368
+ def chunk_forward(
369
+ audio: torch.Tensor,
370
+ model: torch.nn.Module,
371
+ feature_dim: int = 128,
372
+ chunk_size: int = 8,
373
+ left_context_frames: int = 256,
374
+ ):
375
+ # Perform chunk by chunk forward for the encoder. Each chunk is conditioned on the current chunk and left context (maintained by the states)
376
+ # At each step, we take a chunk of audio and forward the encoder
377
+ # For the first chunk, we wait until the accumulated audio duration to reach (buffer + chunk_duration), the buffer
378
+ # is necessary for the convolution subsampling module in the encoder.
379
+ # After the first chunk, we perform normal chunk-by-chunk inference when the accumulated audio reaches chunk_duration
380
+ # An example of Buffer=2 frames, chunk=5 frames, the latency for the first chunk is 7 frames (as we need to accumulate 7 frames
381
+ # for decoding), the rest chunks have latency of 5 frames.
382
+ # Each time we feed (5 + 2) frames to the encoder, and then shift 5 frames
383
+ # Chunk 1: AAAAAAA
384
+ # Chunk 2: AAAAAAA
385
+ # Chunk 3: AAAAAAA
386
+
387
+ # NOTE: params.chunk_size is the chunk_size regarding to the input of the zipformer encoder, so at fbank level, the chunk size
388
+ # is 2 * params.chunk_size
389
+
390
+ # fbank extractor
391
+ extractor = Fbank(FbankConfig(num_mel_bins=feature_dim))
392
+
393
+ device = next(model.parameters()).device
394
+
395
+ chunk_size = int(chunk_size)
396
+ chunk_size_samples = int(chunk_size * 2 * 160) # chunk size represented in audio samples of 16kHz sampling rate
397
+ left_context_len = int(left_context_frames)
398
+ pad_length = 7 + 2 * 3 # buffer required by encoder_embed module (i.e. convolution subsampling)
399
+ pad_length_samples = (7 + 2 * 3) * 160
400
+
401
+ # intialize states, to be maintained during chunk-wise forward
402
+ initial_states = get_init_states(model=model, batch_size=1, device=device)
403
+
404
+ # start forward chunk by chunk
405
+ encoder_outs = []
406
+ encoder_out_lens = 0
407
+ states = initial_states
408
+
409
+ num_chunk = 0
410
+ num_processed_samples = 0 # audio samples
411
+
412
+ # the actual loop performing the chunk-wise inference of the encoder
413
+ while True:
414
+ # prepare the input for processing current chunk
415
+ # compute fbank for the current chunk
416
+ audio_chunk = audio[:, num_processed_samples: num_processed_samples + (chunk_size_samples + pad_length_samples)]
417
+ features = extractor.extract(audio_chunk, sampling_rate=16000)
418
+ features = features.to(device)
419
+ feature_lens = features.shape[0]
420
+
421
+ feature_lens = torch.tensor([feature_lens], device=device) # shape: (1)
422
+ features = features.unsqueeze(0) # shape: (1,T,num_mels)
423
+
424
+ # the audio chunk could be shorter than the expected length, for example in the last two chunks
425
+ # pad the chunk so that the input shape is (chunk_size + buffer)
426
+ tail_length = chunk_size * 2 + 7 + 2 * 3 # each prepared chunk should have this length
427
+ if features.size(1) < tail_length:
428
+ pad_length = tail_length - features.size(1)
429
+ feature_lens += pad_length
430
+ features = torch.nn.functional.pad(
431
+ features,
432
+ (0, 0, 0, pad_length),
433
+ mode="constant",
434
+ value=LOG_EPS,
435
+ )
436
+
437
+ states = stack_states([states])
438
+
439
+ # forward current chunk in batch=1
440
+ encoder_out, encoder_out_len, new_states = streaming_forward(
441
+ features=features,
442
+ feature_lens=feature_lens,
443
+ model=model,
444
+ states=states,
445
+ chunk_size=chunk_size,
446
+ left_context_len=left_context_len,
447
+ )
448
+
449
+ encoder_outs.append(encoder_out)
450
+ encoder_out_lens += encoder_out_len
451
+
452
+ # update the states
453
+ states = unstack_states(new_states)[0]
454
+
455
+ num_chunk += 1
456
+ num_processed_samples += chunk_size_samples
457
+
458
+ if num_processed_samples > audio.shape[1]:
459
+ print(f"Audio is exhausted.")
460
+ break
461
+
462
+ encoder_outs = torch.cat(encoder_outs, dim=1) # shape: (1,T,C)
463
+
464
+ return encoder_outs, encoder_out_lens
465
+
466
+
467
+
468
+ def main(args):
469
+ device = torch.device("cpu")
470
+ if torch.cuda.is_available():
471
+ device = torch.device("cuda")
472
+
473
+ # load model
474
+ model = get_model(args)
475
+ model.to(device)
476
+
477
+ info = model.load_state_dict(
478
+ torch.load(args.ckpt_path)["model"], strict=False
479
+ )
480
+ print(info)
481
+ model.eval()
482
+
483
+ # load audio
484
+ audio, fs = torchaudio.load(args.audio)
485
+ assert fs == 16000
486
+
487
+ encoder_out, encoder_out_lens = chunk_forward(
488
+ audio=audio, # shape (1, num_samples)
489
+ model=model,
490
+ feature_dim=128,
491
+ chunk_size=args.chunk_size,
492
+ left_context_frames=args.left_context_frames,
493
+ )
494
+
495
+
496
+ print(encoder_out)
497
+ print(encoder_out.shape)
498
+ # torch.save(encoder_out, "streaming_forward_encoder_out_no_k2.pt")
499
+
500
+ if __name__=="__main__":
501
+ parser = get_parser()
502
+ args = parser.parse_args()
503
+
504
+ main(args)
iter-400000-avg-4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ab1b41eb7d85e83e01bf44249dfbac1c56c0f644aedc7f790268aaff0e287e4
3
+ size 2398204810
model.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
2
+ # Wei Kang,
3
+ # Zengwei Yao)
4
+ #
5
+ # Copyright 2025 University of Cambridge (authors: Xiaoyu Yang)
6
+ #
7
+ # See ../../../../LICENSE for clarification regarding multiple authors
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import logging
22
+ from typing import Optional, Tuple
23
+ import random
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+
30
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
31
+ """
32
+ Args:
33
+ lengths:
34
+ A 1-D tensor containing sentence lengths.
35
+ max_len:
36
+ The length of masks.
37
+ Returns:
38
+ Return a 2-D bool tensor, where masked positions
39
+ are filled with `True` and non-masked positions are
40
+ filled with `False`.
41
+
42
+ >>> lengths = torch.tensor([1, 3, 2, 5])
43
+ >>> make_pad_mask(lengths)
44
+ tensor([[False, True, True, True, True],
45
+ [False, False, False, True, True],
46
+ [False, False, True, True, True],
47
+ [False, False, False, False, False]])
48
+ """
49
+ assert lengths.ndim == 1, lengths.ndim
50
+ max_len = max(max_len, lengths.max())
51
+ n = lengths.size(0)
52
+ seq_range = torch.arange(0, max_len, device=lengths.device)
53
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
54
+
55
+ return expaned_lengths >= lengths.unsqueeze(-1)
56
+
57
+
58
+ class MultiKDModel(nn.Module):
59
+ def __init__(
60
+ self,
61
+ encoder_embed: nn.Module,
62
+ encoder: nn.Module,
63
+ encoder_dim: int,
64
+ num_codebooks: int=8,
65
+ distillation_layer: int=9,
66
+ distillation_delta: int=0,
67
+ teacher_frame_ratio: int = 2,
68
+ interpolate_teacher: bool = False,
69
+ n_mels: int = 128,
70
+ num_events: int = 527,
71
+ mask_mode: str = "w2v2",
72
+ mask_prob: float = 0.65,
73
+ mask_length: int = 10,
74
+ mask_selection: str = "static",
75
+ mask_other: float = 0.0,
76
+ min_masks: int = 2,
77
+ mask_channel_prob: float = 0.0,
78
+ mask_channel_length: int = 10,
79
+ mask_channel_selection: str = "static",
80
+ mask_channel_other: float = 0.0,
81
+ loss_only_mask: bool = False,
82
+ ):
83
+ """A model that performs MVQ KD pre-training .
84
+
85
+ Args:
86
+ encoder_embed:
87
+ It is a Convolutional 2D subsampling module. It converts
88
+ an input of shape (N, T, idim) to an output of of shape
89
+ (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
90
+ encoder:
91
+ It is the transcription network in the paper. Its accepts
92
+ two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
93
+ It returns two tensors: `logits` of shape (N, T, encoder_dim) and
94
+ `logit_lens` of shape (N,).
95
+ num_codebooks:
96
+ The number of codebooks used in the target
97
+ distillation_layer:
98
+ Use which layer to do MVQ pre-training
99
+ distillation_delta:
100
+ How many frames to delay the alignment between the model and the target frames.
101
+ Should be zero for non-streaming models, and a positive number for streaming models
102
+ teacher_frame_ratio:
103
+ The frame rate ratio between the target and the model output
104
+ mask_mode:
105
+ The masking mode.
106
+ w2v2: the wav2vec2 style of masking, allows overlap
107
+ custom: no overlap, therefore bigger masking ratio
108
+ mask_prob:
109
+ The probability of selecting choosing one frame as the start index
110
+ mask_length:
111
+ The length of each mask
112
+ mask_selection:
113
+ How to determine the length of the mask, see ``compute_mask_indices''
114
+ """
115
+ super().__init__()
116
+
117
+ self.encoder_embed = encoder_embed
118
+ self.encoder = encoder
119
+ self.encoder_dim = encoder_dim
120
+
121
+ self.distillation_layer = distillation_layer
122
+ # the frame ratio between the teacher and student
123
+ # if larger than one, we are basically having more than one set of
124
+ # codebooks for each frame
125
+ self.num_codebooks= num_codebooks
126
+ self.teacher_frame_ratio = teacher_frame_ratio
127
+ self.interpolate_teacher = interpolate_teacher
128
+ self.distillation_delta = distillation_delta
129
+
130
+ if num_codebooks > 0:
131
+ from multi_quantization.prediction import JointCodebookLoss
132
+ self.codebook_loss_net = JointCodebookLoss(
133
+ predictor_channels=encoder_dim,
134
+ num_codebooks=num_codebooks * self.teacher_frame_ratio,
135
+ is_joint=False,
136
+ reduction="none",
137
+ )
138
+ else:
139
+ self.codebook_loss_net = None
140
+
141
+ self.audio_tagging_proj = nn.Sequential(
142
+ nn.Dropout(0.1),
143
+ nn.Linear(encoder_dim, num_events),
144
+ ) # 527 classes
145
+
146
+ # masking related
147
+ assert mask_mode in ["w2v2", "block"], f"Unseen mask mode: {mask_mode}"
148
+ self.mask_mode = mask_mode
149
+
150
+ self.mask_emb = nn.Parameter(torch.FloatTensor(n_mels).normal_())
151
+ self.mask_prob = mask_prob
152
+ self.mask_length = mask_length
153
+ self.mask_selection = mask_selection
154
+ self.mask_other = mask_other
155
+ self.min_masks = min_masks
156
+
157
+ self.mask_channel_prob = mask_channel_prob
158
+ self.mask_channel_length = mask_channel_length
159
+ self.mask_channel_selection = mask_channel_selection
160
+ self.mask_channel_other = mask_channel_other
161
+
162
+ self.loss_only_mask = loss_only_mask
163
+
164
+ def forward_encoder(
165
+ self, x: torch.Tensor, x_lens: torch.Tensor
166
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
167
+ """Compute encoder outputs.
168
+ Args:
169
+ x:
170
+ A 3-D tensor of shape (N, T, C).
171
+ x_lens:
172
+ A 1-D tensor of shape (N,). It contains the number of frames in `x`
173
+ before padding.
174
+
175
+ Returns:
176
+ encoder_out:
177
+ Encoder output, of shape (N, T, C).
178
+ encoder_out_lens:
179
+ Encoder output lengths, of shape (N,).
180
+ """
181
+ # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
182
+ x, x_lens = self.encoder_embed(x, x_lens)
183
+ # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
184
+
185
+ src_key_padding_mask = make_pad_mask(x_lens)
186
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
187
+
188
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
189
+
190
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
191
+ assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
192
+
193
+ return encoder_out, encoder_out_lens
194
+
195
+ def forward(
196
+ self,
197
+ x: torch.Tensor,
198
+ x_lens: torch.Tensor,
199
+ codebook_indexes: torch.Tensor = None,
200
+ at_targets: torch.Tensor = None,
201
+ mask: bool = True,
202
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
203
+ """
204
+ Args:
205
+ x:
206
+ A 3-D tensor of shape (N, T, C).
207
+ x_lens:
208
+ A 1-D tensor of shape (N,). It contains the number of frames in `x`
209
+ before padding.
210
+ codebook_indexes:
211
+ Codebook indexes of teacher embeddings
212
+ mask:
213
+ If we perform w2v2 style of masking over the fbank frames
214
+
215
+ Returns:
216
+ Return the codebook loss
217
+ """
218
+ assert x.ndim == 3, x.shape
219
+ assert x_lens.ndim == 1, x_lens.shape
220
+ assert codebook_indexes is not None or at_targets is not None
221
+
222
+ # apply masking
223
+ if self.training and mask:
224
+ padding_mask = make_pad_mask(x_lens)
225
+
226
+ # apply masking to the fbank features
227
+ x, mask_indices = self.apply_mask(
228
+ x.clone(),
229
+ padding_mask=padding_mask
230
+ ) # (N,T,C), (N,T)
231
+ else:
232
+ mask_indices = None
233
+
234
+ # Compute encoder outputs
235
+ encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
236
+
237
+ if codebook_indexes is not None and self.codebook_loss_net is not None:
238
+ codebook_loss = self.forward_codebook_loss(
239
+ encoder_out, encoder_out_lens, codebook_indexes, reduction="none"
240
+ )
241
+ if self.loss_only_mask and mask_indices is not None:
242
+ # downsample the mask
243
+ mask_indices = nn.functional.avg_pool1d(mask_indices, 4) >= 0.5
244
+ assert mask_indices.size(1) >= codebook_loss.size(1)
245
+ mask_indices = mask_indices[:, :codebook_loss.size(1)].float()
246
+ codebook_loss = codebook_loss * mask_indices
247
+ codebook_loss = codebook_loss.sum(dim=1) # (B,)
248
+ else:
249
+ codebook_loss = None
250
+
251
+ if at_targets is not None:
252
+ at_loss = self.forward_audio_tagging(encoder_out, encoder_out_lens, at_targets, return_logits=False)
253
+ else:
254
+ at_loss = None
255
+
256
+ return codebook_loss, at_loss
257
+
258
+ def forward_codebook_loss(
259
+ self,
260
+ encoder_out: torch.Tensor,
261
+ encoder_out_lens: torch.Tensor,
262
+ codebook_indexes: torch.Tensor,
263
+ reduction: str = "sum",
264
+ ):
265
+ # align the encoder features with the codebook indexes
266
+ if self.interpolate_teacher:
267
+ codebook_indexes = self.interpolate_codebook_indexes(
268
+ encoder_out, codebook_indexes
269
+ )
270
+ else:
271
+ if codebook_indexes.shape[1] != encoder_out.shape[1]:
272
+ # align the codebook indexes to the frame rate of the student encoder out
273
+ codebook_indexes = self.concat_successive_codebook_indexes(
274
+ encoder_out, codebook_indexes, ratio=self.teacher_frame_ratio
275
+ )
276
+
277
+ # the delta is associated with the frame-rate of the encoder
278
+ # so a bigger delta maybe necessary for 50Hz student encoder
279
+ if self.distillation_delta > 0:
280
+ codebook_indexes = codebook_indexes[:,:-self.distillation_delta, :]
281
+ encoder_out = encoder_out[:, self.distillation_delta:, :]
282
+ truncated_padding_mask = make_pad_mask(encoder_out_lens - self.distillation_delta)
283
+ codebook_indexes = codebook_indexes.masked_fill(truncated_padding_mask.unsqueeze(-1), value=-100)
284
+
285
+ N,T,_ = encoder_out.shape
286
+ codebook_loss = self.codebook_loss_net(encoder_out.float(), codebook_indexes)
287
+ codebook_loss = codebook_loss.reshape(N,T,-1)
288
+ num_cb = codebook_loss.size(-1)
289
+ # normalize the loss by the number of codebooks
290
+ if reduction == "sum":
291
+ codebook_loss = codebook_loss.sum(dim=(1,2)) / num_cb # (B,)
292
+ elif reduction == "none":
293
+ codebook_loss = codebook_loss.sum(dim=2) / num_cb # (B,T)
294
+ else:
295
+ raise NotImplementedError()
296
+
297
+ return codebook_loss
298
+
299
+ def forward_audio_tagging(
300
+ self,
301
+ encoder_out: torch.Tensor,
302
+ encoder_out_lens: torch.Tensor,
303
+ target: torch.Tensor = None,
304
+ return_logits: bool = False,
305
+ ):
306
+ # target: (N, num_events)
307
+ logits = self.audio_tagging_proj(encoder_out) # (N, T, num_classes)
308
+ padding_mask = make_pad_mask(encoder_out_lens) # (N,T)
309
+ logits[padding_mask] = 0
310
+ logits = logits.sum(dim=1)
311
+ logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) # (N, num_events)
312
+ if return_logits:
313
+ return logits
314
+
315
+ at_loss = F.binary_cross_entropy_with_logits(logits, target, reduction="none")
316
+
317
+ return at_loss
318
+
319
+ def apply_mask(
320
+ self,
321
+ x: torch.Tensor,
322
+ padding_mask: torch.Tensor = None
323
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
324
+ """Apply mask according to the mask_mode, return the masked features and the masked positions
325
+
326
+ Args:
327
+ x (torch.Tensor): The input fbank features
328
+ padding_mask (torch.Tensor, optional): The padding mask
329
+
330
+ Returns:
331
+ The masked fbank feature and the masked_indices, with masked positions as 1
332
+ """
333
+ # apply mask to the fbank features, two modes applicable
334
+ if self.mask_mode == "w2v2":
335
+ x, masked_indices = self.apply_mask_w2v2(x, padding_mask)
336
+ elif self.mask_mode == "block":
337
+ x, masked_indices = self.apply_mask_block(x, padding_mask)
338
+ else:
339
+ raise NotImplementedError()
340
+
341
+ if random.random() > 0.97:
342
+ logging.info(f"Apply {self.mask_mode} masking. A proportion of {masked_indices.sum()/masked_indices.numel():.2f} frames are masked")
343
+ return x, masked_indices
344
+
345
+
346
+ def apply_mask_block(
347
+ self,
348
+ x: torch.Tensor,
349
+ padding_mask: torch.Tensor = None
350
+ ):
351
+ B,T,C = x.shape
352
+ assert self.mask_prob > 0.0
353
+
354
+ mask_indices = compute_mask_indices_block(
355
+ shape=(B,T),
356
+ padding_mask=padding_mask,
357
+ mask_prob=self.mask_prob,
358
+ mask_length=self.mask_length,
359
+ min_masks=self.min_masks,
360
+ ).to(x.device)
361
+
362
+ x = index_put(x, mask_indices.bool(), self.mask_emb)
363
+
364
+ return x, mask_indices
365
+
366
+ def apply_mask_w2v2(
367
+ self,
368
+ x: torch.Tensor,
369
+ padding_mask: torch.Tensor = None
370
+ ):
371
+ # this function is modified from fairseq: https://github.com/facebookresearch/fairseq/blob/bedb259bf34a9fc22073c13a1cee23192fa70ef3/fairseq/models/wav2vec/wav2vec2.py#L429
372
+ # The masked indices have value 1
373
+ B, T, C = x.shape
374
+
375
+ # we mask channel first, then mask timestamps
376
+ if self.mask_channel_prob > 0:
377
+ mask_channel_indices = compute_mask_indices(
378
+ (B, C),
379
+ None,
380
+ self.mask_channel_prob,
381
+ self.mask_channel_length,
382
+ self.mask_channel_selection,
383
+ self.mask_channel_other,
384
+ no_overlap=False,
385
+ min_space=1,
386
+ require_same_masks=False,
387
+ )
388
+ mask_channel_indices = (
389
+ torch.from_numpy(mask_channel_indices)
390
+ .to(x.device)
391
+ .unsqueeze(1)
392
+ .expand(-1, T, -1)
393
+ )
394
+ if random.random() > 0.98:
395
+ logging.info(f"A proportion of {mask_channel_indices.sum()/mask_channel_indices.numel():.2f} feature dims are masked")
396
+ x[mask_channel_indices] = 0
397
+
398
+ if self.mask_prob > 0:
399
+ mask_indices = compute_mask_indices(
400
+ (B, T),
401
+ padding_mask,
402
+ self.mask_prob,
403
+ self.mask_length,
404
+ mask_type=self.mask_selection,
405
+ mask_other=self.mask_other,
406
+ min_masks=2, # fixed
407
+ no_overlap=False, # False
408
+ min_space=1, # 1
409
+ require_same_masks=False,
410
+ )
411
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
412
+ x = index_put(x, mask_indices, self.mask_emb)
413
+ mask_indices = mask_indices.float()
414
+ else:
415
+ mask_indices = None
416
+
417
+ return x, mask_indices
418
+
419
+ @staticmethod
420
+ def interpolate_codebook_indexes(middle_layer_output, codebook_indexes):
421
+ # This function addresses the case where the teacher has a lower frame rate
422
+ # than the student model
423
+ t_expected = middle_layer_output.shape[1]
424
+ N, T, C = codebook_indexes.shape # C should be 256
425
+
426
+ codebook_indexes = codebook_indexes.permute(0,2,1).float() # (N,C,T)
427
+ codebook_indexes = torch.nn.functional.interpolate(codebook_indexes, t_expected)
428
+ codebook_indexes = codebook_indexes.permute(0,2,1).int() # (N,T,C)
429
+
430
+ assert codebook_indexes.shape[1] == middle_layer_output.shape[1]
431
+ return codebook_indexes
432
+
433
+ @staticmethod
434
+ def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes, ratio=2):
435
+ # Output rate of hubert is 50 frames per second,
436
+ # while that of current encoder is 25.
437
+ # Following code handling two issues:
438
+ # 1.
439
+ # Roughly speaking, to generate another frame output,
440
+ # hubert needes extra two frames,
441
+ # while current encoder needs extra four frames.
442
+ # Suppose there are only extra three frames provided,
443
+ # hubert will generate another frame while current encoder does nothing.
444
+ # 2.
445
+ # codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
446
+ # learns from 50 frames teacher output, two successive frames of teacher model
447
+ # output is concatenated together.
448
+ t_expected = middle_layer_output.shape[1]
449
+ N, T, C = codebook_indexes.shape # C should be 256
450
+
451
+ # Handling issue 1.
452
+ if T >= t_expected * ratio:
453
+ codebook_indexes = codebook_indexes[:, : t_expected * ratio, :]
454
+ else:
455
+ assert t_expected * ratio - T <= 5, (T, t_expected, ratio)
456
+ diff = t_expected * ratio - T
457
+ codebook_indexes = torch.cat(
458
+ [
459
+ codebook_indexes,
460
+ torch.full((N,diff,C), -100).to(codebook_indexes.device).to(codebook_indexes.dtype)
461
+ ],
462
+ dim=1,
463
+ )
464
+ assert codebook_indexes.size(1) == middle_layer_output.size(1) * ratio
465
+
466
+ # Handling issue 2.
467
+ codebook_indexes = codebook_indexes.reshape(N, t_expected, C * ratio)
468
+ assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
469
+ return codebook_indexes
470
+
471
+ def index_put(tensor, indices, value):
472
+ tensor[indices] = value
473
+ return tensor
474
+
475
+ def compute_mask_indices_block(
476
+ shape,
477
+ padding_mask,
478
+ mask_prob: float = 0.5,
479
+ mask_length: int = 10,
480
+ min_masks: int = 2,
481
+ ):
482
+ # self-implemented mask, no overlap
483
+ B,T = shape
484
+ mask_indices = []
485
+ for i in range(B):
486
+ if padding_mask is not None:
487
+ num_segments = (T - padding_mask[i].sum()) // mask_length # discard the last few frames
488
+ else:
489
+ num_segments = T // mask_length
490
+ segment_mask = torch.rand(num_segments) < mask_prob
491
+ while sum(segment_mask) < min_masks:
492
+ segment_mask = torch.rand(num_segments) < mask_prob
493
+ segment_mask_expanded = segment_mask.unsqueeze(-1).expand(num_segments, mask_length)
494
+ segment_mask_expanded = segment_mask_expanded.reshape(-1).float()
495
+ if segment_mask_expanded.size(0) < T:
496
+ pad = T - segment_mask_expanded.size(0)
497
+ segment_mask_expanded = torch.cat([segment_mask_expanded, torch.zeros(pad)])
498
+ mask_indices.append(segment_mask_expanded)
499
+
500
+ mask_indices = torch.stack(mask_indices)
501
+ return mask_indices
502
+
503
+ def compute_mask_indices(
504
+ shape: Tuple[int, int],
505
+ padding_mask: Optional[torch.Tensor],
506
+ mask_prob: float,
507
+ mask_length: int,
508
+ mask_type: str = "static",
509
+ mask_other: float = 0.0,
510
+ min_masks: int = 0,
511
+ no_overlap: bool = False,
512
+ min_space: int = 0,
513
+ require_same_masks: bool = True,
514
+ mask_dropout: float = 0.0,
515
+ add_masks: bool = False,
516
+ seed: Optional[int] = None,
517
+ epoch: Optional[int] = None,
518
+ indices: Optional[torch.Tensor] = None,
519
+ idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
520
+ num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
521
+ ) -> np.ndarray:
522
+ """
523
+ Computes random mask spans for a given shape
524
+
525
+ Args:
526
+ shape: the the shape for which to compute masks.
527
+ should be of size 2 where first element is batch size and 2nd is timesteps
528
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
529
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
530
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
531
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
532
+ mask_type: how to compute mask lengths
533
+ static = fixed size
534
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
535
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
536
+ poisson = sample from possion distribution with lambda = mask length
537
+ min_masks: minimum number of masked spans
538
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
539
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
540
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
541
+ mask_dropout: randomly dropout this percentage of masks in each example
542
+ """
543
+
544
+ bsz, all_sz = shape
545
+ mask = np.full((bsz, all_sz), False)
546
+
547
+ if num_mask_ver == 1:
548
+ all_num_mask = int(
549
+ # add a random number for probabilistic rounding
550
+ mask_prob * all_sz / float(mask_length)
551
+ + np.random.rand()
552
+ )
553
+ all_num_mask = max(min_masks, all_num_mask)
554
+
555
+ mask_idcs = []
556
+ for i in range(bsz):
557
+ if seed is not None and epoch is not None and indices is not None:
558
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
559
+ else:
560
+ seed_i = None
561
+
562
+ rng = np.random.default_rng(seed_i)
563
+
564
+ if padding_mask is not None:
565
+ sz = all_sz - padding_mask[i].long().sum().item()
566
+ assert sz >= 0, sz
567
+ else:
568
+ sz = all_sz
569
+
570
+ if num_mask_ver == 1:
571
+ if padding_mask is not None:
572
+ num_mask = int(
573
+ # add a random number for probabilistic rounding
574
+ mask_prob * sz / float(mask_length)
575
+ + np.random.rand()
576
+ )
577
+ num_mask = max(min_masks, num_mask)
578
+ else:
579
+ num_mask = all_num_mask
580
+ elif num_mask_ver == 2:
581
+ num_mask = int(
582
+ # add a random number for probabilistic rounding
583
+ mask_prob * sz / float(mask_length)
584
+ + rng.random()
585
+ )
586
+ num_mask = max(min_masks, num_mask)
587
+ hard_max = sz // mask_length
588
+ num_mask = min(hard_max, num_mask) # prevent whole sequence being masked
589
+ else:
590
+ raise ValueError()
591
+
592
+ if mask_type == "static":
593
+ lengths = np.full(num_mask, mask_length)
594
+ elif mask_type == "uniform":
595
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
596
+ elif mask_type == "normal":
597
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
598
+ lengths = [max(1, int(round(x))) for x in lengths]
599
+ elif mask_type == "poisson":
600
+ lengths = rng.poisson(mask_length, size=num_mask)
601
+ lengths = [int(round(x)) for x in lengths]
602
+ else:
603
+ raise Exception("unknown mask selection " + mask_type)
604
+
605
+ if sum(lengths) == 0:
606
+ if mask_type == "static":
607
+ raise ValueError("this should never happens")
608
+ else:
609
+ lengths = [min(mask_length, sz - 1)]
610
+
611
+ if no_overlap:
612
+ mask_idc = []
613
+
614
+ def arrange(s, e, length, keep_length):
615
+ span_start = rng.randint(s, e - length)
616
+ mask_idc.extend(span_start + i for i in range(length))
617
+
618
+ new_parts = []
619
+ if span_start - s - min_space >= keep_length:
620
+ new_parts.append((s, span_start - min_space + 1))
621
+ if e - span_start - length - min_space > keep_length:
622
+ new_parts.append((span_start + length + min_space, e))
623
+ return new_parts
624
+
625
+ parts = [(0, sz)]
626
+ min_length = min(lengths)
627
+ for length in sorted(lengths, reverse=True):
628
+ lens = np.fromiter(
629
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
630
+ np.int,
631
+ )
632
+ l_sum = np.sum(lens)
633
+ if l_sum == 0:
634
+ break
635
+ probs = lens / np.sum(lens)
636
+ c = rng.choice(len(parts), p=probs)
637
+ s, e = parts.pop(c)
638
+ parts.extend(arrange(s, e, length, min_length))
639
+ mask_idc = np.asarray(mask_idc)
640
+ else:
641
+ if idc_select_ver == 1:
642
+ min_len = min(lengths)
643
+ if sz - min_len <= num_mask:
644
+ min_len = sz - num_mask - 1
645
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
646
+ elif idc_select_ver == 2:
647
+ mask_idc = rng.choice(sz, num_mask, replace=False)
648
+ else:
649
+ raise ValueError()
650
+
651
+ mask_idc = np.asarray(
652
+ [
653
+ mask_idc[j] + offset
654
+ for j in range(len(mask_idc))
655
+ for offset in range(lengths[j])
656
+ ]
657
+ )
658
+
659
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
660
+ if len(mask_idc) >= sz:
661
+
662
+ raise ValueError(
663
+ (
664
+ f"the entire sequence is masked. "
665
+ f"sz={sz}; mask_idc[mask_idc]; "
666
+ f"index={indices[i] if indices is not None else None}"
667
+ )
668
+ )
669
+ mask_idcs.append(mask_idc)
670
+
671
+ target_len = None
672
+ if require_same_masks:
673
+ if add_masks:
674
+ target_len = max([len(m) for m in mask_idcs])
675
+ else:
676
+ target_len = min([len(m) for m in mask_idcs])
677
+
678
+ for i, mask_idc in enumerate(mask_idcs):
679
+ if target_len is not None and len(mask_idc) > target_len:
680
+ mask_idc = rng.choice(mask_idc, target_len, replace=False)
681
+
682
+ mask[i, mask_idc] = True
683
+
684
+ if target_len is not None and len(mask_idc) < target_len:
685
+ unmasked = np.flatnonzero(~mask[i])
686
+ to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
687
+ mask[i, to_mask] = True
688
+
689
+ if mask_dropout > 0:
690
+ masked = np.flatnonzero(mask[i])
691
+ num_holes = np.rint(len(masked) * mask_dropout).astype(int)
692
+ to_drop = rng.choice(masked, num_holes, replace=False)
693
+ mask[i, to_drop] = False
694
+
695
+ return mask
696
+
697
+ def _test_w2v2_channel_mask():
698
+ x = torch.ones(100, 1000, 128)
699
+ B, T, C = x.shape
700
+
701
+ configs = [(0.25, 15), (0.25, 20), (0.5, 15),]
702
+ # configs = [(0.2, 20), (0.3, 20), (0.4, 20),]
703
+ for config in configs:
704
+ mask_channel_prob, mask_channel_length = config
705
+ ratios = []
706
+ for i in range(20):
707
+ mask_channel_indices = compute_mask_indices(
708
+ (B, C),
709
+ None,
710
+ mask_channel_prob,
711
+ mask_channel_length,
712
+ "static",
713
+ 0.0,
714
+ no_overlap=False,
715
+ min_space=1,
716
+ require_same_masks=False,
717
+ )
718
+ mask_channel_indices = (
719
+ torch.from_numpy(mask_channel_indices)
720
+ .to(x.device)
721
+ .unsqueeze(1)
722
+ .expand(-1, T, -1)
723
+ )
724
+ ratio = mask_channel_indices.sum() / mask_channel_indices.numel()
725
+ ratios.append(ratio)
726
+ import pdb; pdb.set_trace()
727
+ avg_ratio = sum(ratios) / len(ratios)
728
+ print(f"Current config: mask_channel_prob = {mask_channel_prob}, mask_channel_length = {mask_channel_length}")
729
+ print(f"Averaged masking ratio: {avg_ratio}")
730
+
731
+ def _test_w2v2_mask():
732
+ x = torch.ones(100, 1000, 128)
733
+ B, T, C = x.shape
734
+
735
+ mask_prob = 0.65
736
+ mask_length = 10
737
+
738
+ # configs = [(0.65, 10), (0.01, 40), (0.1, 40), (0.2, 40), (0.2, 20), (0.35, 10), (0.35, 20), (0.25, 20)]
739
+ configs = []
740
+ for i in range(6):
741
+ p = 0.05 + (i+1) * 0.1
742
+ for l in [10, 20, 30, 40]:
743
+ configs.append((p, l))
744
+ configs = [(0.65, 10), (0.02, 40), (0.05, 40), (0.1, 40)]
745
+ for config in configs:
746
+ mask_prob, mask_length = config
747
+ ratios = []
748
+ for i in range(20):
749
+ mask_indices = compute_mask_indices(
750
+ (B, T),
751
+ None,
752
+ mask_prob,
753
+ mask_length,
754
+ mask_type="static",
755
+ mask_other=0.0,
756
+ min_masks=2,
757
+ no_overlap=False, # False
758
+ min_space=1, # 1
759
+ require_same_masks=False,
760
+ )
761
+ mask_indices = torch.from_numpy(mask_indices)
762
+ ratio = mask_indices.sum() / mask_indices.numel()
763
+ ratios.append(ratio)
764
+ avg_ratio = sum(ratios) / len(ratios)
765
+ print(f"Current config: mask_prob = {mask_prob}, mask_length = {mask_length}")
766
+ print(f"Averaged masking ratio: {avg_ratio}")
767
+
768
+ def _test_custom_mask():
769
+ x = torch.ones(100, 1000, 128)
770
+ B, T, C = x.shape
771
+
772
+ configs = [(0.5, 20), (0.2, 20), (0.3, 20), (0.4, 20), (0.5, 20)]
773
+ for config in configs:
774
+ mask_prob, mask_length = config
775
+ ratios = []
776
+ for i in range(20):
777
+ all_possible_mask_lengths = [mask_length + i * 2 for i in range(-5, 6)]
778
+ mask_length = random.sample(all_possible_mask_lengths, 1)[0]
779
+ assert mask_length > 0, f"Sampled mask_length smaller than 0, {mask_length}"
780
+
781
+ mask_indices = compute_mask_indices_block(
782
+ shape=(B, T),
783
+ padding_mask=None,
784
+ mask_prob=mask_prob,
785
+ mask_length=mask_length,
786
+ min_masks=2,
787
+ )
788
+ import pdb; pdb.set_trace()
789
+ ratio = mask_indices.sum() / mask_indices.numel()
790
+ ratios.append(ratio)
791
+ avg_ratio = sum(ratios) / len(ratios)
792
+ print(f"Current config: mask_prob = {mask_prob}, mask_length = {mask_length}")
793
+ print(f"Averaged masking ratio: {avg_ratio}")
794
+
795
+
796
+ if __name__=="__main__":
797
+ _test_w2v2_channel_mask()
798
+ # _test_w2v2_mask()
799
+ # _test_custom_mask()
scaling.py ADDED
@@ -0,0 +1,1913 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import logging
19
+ import math
20
+ import random
21
+ from typing import Optional, Tuple, Union
22
+
23
+ # import k2
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch import Tensor
27
+ from torch.cuda.amp import custom_bwd, custom_fwd
28
+
29
+
30
+ def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
31
+ max_value = torch.max(x, y)
32
+ diff = torch.abs(x - y)
33
+ return max_value + torch.log1p(torch.exp(-diff))
34
+
35
+
36
+ # RuntimeError: Exporting the operator logaddexp to ONNX opset version
37
+ # 14 is not supported. Please feel free to request support or submit
38
+ # a pull request on PyTorch GitHub.
39
+ #
40
+ # The following function is to solve the above error when exporting
41
+ # models to ONNX via torch.jit.trace()
42
+ def logaddexp(x: Tensor, y: Tensor) -> Tensor:
43
+ # Caution(fangjun): Put torch.jit.is_scripting() before
44
+ # torch.onnx.is_in_onnx_export();
45
+ # otherwise, it will cause errors for torch.jit.script().
46
+ #
47
+ # torch.logaddexp() works for both torch.jit.script() and
48
+ # torch.jit.trace() but it causes errors for ONNX export.
49
+ #
50
+ if torch.jit.is_scripting():
51
+ # Note: We cannot use torch.jit.is_tracing() here as it also
52
+ # matches torch.onnx.export().
53
+ return torch.logaddexp(x, y)
54
+ elif torch.onnx.is_in_onnx_export():
55
+ return logaddexp_onnx(x, y)
56
+ else:
57
+ # for torch.jit.trace()
58
+ return torch.logaddexp(x, y)
59
+
60
+
61
+ class PiecewiseLinear(object):
62
+ """
63
+ Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
64
+ the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
65
+ respectively.
66
+ """
67
+
68
+ def __init__(self, *args):
69
+ assert len(args) >= 1, len(args)
70
+ if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
71
+ self.pairs = list(args[0].pairs)
72
+ else:
73
+ self.pairs = [(float(x), float(y)) for x, y in args]
74
+ for x, y in self.pairs:
75
+ assert isinstance(x, (float, int)), type(x)
76
+ assert isinstance(y, (float, int)), type(y)
77
+
78
+ for i in range(len(self.pairs) - 1):
79
+ assert self.pairs[i + 1][0] > self.pairs[i][0], (
80
+ i,
81
+ self.pairs[i],
82
+ self.pairs[i + 1],
83
+ )
84
+
85
+ def __str__(self):
86
+ # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
87
+ return f"PiecewiseLinear({str(self.pairs)[1:-1]})"
88
+
89
+ def __call__(self, x):
90
+ if x <= self.pairs[0][0]:
91
+ return self.pairs[0][1]
92
+ elif x >= self.pairs[-1][0]:
93
+ return self.pairs[-1][1]
94
+ else:
95
+ cur_x, cur_y = self.pairs[0]
96
+ for i in range(1, len(self.pairs)):
97
+ next_x, next_y = self.pairs[i]
98
+ if x >= cur_x and x <= next_x:
99
+ return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
100
+ cur_x, cur_y = next_x, next_y
101
+ assert False
102
+
103
+ def __mul__(self, alpha):
104
+ return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
105
+
106
+ def __add__(self, x):
107
+ if isinstance(x, (float, int)):
108
+ return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
109
+ s, x = self.get_common_basis(x)
110
+ return PiecewiseLinear(
111
+ *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]
112
+ )
113
+
114
+ def max(self, x):
115
+ if isinstance(x, (float, int)):
116
+ x = PiecewiseLinear((0, x))
117
+ s, x = self.get_common_basis(x, include_crossings=True)
118
+ return PiecewiseLinear(
119
+ *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
120
+ )
121
+
122
+ def min(self, x):
123
+ if isinstance(x, float) or isinstance(x, int):
124
+ x = PiecewiseLinear((0, x))
125
+ s, x = self.get_common_basis(x, include_crossings=True)
126
+ return PiecewiseLinear(
127
+ *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
128
+ )
129
+
130
+ def __eq__(self, other):
131
+ return self.pairs == other.pairs
132
+
133
+ def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
134
+ """
135
+ Returns (self_mod, p_mod) which are equivalent piecewise linear
136
+ functions to self and p, but with the same x values.
137
+
138
+ p: the other piecewise linear function
139
+ include_crossings: if true, include in the x values positions
140
+ where the functions indicate by this and p cross.
141
+ """
142
+ assert isinstance(p, PiecewiseLinear), type(p)
143
+
144
+ # get sorted x-values without repetition.
145
+ x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
146
+ y_vals1 = [self(x) for x in x_vals]
147
+ y_vals2 = [p(x) for x in x_vals]
148
+
149
+ if include_crossings:
150
+ extra_x_vals = []
151
+ for i in range(len(x_vals) - 1):
152
+ if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
153
+ # if the two lines in this subsegment potentially cross each other..
154
+ diff_cur = abs(y_vals1[i] - y_vals2[i])
155
+ diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
156
+ # `pos`, between 0 and 1, gives the relative x position,
157
+ # with 0 being x_vals[i] and 1 being x_vals[i+1].
158
+ pos = diff_cur / (diff_cur + diff_next)
159
+ extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
160
+ extra_x_vals.append(extra_x_val)
161
+ if len(extra_x_vals) > 0:
162
+ x_vals = sorted(set(x_vals + extra_x_vals))
163
+ y_vals1 = [self(x) for x in x_vals]
164
+ y_vals2 = [p(x) for x in x_vals]
165
+ return (
166
+ PiecewiseLinear(*zip(x_vals, y_vals1)),
167
+ PiecewiseLinear(*zip(x_vals, y_vals2)),
168
+ )
169
+
170
+
171
+ class ScheduledFloat(torch.nn.Module):
172
+ """
173
+ This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
174
+ it does not have a working forward() function. You are supposed to cast it to float, as
175
+ in, float(parent_module.whatever), and use it as something like a dropout prob.
176
+
177
+ It is a floating point value whose value changes depending on the batch count of the
178
+ training loop. It is a piecewise linear function where you specify the (x,y) pairs
179
+ in sorted order on x; x corresponds to the batch index. For batch-index values before the
180
+ first x or after the last x, we just use the first or last y value.
181
+
182
+ Example:
183
+ self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
184
+
185
+ `default` is used when self.batch_count is not set or not in training mode or in
186
+ torch.jit scripting mode.
187
+ """
188
+
189
+ def __init__(self, *args, default: float = 0.0):
190
+ super().__init__()
191
+ # self.batch_count and self.name will be written to in the training loop.
192
+ self.batch_count = None
193
+ self.name = None
194
+ self.default = default
195
+ self.schedule = PiecewiseLinear(*args)
196
+
197
+ def extra_repr(self) -> str:
198
+ return (
199
+ f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
200
+ )
201
+
202
+ def __float__(self):
203
+ batch_count = self.batch_count
204
+ if (
205
+ batch_count is None
206
+ or not self.training
207
+ or torch.jit.is_scripting()
208
+ or torch.jit.is_tracing()
209
+ ):
210
+ return float(self.default)
211
+ else:
212
+ ans = self.schedule(self.batch_count)
213
+ if random.random() < 0.0002:
214
+ logging.info(
215
+ f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}"
216
+ )
217
+ return ans
218
+
219
+ def __add__(self, x):
220
+ if isinstance(x, float) or isinstance(x, int):
221
+ return ScheduledFloat(self.schedule + x, default=self.default)
222
+ else:
223
+ return ScheduledFloat(
224
+ self.schedule + x.schedule, default=self.default + x.default
225
+ )
226
+
227
+ def max(self, x):
228
+ if isinstance(x, float) or isinstance(x, int):
229
+ return ScheduledFloat(self.schedule.max(x), default=self.default)
230
+ else:
231
+ return ScheduledFloat(
232
+ self.schedule.max(x.schedule), default=max(self.default, x.default)
233
+ )
234
+
235
+
236
+ FloatLike = Union[float, ScheduledFloat]
237
+
238
+
239
+ def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
240
+ """
241
+ A randomized way of casting a floating point value to half precision.
242
+ """
243
+ if x.dtype == torch.float16:
244
+ return x
245
+ x_abs = x.abs()
246
+ is_too_small = x_abs < min_abs
247
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
248
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
249
+ # for those elements].
250
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
251
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
252
+
253
+
254
+ class CutoffEstimator:
255
+ """
256
+ Estimates cutoffs of an arbitrary numerical quantity such that a specified
257
+ proportion of items will be above the cutoff on average.
258
+
259
+ p is the proportion of items that should be above the cutoff.
260
+ """
261
+
262
+ def __init__(self, p: float):
263
+ self.p = p
264
+ # total count of items
265
+ self.count = 0
266
+ # total count of items that were above the cutoff
267
+ self.count_above = 0
268
+ # initial cutoff value
269
+ self.cutoff = 0
270
+
271
+ def __call__(self, x: float) -> bool:
272
+ """
273
+ Returns true if x is above the cutoff.
274
+ """
275
+ ans = x > self.cutoff
276
+ self.count += 1
277
+ if ans:
278
+ self.count_above += 1
279
+ cur_p = self.count_above / self.count
280
+ delta_p = cur_p - self.p
281
+ if (delta_p > 0) == ans:
282
+ q = abs(delta_p)
283
+ self.cutoff = x * q + self.cutoff * (1 - q)
284
+ return ans
285
+
286
+
287
+ class SoftmaxFunction(torch.autograd.Function):
288
+ """
289
+ Tries to handle half-precision derivatives in a randomized way that should
290
+ be more accurate for training than the default behavior.
291
+ """
292
+
293
+ @staticmethod
294
+ def forward(ctx, x: Tensor, dim: int):
295
+ ans = x.softmax(dim=dim)
296
+ # if x dtype is float16, x.softmax() returns a float32 because
297
+ # (presumably) that op does not support float16, and autocast
298
+ # is enabled.
299
+ if torch.is_autocast_enabled():
300
+ ans = ans.to(torch.get_autocast_gpu_dtype())
301
+ ctx.save_for_backward(ans)
302
+ ctx.x_dtype = x.dtype
303
+ ctx.dim = dim
304
+ return ans
305
+
306
+ @staticmethod
307
+ def backward(ctx, ans_grad: Tensor):
308
+ (ans,) = ctx.saved_tensors
309
+ with torch.cuda.amp.autocast(enabled=False):
310
+ ans_grad = ans_grad.to(torch.float32)
311
+ ans = ans.to(torch.float32)
312
+ x_grad = ans_grad * ans
313
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
314
+ return x_grad, None
315
+
316
+
317
+ def softmax(x: Tensor, dim: int):
318
+ if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
319
+ return x.softmax(dim=dim)
320
+
321
+ return SoftmaxFunction.apply(x, dim)
322
+
323
+
324
+ class MaxEigLimiterFunction(torch.autograd.Function):
325
+ @staticmethod
326
+ def forward(
327
+ ctx,
328
+ x: Tensor,
329
+ coeffs: Tensor,
330
+ direction: Tensor,
331
+ channel_dim: int,
332
+ grad_scale: float,
333
+ ) -> Tensor:
334
+ ctx.channel_dim = channel_dim
335
+ ctx.grad_scale = grad_scale
336
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
337
+ return x
338
+
339
+ @staticmethod
340
+ def backward(ctx, x_grad, *args):
341
+ with torch.enable_grad():
342
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
343
+ x_orig.requires_grad = True
344
+ num_channels = x_orig.shape[ctx.channel_dim]
345
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
346
+ new_direction.requires_grad = False
347
+ x = x - x.mean(dim=0)
348
+ x_var = (x**2).mean()
349
+ x_residual = x - coeffs * new_direction
350
+ x_residual_var = (x_residual**2).mean()
351
+ # `variance_proportion` is the proportion of the variance accounted for
352
+ # by the top eigen-direction. This is to be minimized.
353
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
354
+ variance_proportion.backward()
355
+ x_orig_grad = x_orig.grad
356
+ x_extra_grad = (
357
+ x_orig.grad
358
+ * ctx.grad_scale
359
+ * x_grad.norm()
360
+ / (x_orig_grad.norm() + 1.0e-20)
361
+ )
362
+ return x_grad + x_extra_grad.detach(), None, None, None, None
363
+
364
+
365
+ class BiasNormFunction(torch.autograd.Function):
366
+ # This computes:
367
+ # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
368
+ # return x * scales
369
+ # (after unsqueezing the bias), but it does it in a memory-efficient way so that
370
+ # it can just store the returned value (chances are, this will also be needed for
371
+ # some other reason, related to the next operation, so we can save memory).
372
+ @staticmethod
373
+ def forward(
374
+ ctx,
375
+ x: Tensor,
376
+ bias: Tensor,
377
+ log_scale: Tensor,
378
+ channel_dim: int,
379
+ store_output_for_backprop: bool,
380
+ ) -> Tensor:
381
+ assert bias.ndim == 1
382
+ if channel_dim < 0:
383
+ channel_dim = channel_dim + x.ndim
384
+ ctx.store_output_for_backprop = store_output_for_backprop
385
+ ctx.channel_dim = channel_dim
386
+ for _ in range(channel_dim + 1, x.ndim):
387
+ bias = bias.unsqueeze(-1)
388
+ scales = (
389
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
390
+ ) * log_scale.exp()
391
+ ans = x * scales
392
+ ctx.save_for_backward(
393
+ ans.detach() if store_output_for_backprop else x,
394
+ scales.detach(),
395
+ bias.detach(),
396
+ log_scale.detach(),
397
+ )
398
+ return ans
399
+
400
+ @staticmethod
401
+ def backward(ctx, ans_grad: Tensor) -> Tensor:
402
+ ans_or_x, scales, bias, log_scale = ctx.saved_tensors
403
+ if ctx.store_output_for_backprop:
404
+ x = ans_or_x / scales
405
+ else:
406
+ x = ans_or_x
407
+ x = x.detach()
408
+ x.requires_grad = True
409
+ bias.requires_grad = True
410
+ log_scale.requires_grad = True
411
+ with torch.enable_grad():
412
+ # recompute scales from x, bias and log_scale.
413
+ scales = (
414
+ torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
415
+ ) * log_scale.exp()
416
+ ans = x * scales
417
+ ans.backward(gradient=ans_grad)
418
+ return x.grad, bias.grad.flatten(), log_scale.grad, None, None
419
+
420
+
421
+ class BiasNorm(torch.nn.Module):
422
+ """
423
+ This is intended to be a simpler, and hopefully cheaper, replacement for
424
+ LayerNorm. The observation this is based on, is that Transformer-type
425
+ networks, especially with pre-norm, sometimes seem to set one of the
426
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
427
+ the LayerNorm because the output magnitude is then not strongly dependent
428
+ on the other (useful) features. Presumably the weight and bias of the
429
+ LayerNorm are required to allow it to do this.
430
+
431
+ Instead, we give the BiasNorm a trainable bias that it can use when
432
+ computing the scale for normalization. We also give it a (scalar)
433
+ trainable scale on the output.
434
+
435
+
436
+ Args:
437
+ num_channels: the number of channels, e.g. 512.
438
+ channel_dim: the axis/dimension corresponding to the channel,
439
+ interpreted as an offset from the input's ndim if negative.
440
+ This is NOT the num_channels; it should typically be one of
441
+ {-2, -1, 0, 1, 2, 3}.
442
+ log_scale: the initial log-scale that we multiply the output by; this
443
+ is learnable.
444
+ log_scale_min: FloatLike, minimum allowed value of log_scale
445
+ log_scale_max: FloatLike, maximum allowed value of log_scale
446
+ store_output_for_backprop: only possibly affects memory use; recommend
447
+ to set to True if you think the output of this module is more likely
448
+ than the input of this module to be required to be stored for the
449
+ backprop.
450
+ """
451
+
452
+ def __init__(
453
+ self,
454
+ num_channels: int,
455
+ channel_dim: int = -1, # CAUTION: see documentation.
456
+ log_scale: float = 1.0,
457
+ log_scale_min: float = -1.5,
458
+ log_scale_max: float = 1.5,
459
+ store_output_for_backprop: bool = False,
460
+ ) -> None:
461
+ super(BiasNorm, self).__init__()
462
+ self.num_channels = num_channels
463
+ self.channel_dim = channel_dim
464
+ self.log_scale = nn.Parameter(torch.tensor(log_scale))
465
+ self.bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4))
466
+
467
+ self.log_scale_min = log_scale_min
468
+ self.log_scale_max = log_scale_max
469
+
470
+ self.store_output_for_backprop = store_output_for_backprop
471
+
472
+ def forward(self, x: Tensor) -> Tensor:
473
+ assert x.shape[self.channel_dim] == self.num_channels
474
+
475
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
476
+ channel_dim = self.channel_dim
477
+ if channel_dim < 0:
478
+ channel_dim += x.ndim
479
+ bias = self.bias
480
+ for _ in range(channel_dim + 1, x.ndim):
481
+ bias = bias.unsqueeze(-1)
482
+ scales = (
483
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
484
+ ) * self.log_scale.exp()
485
+ return x * scales
486
+
487
+ log_scale = limit_param_value(
488
+ self.log_scale,
489
+ min=float(self.log_scale_min),
490
+ max=float(self.log_scale_max),
491
+ training=self.training,
492
+ )
493
+
494
+ return BiasNormFunction.apply(
495
+ x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop
496
+ )
497
+
498
+
499
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
500
+ """
501
+ Behaves like a constructor of a modified version of nn.Linear
502
+ that gives an easy way to set the default initial parameter scale.
503
+
504
+ Args:
505
+ Accepts the standard args and kwargs that nn.Linear accepts
506
+ e.g. in_features, out_features, bias=False.
507
+
508
+ initial_scale: you can override this if you want to increase
509
+ or decrease the initial magnitude of the module's output
510
+ (affects the initialization of weight_scale and bias_scale).
511
+ Another option, if you want to do something like this, is
512
+ to re-initialize the parameters.
513
+ """
514
+ ans = nn.Linear(*args, **kwargs)
515
+ with torch.no_grad():
516
+ ans.weight[:] *= initial_scale
517
+ if ans.bias is not None:
518
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
519
+ return ans
520
+
521
+
522
+ def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
523
+ """
524
+ Behaves like a constructor of a modified version of nn.Conv1d
525
+ that gives an easy way to set the default initial parameter scale.
526
+
527
+ Args:
528
+ Accepts the standard args and kwargs that nn.Linear accepts
529
+ e.g. in_features, out_features, bias=False.
530
+
531
+ initial_scale: you can override this if you want to increase
532
+ or decrease the initial magnitude of the module's output
533
+ (affects the initialization of weight_scale and bias_scale).
534
+ Another option, if you want to do something like this, is
535
+ to re-initialize the parameters.
536
+ """
537
+ ans = nn.Conv1d(*args, **kwargs)
538
+ with torch.no_grad():
539
+ ans.weight[:] *= initial_scale
540
+ if ans.bias is not None:
541
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
542
+ return ans
543
+
544
+
545
+ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d:
546
+ """
547
+ Behaves like a constructor of a modified version of nn.Conv2d
548
+ that gives an easy way to set the default initial parameter scale.
549
+
550
+ Args:
551
+ Accepts the standard args and kwargs that nn.Linear accepts
552
+ e.g. in_features, out_features, bias=False, but:
553
+ NO PADDING-RELATED ARGS.
554
+
555
+ initial_scale: you can override this if you want to increase
556
+ or decrease the initial magnitude of the module's output
557
+ (affects the initialization of weight_scale and bias_scale).
558
+ Another option, if you want to do something like this, is
559
+ to re-initialize the parameters.
560
+ """
561
+ ans = nn.Conv2d(*args, **kwargs)
562
+ with torch.no_grad():
563
+ ans.weight[:] *= initial_scale
564
+ if ans.bias is not None:
565
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
566
+ return ans
567
+
568
+
569
+ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
570
+ """
571
+ Behaves like a depthwise 1d convolution, except that it is causal in
572
+ a chunkwise way, as if we had a block-triangular attention mask.
573
+ The chunk size is provided at test time (it should probably be
574
+ kept in sync with the attention mask).
575
+
576
+ This has a little more than twice the parameters of a conventional
577
+ depthwise conv1d module: we implement it by having one
578
+ depthwise convolution, of half the width, that is causal (via
579
+ right-padding); and one depthwise convolution that is applied only
580
+ within chunks, that we multiply by a scaling factor which depends
581
+ on the position within the chunk.
582
+
583
+ Args:
584
+ Accepts the standard args and kwargs that nn.Linear accepts
585
+ e.g. in_features, out_features, bias=False.
586
+
587
+ initial_scale: you can override this if you want to increase
588
+ or decrease the initial magnitude of the module's output
589
+ (affects the initialization of weight_scale and bias_scale).
590
+ Another option, if you want to do something like this, is
591
+ to re-initialize the parameters.
592
+ """
593
+
594
+ def __init__(
595
+ self,
596
+ channels: int,
597
+ kernel_size: int,
598
+ initial_scale: float = 1.0,
599
+ bias: bool = True,
600
+ ):
601
+ super().__init__()
602
+ assert kernel_size % 2 == 1
603
+
604
+ half_kernel_size = (kernel_size + 1) // 2
605
+ # will pad manually, on one side.
606
+ self.causal_conv = nn.Conv1d(
607
+ in_channels=channels,
608
+ out_channels=channels,
609
+ groups=channels,
610
+ kernel_size=half_kernel_size,
611
+ padding=0,
612
+ bias=True,
613
+ )
614
+
615
+ self.chunkwise_conv = nn.Conv1d(
616
+ in_channels=channels,
617
+ out_channels=channels,
618
+ groups=channels,
619
+ kernel_size=kernel_size,
620
+ padding=kernel_size // 2,
621
+ bias=bias,
622
+ )
623
+
624
+ # first row is correction factors added to the scale near the left edge of the chunk,
625
+ # second row is correction factors added to the scale near the right edge of the chunk,
626
+ # both of these are added to a default scale of 1.0.
627
+ self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size))
628
+ self.kernel_size = kernel_size
629
+
630
+ with torch.no_grad():
631
+ self.causal_conv.weight[:] *= initial_scale
632
+ self.chunkwise_conv.weight[:] *= initial_scale
633
+ if bias:
634
+ torch.nn.init.uniform_(
635
+ self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale
636
+ )
637
+
638
+ def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
639
+ """Forward function.
640
+
641
+ Args:
642
+ x: a Tensor of shape (batch_size, channels, seq_len)
643
+ chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
644
+ """
645
+ (batch_size, num_channels, seq_len) = x.shape
646
+
647
+ # half_kernel_size = self.kernel_size + 1 // 2
648
+ # left_pad is half_kernel_size - 1 where half_kernel_size is the size used
649
+ # in the causal conv. It's the amount by which we must pad on the left,
650
+ # to make the convolution causal.
651
+ left_pad = self.kernel_size // 2
652
+
653
+ if chunk_size < 0 or chunk_size > seq_len:
654
+ chunk_size = seq_len
655
+ right_pad = -seq_len % chunk_size
656
+
657
+ x = torch.nn.functional.pad(x, (left_pad, right_pad))
658
+
659
+ x_causal = self.causal_conv(x[..., : left_pad + seq_len])
660
+ assert x_causal.shape == (batch_size, num_channels, seq_len)
661
+
662
+ x_chunk = x[..., left_pad:]
663
+ num_chunks = x_chunk.shape[2] // chunk_size
664
+ x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size)
665
+ x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(
666
+ batch_size * num_chunks, num_channels, chunk_size
667
+ )
668
+ x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
669
+
670
+ chunk_scale = self._get_chunk_scale(chunk_size)
671
+
672
+ x_chunk = x_chunk * chunk_scale
673
+ x_chunk = x_chunk.reshape(
674
+ batch_size, num_chunks, num_channels, chunk_size
675
+ ).permute(0, 2, 1, 3)
676
+ x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[
677
+ ..., :seq_len
678
+ ]
679
+
680
+ return x_chunk + x_causal
681
+
682
+ def _get_chunk_scale(self, chunk_size: int):
683
+ """Returns tensor of shape (num_channels, chunk_size) that will be used to
684
+ scale the output of self.chunkwise_conv."""
685
+ left_edge = self.chunkwise_conv_scale[0]
686
+ right_edge = self.chunkwise_conv_scale[1]
687
+ if chunk_size < self.kernel_size:
688
+ left_edge = left_edge[:, :chunk_size]
689
+ right_edge = right_edge[:, -chunk_size:]
690
+ else:
691
+ t = chunk_size - self.kernel_size
692
+ channels = left_edge.shape[0]
693
+ pad = torch.zeros(
694
+ channels, t, device=left_edge.device, dtype=left_edge.dtype
695
+ )
696
+ left_edge = torch.cat((left_edge, pad), dim=-1)
697
+ right_edge = torch.cat((pad, right_edge), dim=-1)
698
+ return 1.0 + (left_edge + right_edge)
699
+
700
+ def streaming_forward(
701
+ self,
702
+ x: Tensor,
703
+ cache: Tensor,
704
+ ) -> Tuple[Tensor, Tensor]:
705
+ """Streaming Forward function.
706
+
707
+ Args:
708
+ x: a Tensor of shape (batch_size, channels, seq_len)
709
+ cache: cached left context of shape (batch_size, channels, left_pad)
710
+ """
711
+ (batch_size, num_channels, seq_len) = x.shape
712
+
713
+ # left_pad is half_kernel_size - 1 where half_kernel_size is the size used
714
+ # in the causal conv. It's the amount by which we must pad on the left,
715
+ # to make the convolution causal.
716
+ left_pad = self.kernel_size // 2
717
+
718
+ # Pad cache
719
+ assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad)
720
+ x = torch.cat([cache, x], dim=2)
721
+ # Update cache
722
+ cache = x[..., -left_pad:]
723
+
724
+ x_causal = self.causal_conv(x)
725
+ assert x_causal.shape == (batch_size, num_channels, seq_len)
726
+
727
+ x_chunk = x[..., left_pad:]
728
+ x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
729
+
730
+ chunk_scale = self._get_chunk_scale(chunk_size=seq_len)
731
+ x_chunk = x_chunk * chunk_scale
732
+
733
+ return x_chunk + x_causal, cache
734
+
735
+
736
+ class BalancerFunction(torch.autograd.Function):
737
+ @staticmethod
738
+ def forward(
739
+ ctx,
740
+ x: Tensor,
741
+ min_mean: float,
742
+ max_mean: float,
743
+ min_rms: float,
744
+ max_rms: float,
745
+ grad_scale: float,
746
+ channel_dim: int,
747
+ ) -> Tensor:
748
+ if channel_dim < 0:
749
+ channel_dim += x.ndim
750
+ ctx.channel_dim = channel_dim
751
+ ctx.save_for_backward(x)
752
+ ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim)
753
+ return x
754
+
755
+ @staticmethod
756
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
757
+ (x,) = ctx.saved_tensors
758
+ (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
759
+
760
+ try:
761
+ with torch.enable_grad():
762
+ with torch.cuda.amp.autocast(enabled=False):
763
+ x = x.to(torch.float32)
764
+ x = x.detach()
765
+ x.requires_grad = True
766
+ mean_dims = [i for i in range(x.ndim) if i != channel_dim]
767
+ uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True)
768
+ mean = x.mean(dim=mean_dims, keepdim=True)
769
+ stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
770
+ rms = uncentered_var.clamp(min=1.0e-20).sqrt()
771
+
772
+ m = mean / stddev
773
+ # part of loss that relates to mean / stddev
774
+ m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
775
+
776
+ # put a much larger scale on the RMS-max-limit loss, so that if both it and the
777
+ # m_loss are violated we fix the RMS loss first.
778
+ rms_clamped = rms.clamp(min=min_rms, max=max_rms)
779
+ r_loss = (rms_clamped / rms).log().abs()
780
+
781
+ loss = m_loss + r_loss
782
+
783
+ loss.backward(gradient=torch.ones_like(loss))
784
+ loss_grad = x.grad
785
+ loss_grad_rms = (
786
+ (loss_grad**2)
787
+ .mean(dim=mean_dims, keepdim=True)
788
+ .sqrt()
789
+ .clamp(min=1.0e-20)
790
+ )
791
+
792
+ loss_grad = loss_grad * (grad_scale / loss_grad_rms)
793
+
794
+ x_grad_float = x_grad.to(torch.float32)
795
+ # scale each element of loss_grad by the absolute value of the corresponding
796
+ # element of x_grad, which we view as a noisy estimate of its magnitude for that
797
+ # (frame and dimension). later we can consider factored versions.
798
+ x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
799
+ x_grad = x_grad_mod.to(x_grad.dtype)
800
+ except Exception as e:
801
+ logging.info(
802
+ f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue."
803
+ )
804
+
805
+ return x_grad, None, None, None, None, None, None
806
+
807
+
808
+ class Balancer(torch.nn.Module):
809
+ """
810
+ Modifies the backpropped derivatives of a function to try to encourage, for
811
+ each channel, that it is positive at least a proportion `threshold` of the
812
+ time. It does this by multiplying negative derivative values by up to
813
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
814
+ interpolated from 1 at the threshold to those extremal values when none
815
+ of the inputs are positive.
816
+
817
+ Args:
818
+ num_channels: the number of channels
819
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
820
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
821
+ min_positive: the minimum, per channel, of the proportion of the time
822
+ that (x > 0), below which we start to modify the derivatives.
823
+ max_positive: the maximum, per channel, of the proportion of the time
824
+ that (x > 0), above which we start to modify the derivatives.
825
+ scale_gain_factor: determines the 'gain' with which we increase the
826
+ change in gradient once the constraints on min_abs and max_abs
827
+ are violated.
828
+ min_abs: the minimum average-absolute-value difference from the mean
829
+ value per channel, which we allow, before we start to modify
830
+ the derivatives to prevent this.
831
+ max_abs: the maximum average-absolute-value difference from the mean
832
+ value per channel, which we allow, before we start to modify
833
+ the derivatives to prevent this.
834
+ prob: determines the minimum probability with which we modify the
835
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
836
+ on each forward(). This is done randomly to prevent all layers
837
+ from doing it at the same time.
838
+ """
839
+
840
+ def __init__(
841
+ self,
842
+ num_channels: int,
843
+ channel_dim: int,
844
+ min_positive: FloatLike = 0.05,
845
+ max_positive: FloatLike = 0.95,
846
+ min_abs: FloatLike = 0.2,
847
+ max_abs: FloatLike = 100.0,
848
+ grad_scale: FloatLike = 0.04,
849
+ prob: Optional[FloatLike] = None,
850
+ ):
851
+ super().__init__()
852
+
853
+ if prob is None:
854
+ prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
855
+ self.prob = prob
856
+ # 5% of the time we will return and do nothing because memory usage is
857
+ # too high.
858
+ self.mem_cutoff = CutoffEstimator(0.05)
859
+
860
+ # actually self.num_channels is no longer needed except for an assertion.
861
+ self.num_channels = num_channels
862
+ self.channel_dim = channel_dim
863
+ self.min_positive = min_positive
864
+ self.max_positive = max_positive
865
+ self.min_abs = min_abs
866
+ self.max_abs = max_abs
867
+ self.grad_scale = grad_scale
868
+
869
+ def forward(self, x: Tensor) -> Tensor:
870
+ if (
871
+ torch.jit.is_scripting()
872
+ or not x.requires_grad
873
+ or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
874
+ ):
875
+ return _no_op(x)
876
+
877
+ prob = float(self.prob)
878
+ if random.random() < prob:
879
+ # The following inner-functions convert from the way we historically specified
880
+ # these limitations, as limits on the absolute value and the proportion of positive
881
+ # values, to limits on the RMS value and the (mean / stddev).
882
+ def _abs_to_rms(x):
883
+ # for normally distributed data, if the expected absolute value is x, the
884
+ # expected rms value will be sqrt(pi/2) * x.
885
+ return 1.25331413732 * x
886
+
887
+ def _proportion_positive_to_mean(x):
888
+ def _atanh(x):
889
+ eps = 1.0e-10
890
+ # eps is to prevent crashes if x is exactly 0 or 1.
891
+ # we'll just end up returning a fairly large value.
892
+ return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0
893
+
894
+ def _approx_inverse_erf(x):
895
+ # 1 / (sqrt(pi) * ln(2)),
896
+ # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions
897
+ # this approximation is extremely crude and gets progressively worse for
898
+ # x very close to -1 or +1, but we mostly care about the "middle" region
899
+ # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
900
+ # and math.erf(0.0407316414078772) = 0.045935330944660666,
901
+ # which is pretty close to 0.05.
902
+ return 0.8139535143 * _atanh(x)
903
+
904
+ # first convert x from the range 0..1 to the range -1..1 which the error
905
+ # function returns
906
+ x = -1 + (2 * x)
907
+ return _approx_inverse_erf(x)
908
+
909
+ min_mean = _proportion_positive_to_mean(float(self.min_positive))
910
+ max_mean = _proportion_positive_to_mean(float(self.max_positive))
911
+ min_rms = _abs_to_rms(float(self.min_abs))
912
+ max_rms = _abs_to_rms(float(self.max_abs))
913
+ grad_scale = float(self.grad_scale)
914
+
915
+ assert x.shape[self.channel_dim] == self.num_channels
916
+
917
+ return BalancerFunction.apply(
918
+ x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim
919
+ )
920
+ else:
921
+ return _no_op(x)
922
+
923
+
924
+ def penalize_abs_values_gt(
925
+ x: Tensor, limit: float, penalty: float, name: str = None
926
+ ) -> Tensor:
927
+ """
928
+ Returns x unmodified, but in backprop will put a penalty for the excess of
929
+ the absolute values of elements of x over the limit "limit". E.g. if
930
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
931
+
932
+ Caution: the value of this penalty will be affected by grad scaling used
933
+ in automatic mixed precision training. For this reasons we use this,
934
+ it shouldn't really matter, or may even be helpful; we just use this
935
+ to disallow really implausible values of scores to be given to softmax.
936
+
937
+ The name is for randomly printed debug info.
938
+ """
939
+ x_sign = x.sign()
940
+ over_limit = (x.abs() - limit) > 0
941
+ # The following is a memory efficient way to penalize the absolute values of
942
+ # x that's over the limit. (The memory efficiency comes when you think
943
+ # about which items torch needs to cache for the autograd, and which ones it
944
+ # can throw away). The numerical value of aux_loss as computed here will
945
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
946
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
947
+ # limit).relu().
948
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
949
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
950
+ # sum() due to how with_loss() works.
951
+ x = with_loss(x, aux_loss, name)
952
+ # you must use x for something, or this will be ineffective.
953
+ return x
954
+
955
+
956
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
957
+ if x.ndim == 2:
958
+ return x.diag()
959
+ else:
960
+ (batch, dim, dim) = x.shape
961
+ x = x.reshape(batch, dim * dim)
962
+ x = x[:, :: dim + 1]
963
+ assert x.shape == (batch, dim)
964
+ return x
965
+
966
+
967
+ def _whitening_metric(x: Tensor, num_groups: int):
968
+ """
969
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
970
+ of the centered feature covariance are the same within each group's covariance matrix
971
+ and also between groups.
972
+ Args:
973
+ x: a Tensor of shape (*, num_channels)
974
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
975
+ Returns:
976
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
977
+ greater than 1.0 otherwise.
978
+ """
979
+ assert x.dtype != torch.float16
980
+ x = x.reshape(-1, x.shape[-1])
981
+ (num_frames, num_channels) = x.shape
982
+ assert num_channels % num_groups == 0
983
+ channels_per_group = num_channels // num_groups
984
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
985
+ # x now has shape (num_groups, num_frames, channels_per_group)
986
+ # subtract the mean so we use the centered, not uncentered, covariance.
987
+ # My experience has been that when we "mess with the gradients" like this,
988
+ # it's better not do anything that tries to move the mean around, because
989
+ # that can easily cause instability.
990
+ x = x - x.mean(dim=1, keepdim=True)
991
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
992
+ x_covar = torch.matmul(x.transpose(1, 2), x)
993
+ x_covar_mean_diag = _diag(x_covar).mean()
994
+ # the following expression is what we'd get if we took the matrix product
995
+ # of each covariance and measured the mean of its trace, i.e.
996
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
997
+ x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
998
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
999
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
1000
+ return metric
1001
+
1002
+
1003
+ class WhiteningPenaltyFunction(torch.autograd.Function):
1004
+ @staticmethod
1005
+ def forward(ctx, x: Tensor, module: nn.Module) -> Tensor:
1006
+ ctx.save_for_backward(x)
1007
+ ctx.module = module
1008
+ return x
1009
+
1010
+ @staticmethod
1011
+ def backward(ctx, x_grad: Tensor):
1012
+ (x_orig,) = ctx.saved_tensors
1013
+ w = ctx.module
1014
+
1015
+ try:
1016
+ with torch.enable_grad():
1017
+ with torch.cuda.amp.autocast(enabled=False):
1018
+ x_detached = x_orig.to(torch.float32).detach()
1019
+ x_detached.requires_grad = True
1020
+
1021
+ metric = _whitening_metric(x_detached, w.num_groups)
1022
+
1023
+ if random.random() < 0.005 or __name__ == "__main__":
1024
+ logging.info(
1025
+ f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
1026
+ f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}"
1027
+ )
1028
+
1029
+ if metric < float(w.whitening_limit):
1030
+ w.prob = w.min_prob
1031
+ return x_grad, None
1032
+ else:
1033
+ w.prob = w.max_prob
1034
+ metric.backward()
1035
+ penalty_grad = x_detached.grad
1036
+ scale = float(w.grad_scale) * (
1037
+ x_grad.to(torch.float32).norm()
1038
+ / (penalty_grad.norm() + 1.0e-20)
1039
+ )
1040
+ penalty_grad = penalty_grad * scale
1041
+ return x_grad + penalty_grad.to(x_grad.dtype), None
1042
+ except Exception as e:
1043
+ logging.info(
1044
+ f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue."
1045
+ )
1046
+ return x_grad, None
1047
+
1048
+
1049
+ class Whiten(nn.Module):
1050
+ def __init__(
1051
+ self,
1052
+ num_groups: int,
1053
+ whitening_limit: FloatLike,
1054
+ prob: Union[float, Tuple[float, float]],
1055
+ grad_scale: FloatLike,
1056
+ ):
1057
+ """
1058
+ Args:
1059
+ num_groups: the number of groups to divide the channel dim into before
1060
+ whitening. We will attempt to make the feature covariance
1061
+ within each group, after mean subtraction, as "white" as possible,
1062
+ while having the same trace across all groups.
1063
+ whitening_limit: a value greater than 1.0, that dictates how much
1064
+ freedom we have to violate the constraints. 1.0 would mean perfectly
1065
+ white, with exactly the same trace across groups; larger values
1066
+ give more freedom. E.g. 2.0.
1067
+ prob: the probability with which we apply the gradient modification
1068
+ (also affects the grad scale). May be supplied as a float,
1069
+ or as a pair (min_prob, max_prob)
1070
+
1071
+ grad_scale: determines the scale on the gradient term from this object,
1072
+ relative to the rest of the gradient on the attention weights.
1073
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
1074
+ """
1075
+ super(Whiten, self).__init__()
1076
+ assert num_groups >= 1
1077
+ assert float(whitening_limit) >= 1
1078
+ assert float(grad_scale) >= 0
1079
+ self.num_groups = num_groups
1080
+ self.whitening_limit = whitening_limit
1081
+ self.grad_scale = grad_scale
1082
+
1083
+ if isinstance(prob, float):
1084
+ prob = (prob, prob)
1085
+ (self.min_prob, self.max_prob) = prob
1086
+ assert 0 < self.min_prob <= self.max_prob <= 1
1087
+ self.prob = self.max_prob
1088
+ self.name = None # will be set in training loop
1089
+
1090
+ def forward(self, x: Tensor) -> Tensor:
1091
+ """
1092
+ In the forward pass, this function just returns the input unmodified.
1093
+ In the backward pass, it will modify the gradients to ensure that the
1094
+ distribution in each group has close to (lambda times I) as the covariance
1095
+ after mean subtraction, with the same lambda across groups.
1096
+ For whitening_limit > 1, there will be more freedom to violate this
1097
+ constraint.
1098
+
1099
+ Args:
1100
+ x: the input of shape (*, num_channels)
1101
+
1102
+ Returns:
1103
+ x, unmodified. You should make sure
1104
+ you use the returned value, or the graph will be freed
1105
+ and nothing will happen in backprop.
1106
+ """
1107
+ grad_scale = float(self.grad_scale)
1108
+ if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
1109
+ return _no_op(x)
1110
+ else:
1111
+ return WhiteningPenaltyFunction.apply(x, self)
1112
+
1113
+
1114
+ class WithLoss(torch.autograd.Function):
1115
+ @staticmethod
1116
+ def forward(ctx, x: Tensor, y: Tensor, name: str):
1117
+ ctx.y_shape = y.shape
1118
+ if random.random() < 0.002 and name is not None:
1119
+ loss_sum = y.sum().item()
1120
+ logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
1121
+ return x
1122
+
1123
+ @staticmethod
1124
+ def backward(ctx, ans_grad: Tensor):
1125
+ return (
1126
+ ans_grad,
1127
+ torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
1128
+ None,
1129
+ )
1130
+
1131
+
1132
+ def with_loss(x, y, name):
1133
+ # returns x but adds y.sum() to the loss function.
1134
+ return WithLoss.apply(x, y, name)
1135
+
1136
+
1137
+ class ScaleGradFunction(torch.autograd.Function):
1138
+ @staticmethod
1139
+ def forward(ctx, x: Tensor, alpha: float) -> Tensor:
1140
+ ctx.alpha = alpha
1141
+ return x
1142
+
1143
+ @staticmethod
1144
+ def backward(ctx, grad: Tensor):
1145
+ return grad * ctx.alpha, None
1146
+
1147
+
1148
+ def scale_grad(x: Tensor, alpha: float):
1149
+ return ScaleGradFunction.apply(x, alpha)
1150
+
1151
+
1152
+ class ScaleGrad(nn.Module):
1153
+ def __init__(self, alpha: float):
1154
+ super().__init__()
1155
+ self.alpha = alpha
1156
+
1157
+ def forward(self, x: Tensor) -> Tensor:
1158
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
1159
+ return x
1160
+ return scale_grad(x, self.alpha)
1161
+
1162
+
1163
+ class LimitParamValue(torch.autograd.Function):
1164
+ @staticmethod
1165
+ def forward(ctx, x: Tensor, min: float, max: float):
1166
+ ctx.save_for_backward(x)
1167
+ assert max >= min
1168
+ ctx.min = min
1169
+ ctx.max = max
1170
+ return x
1171
+
1172
+ @staticmethod
1173
+ def backward(ctx, x_grad: Tensor):
1174
+ (x,) = ctx.saved_tensors
1175
+ # where x < ctx.min, ensure all grads are negative (this will tend to make
1176
+ # x more positive).
1177
+ x_grad = x_grad * torch.where(
1178
+ torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0
1179
+ )
1180
+ # where x > ctx.max, ensure all grads are positive (this will tend to make
1181
+ # x more negative).
1182
+ x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
1183
+ return x_grad, None, None
1184
+
1185
+
1186
+ def limit_param_value(
1187
+ x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True
1188
+ ):
1189
+ # You apply this to (typically) an nn.Parameter during training to ensure that its
1190
+ # (elements mostly) stays within a supplied range. This is done by modifying the
1191
+ # gradients in backprop.
1192
+ # It's not necessary to do this on every batch: do it only some of the time,
1193
+ # to save a little time.
1194
+ if training and random.random() < prob:
1195
+ return LimitParamValue.apply(x, min, max)
1196
+ else:
1197
+ return x
1198
+
1199
+
1200
+ def _no_op(x: Tensor) -> Tensor:
1201
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1202
+ return x
1203
+ else:
1204
+ # a no-op function that will have a node in the autograd graph,
1205
+ # to avoid certain bugs relating to backward hooks
1206
+ return x.chunk(1, dim=-1)[0]
1207
+
1208
+
1209
+ class Identity(torch.nn.Module):
1210
+ def __init__(self):
1211
+ super(Identity, self).__init__()
1212
+
1213
+ def forward(self, x):
1214
+ return _no_op(x)
1215
+
1216
+
1217
+ class DoubleSwishFunction(torch.autograd.Function):
1218
+ """
1219
+ double_swish(x) = x * torch.sigmoid(x-1)
1220
+
1221
+ This is a definition, originally motivated by its close numerical
1222
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
1223
+
1224
+ Memory-efficient derivative computation:
1225
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
1226
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
1227
+ Now, s'(x) = s(x) * (1-s(x)).
1228
+ double_swish'(x) = x * s'(x) + s(x).
1229
+ = x * s(x) * (1-s(x)) + s(x).
1230
+ = double_swish(x) * (1-s(x)) + s(x)
1231
+ ... so we just need to remember s(x) but not x itself.
1232
+ """
1233
+
1234
+ @staticmethod
1235
+ def forward(ctx, x: Tensor) -> Tensor:
1236
+ requires_grad = x.requires_grad
1237
+ if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
1238
+ x = x.to(torch.float32)
1239
+
1240
+ s = torch.sigmoid(x - 1.0)
1241
+ y = x * s
1242
+
1243
+ if requires_grad:
1244
+ deriv = y * (1 - s) + s
1245
+
1246
+ # notes on derivative of x * sigmoid(x - 1):
1247
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
1248
+ # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund
1249
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
1250
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
1251
+ # floors), should be expectation-preserving.
1252
+ floor = -0.044
1253
+ ceil = 1.2
1254
+ d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
1255
+ deriv
1256
+ )
1257
+ if __name__ == "__main__":
1258
+ # for self-testing only.
1259
+ assert d_scaled.min() >= 0.0
1260
+ assert d_scaled.max() < 256.0
1261
+ d_int = d_scaled.to(torch.uint8)
1262
+ ctx.save_for_backward(d_int)
1263
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1264
+ y = y.to(torch.float16)
1265
+ return y
1266
+
1267
+ @staticmethod
1268
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1269
+ (d,) = ctx.saved_tensors
1270
+ # the same constants as used in forward pass.
1271
+ floor = -0.043637
1272
+ ceil = 1.2
1273
+
1274
+ d = d * ((ceil - floor) / 255.0) + floor
1275
+ return y_grad * d
1276
+
1277
+
1278
+ class DoubleSwish(torch.nn.Module):
1279
+ def __init__(self):
1280
+ super().__init__()
1281
+
1282
+ def forward(self, x: Tensor) -> Tensor:
1283
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
1284
+ that we approximate closely with x * sigmoid(x-1).
1285
+ """
1286
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1287
+ return x * torch.sigmoid(x - 1.0)
1288
+ return DoubleSwishFunction.apply(x)
1289
+
1290
+
1291
+ # Dropout2 is just like normal dropout, except it supports schedules on the dropout rates.
1292
+ class Dropout2(nn.Module):
1293
+ def __init__(self, p: FloatLike):
1294
+ super().__init__()
1295
+ self.p = p
1296
+
1297
+ def forward(self, x: Tensor) -> Tensor:
1298
+ return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
1299
+
1300
+
1301
+ class MulForDropout3(torch.autograd.Function):
1302
+ # returns (x * y * alpha) where alpha is a float and y doesn't require
1303
+ # grad and is zero-or-one.
1304
+ @staticmethod
1305
+ @custom_fwd
1306
+ def forward(ctx, x, y, alpha):
1307
+ assert not y.requires_grad
1308
+ ans = x * y * alpha
1309
+ ctx.save_for_backward(ans)
1310
+ ctx.alpha = alpha
1311
+ return ans
1312
+
1313
+ @staticmethod
1314
+ @custom_bwd
1315
+ def backward(ctx, ans_grad):
1316
+ (ans,) = ctx.saved_tensors
1317
+ x_grad = ctx.alpha * ans_grad * (ans != 0)
1318
+ return x_grad, None, None
1319
+
1320
+
1321
+ # Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
1322
+ # and it lets you choose one dimension to share the dropout mask over
1323
+ class Dropout3(nn.Module):
1324
+ def __init__(self, p: FloatLike, shared_dim: int):
1325
+ super().__init__()
1326
+ self.p = p
1327
+ self.shared_dim = shared_dim
1328
+
1329
+ def forward(self, x: Tensor) -> Tensor:
1330
+ p = float(self.p)
1331
+ if not self.training or p == 0:
1332
+ return _no_op(x)
1333
+ scale = 1.0 / (1 - p)
1334
+ rand_shape = list(x.shape)
1335
+ rand_shape[self.shared_dim] = 1
1336
+ mask = torch.rand(*rand_shape, device=x.device) > p
1337
+ ans = MulForDropout3.apply(x, mask, scale)
1338
+ return ans
1339
+
1340
+
1341
+ class SwooshLFunction(torch.autograd.Function):
1342
+ """
1343
+ swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
1344
+ """
1345
+
1346
+ @staticmethod
1347
+ def forward(ctx, x: Tensor) -> Tensor:
1348
+ requires_grad = x.requires_grad
1349
+ if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
1350
+ x = x.to(torch.float32)
1351
+
1352
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1353
+
1354
+ coeff = -0.08
1355
+
1356
+ with torch.cuda.amp.autocast(enabled=False):
1357
+ with torch.enable_grad():
1358
+ x = x.detach()
1359
+ x.requires_grad = True
1360
+ y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035
1361
+
1362
+ if not requires_grad:
1363
+ return y
1364
+
1365
+ y.backward(gradient=torch.ones_like(y))
1366
+
1367
+ grad = x.grad
1368
+ floor = coeff
1369
+ ceil = 1.0 + coeff + 0.005
1370
+
1371
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
1372
+ grad
1373
+ )
1374
+ if __name__ == "__main__":
1375
+ # for self-testing only.
1376
+ assert d_scaled.min() >= 0.0
1377
+ assert d_scaled.max() < 256.0
1378
+
1379
+ d_int = d_scaled.to(torch.uint8)
1380
+ ctx.save_for_backward(d_int)
1381
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1382
+ y = y.to(torch.get_autocast_gpu_dtype())
1383
+ return y
1384
+
1385
+ @staticmethod
1386
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1387
+ (d,) = ctx.saved_tensors
1388
+ # the same constants as used in forward pass.
1389
+
1390
+ coeff = -0.08
1391
+ floor = coeff
1392
+ ceil = 1.0 + coeff + 0.005
1393
+ d = d * ((ceil - floor) / 255.0) + floor
1394
+ return y_grad * d
1395
+
1396
+
1397
+ class SwooshL(torch.nn.Module):
1398
+ def forward(self, x: Tensor) -> Tensor:
1399
+ """Return Swoosh-L activation."""
1400
+ return SwooshLFunction.apply(x)
1401
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1402
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1403
+ return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
1404
+ if not x.requires_grad:
1405
+ return k2.swoosh_l_forward(x)
1406
+ else:
1407
+ return k2.swoosh_l(x)
1408
+ # return SwooshLFunction.apply(x)
1409
+
1410
+
1411
+ class SwooshLOnnx(torch.nn.Module):
1412
+ def forward(self, x: Tensor) -> Tensor:
1413
+ """Return Swoosh-L activation."""
1414
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1415
+ return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
1416
+
1417
+
1418
+ class SwooshRFunction(torch.autograd.Function):
1419
+ """
1420
+ swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
1421
+
1422
+ derivatives are between -0.08 and 0.92.
1423
+ """
1424
+
1425
+ @staticmethod
1426
+ def forward(ctx, x: Tensor) -> Tensor:
1427
+ requires_grad = x.requires_grad
1428
+
1429
+ if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
1430
+ x = x.to(torch.float32)
1431
+
1432
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1433
+
1434
+ with torch.cuda.amp.autocast(enabled=False):
1435
+ with torch.enable_grad():
1436
+ x = x.detach()
1437
+ x.requires_grad = True
1438
+ y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
1439
+
1440
+ if not requires_grad:
1441
+ return y
1442
+ y.backward(gradient=torch.ones_like(y))
1443
+
1444
+ grad = x.grad
1445
+ floor = -0.08
1446
+ ceil = 0.925
1447
+
1448
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
1449
+ grad
1450
+ )
1451
+ if __name__ == "__main__":
1452
+ # for self-testing only.
1453
+ assert d_scaled.min() >= 0.0
1454
+ assert d_scaled.max() < 256.0
1455
+
1456
+ d_int = d_scaled.to(torch.uint8)
1457
+ ctx.save_for_backward(d_int)
1458
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1459
+ y = y.to(torch.get_autocast_gpu_dtype())
1460
+ return y
1461
+
1462
+ @staticmethod
1463
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1464
+ (d,) = ctx.saved_tensors
1465
+ # the same constants as used in forward pass.
1466
+ floor = -0.08
1467
+ ceil = 0.925
1468
+ d = d * ((ceil - floor) / 255.0) + floor
1469
+ return y_grad * d
1470
+
1471
+
1472
+ class SwooshR(torch.nn.Module):
1473
+ def forward(self, x: Tensor) -> Tensor:
1474
+ """Return Swoosh-R activation."""
1475
+ # if torch.jit.is_scripting() or torch.jit.is_tracing():
1476
+ return SwooshRFunction.apply(x)
1477
+ if True:
1478
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1479
+ return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
1480
+ if not x.requires_grad:
1481
+ return k2.swoosh_r_forward(x)
1482
+ else:
1483
+ return k2.swoosh_r(x)
1484
+ # return SwooshRFunction.apply(x)
1485
+
1486
+
1487
+ class SwooshROnnx(torch.nn.Module):
1488
+ def forward(self, x: Tensor) -> Tensor:
1489
+ """Return Swoosh-R activation."""
1490
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1491
+ return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687
1492
+
1493
+
1494
+ # simple version of SwooshL that does not redefine the backprop, used in
1495
+ # ActivationDropoutAndLinearFunction.
1496
+ def SwooshLForward(x: Tensor):
1497
+ x_offset = x - 4.0
1498
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
1499
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
1500
+ return log_sum - 0.08 * x - 0.035
1501
+
1502
+
1503
+ # simple version of SwooshR that does not redefine the backprop, used in
1504
+ # ActivationDropoutAndLinearFunction.
1505
+ def SwooshRForward(x: Tensor):
1506
+ x_offset = x - 1.0
1507
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
1508
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
1509
+ return log_sum - 0.08 * x - 0.313261687
1510
+
1511
+
1512
+ class ActivationDropoutAndLinearFunction(torch.autograd.Function):
1513
+ @staticmethod
1514
+ @custom_fwd
1515
+ def forward(
1516
+ ctx,
1517
+ x: Tensor,
1518
+ weight: Tensor,
1519
+ bias: Optional[Tensor],
1520
+ activation: str,
1521
+ dropout_p: float,
1522
+ dropout_shared_dim: Optional[int],
1523
+ ):
1524
+ if dropout_p != 0.0:
1525
+ dropout_shape = list(x.shape)
1526
+ if dropout_shared_dim is not None:
1527
+ dropout_shape[dropout_shared_dim] = 1
1528
+ # else it won't be very memory efficient.
1529
+ dropout_mask = (1.0 / (1.0 - dropout_p)) * (
1530
+ torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p
1531
+ )
1532
+ else:
1533
+ dropout_mask = None
1534
+
1535
+ ctx.save_for_backward(x, weight, bias, dropout_mask)
1536
+
1537
+ ctx.activation = activation
1538
+
1539
+ forward_activation_dict = {
1540
+ "SwooshL": k2.swoosh_l_forward,
1541
+ "SwooshR": k2.swoosh_r_forward,
1542
+ }
1543
+ # it will raise a KeyError if this fails. This will be an error. We let it
1544
+ # propagate to the user.
1545
+ activation_func = forward_activation_dict[activation]
1546
+ x = activation_func(x)
1547
+ if dropout_mask is not None:
1548
+ x = x * dropout_mask
1549
+ x = torch.nn.functional.linear(x, weight, bias)
1550
+ return x
1551
+
1552
+ @staticmethod
1553
+ @custom_bwd
1554
+ def backward(ctx, ans_grad: Tensor):
1555
+ saved = ctx.saved_tensors
1556
+ (x, weight, bias, dropout_mask) = saved
1557
+
1558
+ forward_and_deriv_activation_dict = {
1559
+ "SwooshL": k2.swoosh_l_forward_and_deriv,
1560
+ "SwooshR": k2.swoosh_r_forward_and_deriv,
1561
+ }
1562
+ # the following lines a KeyError if the activation is unrecognized.
1563
+ # This will be an error. We let it propagate to the user.
1564
+ func = forward_and_deriv_activation_dict[ctx.activation]
1565
+
1566
+ y, func_deriv = func(x)
1567
+ if dropout_mask is not None:
1568
+ y = y * dropout_mask
1569
+ # now compute derivative of y w.r.t. weight and bias..
1570
+ # y: (..., in_channels), ans_grad: (..., out_channels),
1571
+ (out_channels, in_channels) = weight.shape
1572
+
1573
+ in_channels = y.shape[-1]
1574
+ g = ans_grad.reshape(-1, out_channels)
1575
+ weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels))
1576
+ y_deriv = torch.matmul(ans_grad, weight)
1577
+ bias_deriv = None if bias is None else g.sum(dim=0)
1578
+ x_deriv = y_deriv * func_deriv
1579
+ if dropout_mask is not None:
1580
+ # order versus func_deriv does not matter
1581
+ x_deriv = x_deriv * dropout_mask
1582
+
1583
+ return x_deriv, weight_deriv, bias_deriv, None, None, None
1584
+
1585
+
1586
+ class ActivationDropoutAndLinear(torch.nn.Module):
1587
+ """
1588
+ This merges an activation function followed by dropout and then a nn.Linear module;
1589
+ it does so in a memory efficient way so that it only stores the input to the whole
1590
+ module. If activation == SwooshL and dropout_shared_dim != None, this will be
1591
+ equivalent to:
1592
+ nn.Sequential(SwooshL(),
1593
+ Dropout3(dropout_p, shared_dim=dropout_shared_dim),
1594
+ ScaledLinear(in_channels, out_channels, bias=bias,
1595
+ initial_scale=initial_scale))
1596
+ If dropout_shared_dim is None, the dropout would be equivalent to
1597
+ Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout
1598
+ mask is smaller.
1599
+
1600
+ Args:
1601
+ in_channels: number of input channels, e.g. 256
1602
+ out_channels: number of output channels, e.g. 256
1603
+ bias: if true, have a bias
1604
+ activation: the activation function, for now just support SwooshL.
1605
+ dropout_p: the dropout probability or schedule (happens after nonlinearity).
1606
+ dropout_shared_dim: the dimension, if any, across which the dropout mask is
1607
+ shared (e.g. the time dimension). If None, this may be less memory
1608
+ efficient if there are modules before this one that cache the input
1609
+ for their backprop (e.g. Balancer or Whiten).
1610
+ """
1611
+
1612
+ def __init__(
1613
+ self,
1614
+ in_channels: int,
1615
+ out_channels: int,
1616
+ bias: bool = True,
1617
+ activation: str = "SwooshL",
1618
+ dropout_p: FloatLike = 0.0,
1619
+ dropout_shared_dim: Optional[int] = -1,
1620
+ initial_scale: float = 1.0,
1621
+ ):
1622
+ super().__init__()
1623
+ # create a temporary module of nn.Linear that we'll steal the
1624
+ # weights and bias from
1625
+ l = ScaledLinear(
1626
+ in_channels, out_channels, bias=bias, initial_scale=initial_scale
1627
+ )
1628
+
1629
+ self.weight = l.weight
1630
+ # register_parameter properly handles making it a parameter when l.bias
1631
+ # is None. I think there is some reason for doing it this way rather
1632
+ # than just setting it to None but I don't know what it is, maybe
1633
+ # something to do with exporting the module..
1634
+ self.register_parameter("bias", l.bias)
1635
+
1636
+ self.activation = activation
1637
+ self.dropout_p = dropout_p
1638
+ self.dropout_shared_dim = dropout_shared_dim
1639
+
1640
+ def forward(self, x: Tensor):
1641
+ # if torch.jit.is_scripting() or torch.jit.is_tracing():
1642
+ if True:
1643
+ if self.activation == "SwooshL":
1644
+ x = SwooshLForward(x)
1645
+ elif self.activation == "SwooshR":
1646
+ x = SwooshRForward(x)
1647
+ else:
1648
+ assert False, self.activation
1649
+ return torch.nn.functional.linear(x, self.weight, self.bias)
1650
+
1651
+ return ActivationDropoutAndLinearFunction.apply(
1652
+ x,
1653
+ self.weight,
1654
+ self.bias,
1655
+ self.activation,
1656
+ float(self.dropout_p),
1657
+ self.dropout_shared_dim,
1658
+ )
1659
+
1660
+
1661
+ def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
1662
+ if num_channels <= x.shape[-1]:
1663
+ return x[..., :num_channels]
1664
+ else:
1665
+ shape = list(x.shape)
1666
+ shape[-1] = num_channels - shape[-1]
1667
+ zeros = torch.zeros(shape, dtype=x.dtype, device=x.device)
1668
+ return torch.cat((x, zeros), dim=-1)
1669
+
1670
+
1671
+ def _test_whiten():
1672
+ for proportion in [0.1, 0.5, 10.0]:
1673
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1674
+ x = torch.randn(100, 128)
1675
+ direction = torch.randn(128)
1676
+ coeffs = torch.randn(100, 1)
1677
+ x += proportion * direction * coeffs
1678
+
1679
+ x.requires_grad = True
1680
+
1681
+ m = Whiten(
1682
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1683
+ ) # grad_scale
1684
+
1685
+ for _ in range(4):
1686
+ y = m(x)
1687
+
1688
+ y_grad = torch.randn_like(x)
1689
+ y.backward(gradient=y_grad)
1690
+
1691
+ if proportion < 0.2:
1692
+ assert torch.allclose(x.grad, y_grad)
1693
+ elif proportion > 1.0:
1694
+ assert not torch.allclose(x.grad, y_grad)
1695
+
1696
+
1697
+ def _test_balancer_sign():
1698
+ probs = torch.arange(0, 1, 0.01)
1699
+ N = 1000
1700
+ x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
1701
+ x = x.detach()
1702
+ x.requires_grad = True
1703
+ m = Balancer(
1704
+ probs.numel(),
1705
+ channel_dim=0,
1706
+ min_positive=0.05,
1707
+ max_positive=0.95,
1708
+ min_abs=0.0,
1709
+ prob=1.0,
1710
+ )
1711
+
1712
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1713
+
1714
+ y = m(x)
1715
+ y.backward(gradient=y_grad)
1716
+ print("_test_balancer_sign: x = ", x)
1717
+ print("_test_balancer_sign: y grad = ", y_grad)
1718
+ print("_test_balancer_sign: x grad = ", x.grad)
1719
+
1720
+
1721
+ def _test_balancer_magnitude():
1722
+ magnitudes = torch.arange(0, 1, 0.01)
1723
+ N = 1000
1724
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
1725
+ x = x.detach()
1726
+ x.requires_grad = True
1727
+ m = Balancer(
1728
+ magnitudes.numel(),
1729
+ channel_dim=0,
1730
+ min_positive=0.0,
1731
+ max_positive=1.0,
1732
+ min_abs=0.2,
1733
+ max_abs=0.7,
1734
+ prob=1.0,
1735
+ )
1736
+
1737
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1738
+
1739
+ y = m(x)
1740
+ y.backward(gradient=y_grad)
1741
+ print("_test_balancer_magnitude: x = ", x)
1742
+ print("_test_balancer_magnitude: y grad = ", y_grad)
1743
+ print("_test_balancer_magnitude: x grad = ", x.grad)
1744
+
1745
+
1746
+ def _test_double_swish_deriv():
1747
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1748
+ x.requires_grad = True
1749
+ m = DoubleSwish()
1750
+
1751
+ tol = (1.2 - (-0.043637)) / 255.0
1752
+ torch.autograd.gradcheck(m, x, atol=tol)
1753
+
1754
+ # for self-test.
1755
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1756
+ x.requires_grad = True
1757
+ y = m(x)
1758
+
1759
+
1760
+ def _test_swooshl_deriv():
1761
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1762
+ x.requires_grad = True
1763
+ m = SwooshL()
1764
+
1765
+ tol = 1.0 / 255.0
1766
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
1767
+
1768
+ # for self-test.
1769
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1770
+ x.requires_grad = True
1771
+ y = m(x)
1772
+
1773
+
1774
+ def _test_swooshr_deriv():
1775
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1776
+ x.requires_grad = True
1777
+ m = SwooshR()
1778
+
1779
+ tol = 1.0 / 255.0
1780
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
1781
+
1782
+ # for self-test.
1783
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1784
+ x.requires_grad = True
1785
+ y = m(x)
1786
+
1787
+
1788
+ def _test_softmax():
1789
+ a = torch.randn(2, 10, dtype=torch.float64)
1790
+ b = a.clone()
1791
+ a.requires_grad = True
1792
+ b.requires_grad = True
1793
+ a.softmax(dim=1)[:, 0].sum().backward()
1794
+ print("a grad = ", a.grad)
1795
+ softmax(b, dim=1)[:, 0].sum().backward()
1796
+ print("b grad = ", b.grad)
1797
+ assert torch.allclose(a.grad, b.grad)
1798
+
1799
+
1800
+ def _test_piecewise_linear():
1801
+ p = PiecewiseLinear((0, 10.0))
1802
+ for x in [-100, 0, 100]:
1803
+ assert p(x) == 10.0
1804
+ p = PiecewiseLinear((0, 10.0), (1, 0.0))
1805
+ for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
1806
+ print("x, y = ", x, y)
1807
+ assert p(x) == y, (x, p(x), y)
1808
+
1809
+ q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
1810
+ x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
1811
+ pq = p.max(q)
1812
+ for x in x_vals:
1813
+ y1 = max(p(x), q(x))
1814
+ y2 = pq(x)
1815
+ assert abs(y1 - y2) < 0.001
1816
+ pq = p.min(q)
1817
+ for x in x_vals:
1818
+ y1 = min(p(x), q(x))
1819
+ y2 = pq(x)
1820
+ assert abs(y1 - y2) < 0.001
1821
+ pq = p + q
1822
+ for x in x_vals:
1823
+ y1 = p(x) + q(x)
1824
+ y2 = pq(x)
1825
+ assert abs(y1 - y2) < 0.001
1826
+
1827
+
1828
+ def _test_activation_dropout_and_linear():
1829
+ in_channels = 20
1830
+ out_channels = 30
1831
+
1832
+ for bias in [True, False]:
1833
+ # actually we don't test for dropout_p != 0.0 because forward functions will give
1834
+ # different answers. This is because we are using the k2 implementation of
1835
+ # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
1836
+ # internally, messing up the random state.
1837
+ for dropout_p in [0.0]:
1838
+ for activation in ["SwooshL", "SwooshR"]:
1839
+ m1 = nn.Sequential(
1840
+ SwooshL() if activation == "SwooshL" else SwooshR(),
1841
+ Dropout3(p=dropout_p, shared_dim=-1),
1842
+ ScaledLinear(
1843
+ in_channels, out_channels, bias=bias, initial_scale=0.5
1844
+ ),
1845
+ )
1846
+ m2 = ActivationDropoutAndLinear(
1847
+ in_channels,
1848
+ out_channels,
1849
+ bias=bias,
1850
+ initial_scale=0.5,
1851
+ activation=activation,
1852
+ dropout_p=dropout_p,
1853
+ )
1854
+ with torch.no_grad():
1855
+ m2.weight[:] = m1[2].weight
1856
+ if bias:
1857
+ m2.bias[:] = m1[2].bias
1858
+ # make sure forward gives same result.
1859
+ x1 = torch.randn(10, in_channels)
1860
+ x1.requires_grad = True
1861
+
1862
+ # TEMP.
1863
+ assert torch.allclose(
1864
+ SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03
1865
+ )
1866
+
1867
+ x2 = x1.clone().detach()
1868
+ x2.requires_grad = True
1869
+ seed = 10
1870
+ torch.manual_seed(seed)
1871
+ y1 = m1(x1)
1872
+ y_grad = torch.randn_like(y1)
1873
+ y1.backward(gradient=y_grad)
1874
+ torch.manual_seed(seed)
1875
+ y2 = m2(x2)
1876
+ y2.backward(gradient=y_grad)
1877
+
1878
+ print(
1879
+ f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}"
1880
+ )
1881
+ print("y1 = ", y1)
1882
+ print("y2 = ", y2)
1883
+ assert torch.allclose(y1, y2, atol=0.02)
1884
+ assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
1885
+ if bias:
1886
+ assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
1887
+ print("x1.grad = ", x1.grad)
1888
+ print("x2.grad = ", x2.grad)
1889
+
1890
+ def isclose(a, b):
1891
+ # return true if cosine similarity is > 0.9.
1892
+ return (a * b).sum() > 0.9 * (
1893
+ (a**2).sum() * (b**2).sum()
1894
+ ).sqrt()
1895
+
1896
+ # the SwooshL() implementation has a noisy gradient due to 1-byte
1897
+ # storage of it.
1898
+ assert isclose(x1.grad, x2.grad)
1899
+
1900
+
1901
+ if __name__ == "__main__":
1902
+ logging.getLogger().setLevel(logging.INFO)
1903
+ torch.set_num_threads(1)
1904
+ torch.set_num_interop_threads(1)
1905
+ _test_piecewise_linear()
1906
+ _test_softmax()
1907
+ _test_whiten()
1908
+ _test_balancer_sign()
1909
+ _test_balancer_magnitude()
1910
+ _test_double_swish_deriv()
1911
+ _test_swooshr_deriv()
1912
+ _test_swooshl_deriv()
1913
+ _test_activation_dropout_and_linear()
subsampling.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2023 Xiaomi Corp. (authors: Daniel Povey,
3
+ # Zengwei Yao)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ from typing import Tuple
20
+ import warnings
21
+
22
+ import torch
23
+ from torch import Tensor, nn
24
+ from scaling import (
25
+ Balancer,
26
+ BiasNorm,
27
+ Dropout3,
28
+ FloatLike,
29
+ Optional,
30
+ ScaledConv2d,
31
+ ScaleGrad,
32
+ ScheduledFloat,
33
+ SwooshL,
34
+ SwooshR,
35
+ Whiten,
36
+ )
37
+
38
+
39
+ class ConvNeXt(nn.Module):
40
+ """
41
+ Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ channels: int,
47
+ hidden_ratio: int = 3,
48
+ kernel_size: Tuple[int, int] = (7, 7),
49
+ layerdrop_rate: FloatLike = None,
50
+ ):
51
+ super().__init__()
52
+ self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
53
+ hidden_channels = channels * hidden_ratio
54
+ if layerdrop_rate is None:
55
+ layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
56
+ self.layerdrop_rate = layerdrop_rate
57
+
58
+ self.depthwise_conv = nn.Conv2d(
59
+ in_channels=channels,
60
+ out_channels=channels,
61
+ groups=channels,
62
+ kernel_size=kernel_size,
63
+ padding=self.padding,
64
+ )
65
+
66
+ self.pointwise_conv1 = nn.Conv2d(
67
+ in_channels=channels, out_channels=hidden_channels, kernel_size=1
68
+ )
69
+
70
+ self.hidden_balancer = Balancer(
71
+ hidden_channels,
72
+ channel_dim=1,
73
+ min_positive=0.3,
74
+ max_positive=1.0,
75
+ min_abs=0.75,
76
+ max_abs=5.0,
77
+ )
78
+
79
+ self.activation = SwooshL()
80
+ self.pointwise_conv2 = ScaledConv2d(
81
+ in_channels=hidden_channels,
82
+ out_channels=channels,
83
+ kernel_size=1,
84
+ initial_scale=0.01,
85
+ )
86
+
87
+ self.out_balancer = Balancer(
88
+ channels,
89
+ channel_dim=1,
90
+ min_positive=0.4,
91
+ max_positive=0.6,
92
+ min_abs=1.0,
93
+ max_abs=6.0,
94
+ )
95
+ self.out_whiten = Whiten(
96
+ num_groups=1,
97
+ whitening_limit=5.0,
98
+ prob=(0.025, 0.25),
99
+ grad_scale=0.01,
100
+ )
101
+
102
+ def forward(self, x: Tensor) -> Tensor:
103
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
104
+ return self.forward_internal(x)
105
+ layerdrop_rate = float(self.layerdrop_rate)
106
+
107
+ if layerdrop_rate != 0.0:
108
+ batch_size = x.shape[0]
109
+ mask = (
110
+ torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
111
+ > layerdrop_rate
112
+ )
113
+ else:
114
+ mask = None
115
+ # turns out this caching idea does not work with --world-size > 1
116
+ # return caching_eval(self.forward_internal, x, mask)
117
+ return self.forward_internal(x, mask)
118
+
119
+ def forward_internal(
120
+ self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
121
+ ) -> Tensor:
122
+ """
123
+ x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
124
+
125
+ The returned value has the same shape as x.
126
+ """
127
+ bypass = x
128
+ x = self.depthwise_conv(x)
129
+ x = self.pointwise_conv1(x)
130
+ x = self.hidden_balancer(x)
131
+ x = self.activation(x)
132
+ x = self.pointwise_conv2(x)
133
+
134
+ if layer_skip_mask is not None:
135
+ x = x * layer_skip_mask
136
+
137
+ x = bypass + x
138
+ x = self.out_balancer(x)
139
+
140
+ if x.requires_grad:
141
+ x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
142
+ x = self.out_whiten(x)
143
+ x = x.transpose(1, 3) # (N, C, H, W)
144
+
145
+ return x
146
+
147
+ def streaming_forward(
148
+ self,
149
+ x: Tensor,
150
+ cached_left_pad: Tensor,
151
+ ) -> Tuple[Tensor, Tensor]:
152
+ """
153
+ Args:
154
+ x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
155
+ cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)
156
+
157
+ Returns:
158
+ - The returned value has the same shape as x.
159
+ - Updated cached_left_pad.
160
+ """
161
+ padding = self.padding
162
+
163
+ # The length without right padding for depth-wise conv
164
+ T = x.size(2) - padding[0]
165
+
166
+ bypass = x[:, :, :T, :]
167
+
168
+ # Pad left side
169
+ assert cached_left_pad.size(2) == padding[0], (
170
+ cached_left_pad.size(2),
171
+ padding[0],
172
+ )
173
+ x = torch.cat([cached_left_pad, x], dim=2)
174
+ # Update cached left padding
175
+ cached_left_pad = x[:, :, T : padding[0] + T, :]
176
+
177
+ # depthwise_conv
178
+ x = torch.nn.functional.conv2d(
179
+ x,
180
+ weight=self.depthwise_conv.weight,
181
+ bias=self.depthwise_conv.bias,
182
+ padding=(0, padding[1]),
183
+ groups=self.depthwise_conv.groups,
184
+ )
185
+ x = self.pointwise_conv1(x)
186
+ x = self.hidden_balancer(x)
187
+ x = self.activation(x)
188
+ x = self.pointwise_conv2(x)
189
+
190
+ x = bypass + x
191
+ return x, cached_left_pad
192
+
193
+
194
+ class Conv2dSubsampling(nn.Module):
195
+ """Convolutional 2D subsampling (to 1/2 length).
196
+
197
+ Convert an input of shape (N, T, idim) to an output
198
+ with shape (N, T', odim), where
199
+ T' = (T-3)//2 - 2 == (T-7)//2
200
+
201
+ It is based on
202
+ https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ in_channels: int,
208
+ out_channels: int,
209
+ layer1_channels: int = 8,
210
+ layer2_channels: int = 32,
211
+ layer3_channels: int = 128,
212
+ dropout: FloatLike = 0.1,
213
+ ) -> None:
214
+ """
215
+ Args:
216
+ in_channels:
217
+ Number of channels in. The input shape is (N, T, in_channels).
218
+ Caution: It requires: T >=7, in_channels >=7
219
+ out_channels
220
+ Output dim. The output shape is (N, (T-3)//2, out_channels)
221
+ layer1_channels:
222
+ Number of channels in layer1
223
+ layer1_channels:
224
+ Number of channels in layer2
225
+ bottleneck:
226
+ bottleneck dimension for 1d squeeze-excite
227
+ """
228
+ assert in_channels >= 7
229
+ super().__init__()
230
+
231
+ # The ScaleGrad module is there to prevent the gradients
232
+ # w.r.t. the weight or bias of the first Conv2d module in self.conv from
233
+ # exceeding the range of fp16 when using automatic mixed precision (amp)
234
+ # training. (The second one is necessary to stop its bias from getting
235
+ # a too-large gradient).
236
+
237
+ self.conv = nn.Sequential(
238
+ nn.Conv2d(
239
+ in_channels=1,
240
+ out_channels=layer1_channels,
241
+ kernel_size=3,
242
+ padding=(0, 1), # (time, freq)
243
+ ),
244
+ ScaleGrad(0.2),
245
+ Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
246
+ SwooshR(),
247
+ nn.Conv2d(
248
+ in_channels=layer1_channels,
249
+ out_channels=layer2_channels,
250
+ kernel_size=3,
251
+ stride=2,
252
+ padding=0,
253
+ ),
254
+ Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
255
+ SwooshR(),
256
+ nn.Conv2d(
257
+ in_channels=layer2_channels,
258
+ out_channels=layer3_channels,
259
+ kernel_size=3,
260
+ stride=(1, 2), # (time, freq)
261
+ ),
262
+ Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
263
+ SwooshR(),
264
+ )
265
+
266
+ # just one convnext layer
267
+ self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
268
+
269
+ # (in_channels-3)//4
270
+ self.out_width = (((in_channels - 1) // 2) - 1) // 2
271
+ self.layer3_channels = layer3_channels
272
+
273
+ self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
274
+ # use a larger than normal grad_scale on this whitening module; there is
275
+ # only one such module, so there is not a concern about adding together
276
+ # many copies of this extra gradient term.
277
+ self.out_whiten = Whiten(
278
+ num_groups=1,
279
+ whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
280
+ prob=(0.025, 0.25),
281
+ grad_scale=0.02,
282
+ )
283
+
284
+ # max_log_eps=0.0 is to prevent both eps and the output of self.out from
285
+ # getting large, there is an unnecessary degree of freedom.
286
+ self.out_norm = BiasNorm(out_channels)
287
+ self.dropout = Dropout3(dropout, shared_dim=1)
288
+
289
+ def forward(
290
+ self, x: torch.Tensor, x_lens: torch.Tensor
291
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
292
+ """Subsample x.
293
+
294
+ Args:
295
+ x:
296
+ Its shape is (N, T, idim).
297
+ x_lens:
298
+ A tensor of shape (batch_size,) containing the number of frames in
299
+
300
+ Returns:
301
+ - a tensor of shape (N, (T-7)//2, odim)
302
+ - output lengths, of shape (batch_size,)
303
+ """
304
+ # On entry, x is (N, T, idim)
305
+ x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
306
+ # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
307
+ # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
308
+ # gradients.
309
+ x = self.conv(x)
310
+ x = self.convnext(x)
311
+
312
+ # Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
313
+ b, c, t, f = x.size()
314
+
315
+ x = x.transpose(1, 2).reshape(b, t, c * f)
316
+ # now x: (N, (T-7)//2, out_width * layer3_channels))
317
+
318
+ x = self.out(x)
319
+ # Now x is of shape (N, (T-7)//2, odim)
320
+ x = self.out_whiten(x)
321
+ x = self.out_norm(x)
322
+ x = self.dropout(x)
323
+
324
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
325
+ x_lens = (x_lens - 7) // 2
326
+ else:
327
+ with warnings.catch_warnings():
328
+ warnings.simplefilter("ignore")
329
+ x_lens = (x_lens - 7) // 2
330
+ assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
331
+
332
+ return x, x_lens
333
+
334
+ def streaming_forward(
335
+ self,
336
+ x: torch.Tensor,
337
+ x_lens: torch.Tensor,
338
+ cached_left_pad: Tensor,
339
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
340
+ """Subsample x.
341
+
342
+ Args:
343
+ x:
344
+ Its shape is (N, T, idim).
345
+ x_lens:
346
+ A tensor of shape (batch_size,) containing the number of frames in
347
+
348
+ Returns:
349
+ - a tensor of shape (N, (T-7)//2, odim)
350
+ - output lengths, of shape (batch_size,)
351
+ - updated cache
352
+ """
353
+ # On entry, x is (N, T, idim)
354
+ x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
355
+
356
+ # T' = (T-7)//2
357
+ x = self.conv(x)
358
+
359
+ # T' = (T-7)//2-3
360
+ x, cached_left_pad = self.convnext.streaming_forward(
361
+ x, cached_left_pad=cached_left_pad
362
+ )
363
+
364
+ # Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
365
+ b, c, t, f = x.size()
366
+
367
+ x = x.transpose(1, 2).reshape(b, t, c * f)
368
+ # now x: (N, T', out_width * layer3_channels))
369
+
370
+ x = self.out(x)
371
+ # Now x is of shape (N, T', odim)
372
+ x = self.out_norm(x)
373
+
374
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
375
+ assert self.convnext.padding[0] == 3
376
+ # The ConvNeXt module needs 3 frames of right padding after subsampling
377
+ x_lens = (x_lens - 7) // 2 - 3
378
+ else:
379
+ with warnings.catch_warnings():
380
+ warnings.simplefilter("ignore")
381
+ # The ConvNeXt module needs 3 frames of right padding after subsampling
382
+ assert self.convnext.padding[0] == 3
383
+ x_lens = (x_lens - 7) // 2 - 3
384
+
385
+ assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())
386
+
387
+ return x, x_lens, cached_left_pad
388
+
389
+ @torch.jit.export
390
+ def get_init_states(
391
+ self,
392
+ batch_size: int = 1,
393
+ device: torch.device = torch.device("cpu"),
394
+ ) -> Tensor:
395
+ """Get initial states for Conv2dSubsampling module.
396
+ It is the cached left padding for ConvNeXt module,
397
+ of shape (batch_size, num_channels, left_pad, num_freqs)
398
+ """
399
+ left_pad = self.convnext.padding[0]
400
+ freq = self.out_width
401
+ channels = self.layer3_channels
402
+ cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
403
+ device
404
+ )
405
+
406
+ return cached_embed_left_pad
utilities.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ class ZipformerConfig:
5
+ def __init__(self):
6
+ # 用 _config 存储所有参数
7
+ self._config = {
8
+ "feature_dim": 128,
9
+ "pos_dim": 48,
10
+ "output_downsampling_factor": 2,
11
+ "downsampling_factor": "1,2,4,8,4,2",
12
+ "num_encoder_layers": "2,2,3,4,3,2",
13
+ "feedforward_dim": "512,768,1024,1536,1024,768",
14
+ "encoder_dim": "192,256,448,768,448,192",
15
+ "encoder_unmasked_dim": "192,192,256,256,256,192",
16
+ "cnn_module_kernel": "31,31,15,15,15,31",
17
+ "num_heads": "4,4,4,8,4,4",
18
+ "causal": True,
19
+ }
20
+
21
+ def __getattr__(self, key):
22
+ if key in self._config:
23
+ return self._config[key]
24
+ raise AttributeError(f"'ZipformerConfig' object has no attribute '{key}'")
25
+
26
+ def __setattr__(self, key, value):
27
+ if key == "_config":
28
+ super().__setattr__(key, value)
29
+ else:
30
+ self._config[key] = value
31
+
32
+ def __delattr__(self, key):
33
+ if key in self._config:
34
+ del self._config[key]
35
+ else:
36
+ raise AttributeError(f"'ZipformerConfig' object has no attribute '{key}'")
37
+
38
+ def to_dict(self):
39
+ return dict(self._config)
40
+
41
+ def __repr__(self):
42
+ return f"ZipformerConfig({self._config})"
43
+
44
+
45
+
46
+ def str2bool(v):
47
+ """Used in argparse.ArgumentParser.add_argument to indicate
48
+ that a type is a bool type and user can enter
49
+
50
+ - yes, true, t, y, 1, to represent True
51
+ - no, false, f, n, 0, to represent False
52
+
53
+ See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
54
+ """
55
+ if isinstance(v, bool):
56
+ return v
57
+ if v.lower() in ("yes", "true", "t", "y", "1"):
58
+ return True
59
+ elif v.lower() in ("no", "false", "f", "n", "0"):
60
+ return False
61
+ else:
62
+ raise argparse.ArgumentTypeError("Boolean value expected.")
63
+
64
+
65
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
66
+ """
67
+ Args:
68
+ lengths:
69
+ A 1-D tensor containing sentence lengths.
70
+ max_len:
71
+ The length of masks.
72
+ Returns:
73
+ Return a 2-D bool tensor, where masked positions
74
+ are filled with `True` and non-masked positions are
75
+ filled with `False`.
76
+
77
+ >>> lengths = torch.tensor([1, 3, 2, 5])
78
+ >>> make_pad_mask(lengths)
79
+ tensor([[False, True, True, True, True],
80
+ [False, False, False, True, True],
81
+ [False, False, True, True, True],
82
+ [False, False, False, False, False]])
83
+ """
84
+ assert lengths.ndim == 1, lengths.ndim
85
+ max_len = max(max_len, lengths.max())
86
+ n = lengths.size(0)
87
+ seq_range = torch.arange(0, max_len, device=lengths.device)
88
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
89
+
90
+ return expaned_lengths >= lengths.unsqueeze(-1)
zipformer.py ADDED
@@ -0,0 +1,2469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey,
3
+ # Zengwei Yao)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ import copy
20
+ import math
21
+ import warnings
22
+ from typing import List, Optional, Tuple, Union
23
+ import logging
24
+ import torch
25
+ import random
26
+ from scaling import (
27
+ Balancer,
28
+ BiasNorm,
29
+ Dropout2,
30
+ ChunkCausalDepthwiseConv1d,
31
+ ActivationDropoutAndLinear,
32
+ ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
33
+ Whiten,
34
+ Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
35
+ penalize_abs_values_gt,
36
+ softmax,
37
+ ScheduledFloat,
38
+ FloatLike,
39
+ limit_param_value,
40
+ convert_num_channels,
41
+ )
42
+ from torch import Tensor, nn
43
+
44
+
45
+ class EncoderInterface(nn.Module):
46
+ def forward(
47
+ self, x: torch.Tensor, x_lens: torch.Tensor
48
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
49
+ """
50
+ Args:
51
+ x:
52
+ A tensor of shape (batch_size, input_seq_len, num_features)
53
+ containing the input features.
54
+ x_lens:
55
+ A tensor of shape (batch_size,) containing the number of frames
56
+ in `x` before padding.
57
+ Returns:
58
+ Return a tuple containing two tensors:
59
+ - encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
60
+ containing unnormalized probabilities, i.e., the output of a
61
+ linear layer.
62
+ - encoder_out_lens, a tensor of shape (batch_size,) containing
63
+ the number of frames in `encoder_out` before padding.
64
+ """
65
+ raise NotImplementedError("Please implement it in a subclass")
66
+
67
+
68
+ class Zipformer2(EncoderInterface):
69
+ """
70
+ Args:
71
+
72
+ Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length
73
+ as downsampling_factor if they are single ints or one-element tuples. The length of
74
+ downsampling_factor defines the number of stacks.
75
+
76
+ output_downsampling_factor (int): how much to downsample at the output. Note:
77
+ we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
78
+ You should probably leave this at 2.
79
+ downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
80
+ Note: this is in addition to the downsampling factor of 2 that is applied in
81
+ the frontend (self.encoder_embed).
82
+ encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
83
+ encoder stack.
84
+ num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
85
+ encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
86
+ the encoder stacks for purposes of per-frame dropout (recommend 256 for
87
+ now).
88
+ query_head_dim (int or Tuple[int]): dimension of query and key per attention
89
+ head: per stack, if a tuple..
90
+ pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per
91
+ attention head
92
+ value_head_dim (int or Tuple[int]): dimension of value in each attention head
93
+ num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
94
+ Must be at least 4.
95
+ feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
96
+ cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
97
+
98
+ pos_dim (int): the dimension of each positional-encoding vector prior to projection,
99
+ e.g. 128.
100
+
101
+ dropout (float): dropout rate
102
+ warmup_batches (float): number of batches to warm up over; this controls
103
+ dropout of encoder layers.
104
+ causal (bool): if True, support chunkwise causal convolution. This should
105
+ not hurt WER as no modeling power is lost, but the convolution modules will be
106
+ slightly slower and use more memory. Enables use of the chunk_size and
107
+ left_context_chunks options in forward(), which simulates streaming
108
+ decoding.
109
+ chunk_size: (list of int): only set this to other than [-1] if causal;
110
+ the chunk size will be randomly chosen from this list. -1 means no chunking.
111
+ left_context_frames: (list of int): determines the number of left-
112
+ context chunks for causal training; will be rounded to a number of
113
+ chunks. Must not be less than cnn_module_kernel (after factoring in
114
+ rounding and downsampling); an error will be thrown if this is violated.
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ output_downsampling_factor: int = 2,
120
+ downsampling_factor: Tuple[int] = (2, 4),
121
+ encoder_dim: Union[int, Tuple[int]] = 384,
122
+ num_encoder_layers: Union[int, Tuple[int]] = 4,
123
+ encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
124
+ query_head_dim: Union[int, Tuple[int]] = 24,
125
+ pos_head_dim: Union[int, Tuple[int]] = 4,
126
+ value_head_dim: Union[int, Tuple[int]] = 12,
127
+ num_heads: Union[int, Tuple[int]] = 8,
128
+ feedforward_dim: Union[int, Tuple[int]] = 1536,
129
+ cnn_module_kernel: Union[int, Tuple[int]] = 31,
130
+ pos_dim: int = 192,
131
+ dropout: FloatLike = None, # see code below for default
132
+ warmup_batches: float = 4000.0,
133
+ causal: bool = False,
134
+ chunk_size: Tuple[int] = [-1],
135
+ left_context_frames: Tuple[int] = [-1],
136
+ ) -> None:
137
+ super(Zipformer2, self).__init__()
138
+
139
+ if dropout is None:
140
+ dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
141
+
142
+ def _to_tuple(x):
143
+ """Converts a single int or a 1-tuple of an int to a tuple with the same length
144
+ as downsampling_factor"""
145
+ if isinstance(x, int):
146
+ x = (x,)
147
+ if len(x) == 1:
148
+ x = x * len(downsampling_factor)
149
+ else:
150
+ assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
151
+ return x
152
+
153
+ self.output_downsampling_factor = output_downsampling_factor # int
154
+ self.downsampling_factor = downsampling_factor # tuple
155
+ self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
156
+ self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
157
+ encoder_unmasked_dim
158
+ ) # tuple
159
+ num_encoder_layers = _to_tuple(num_encoder_layers)
160
+ self.num_encoder_layers = num_encoder_layers
161
+ self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
162
+ self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
163
+ pos_head_dim = _to_tuple(pos_head_dim)
164
+ self.num_heads = num_heads = _to_tuple(num_heads)
165
+ feedforward_dim = _to_tuple(feedforward_dim)
166
+ self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
167
+
168
+ self.causal = causal
169
+ self.chunk_size = chunk_size
170
+ self.left_context_frames = left_context_frames
171
+
172
+ for u, d in zip(encoder_unmasked_dim, encoder_dim):
173
+ assert u <= d
174
+
175
+ # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
176
+ encoders = []
177
+
178
+ num_encoders = len(downsampling_factor)
179
+ for i in range(num_encoders):
180
+ encoder_layer = Zipformer2EncoderLayer(
181
+ embed_dim=encoder_dim[i],
182
+ pos_dim=pos_dim,
183
+ num_heads=num_heads[i],
184
+ query_head_dim=query_head_dim[i],
185
+ pos_head_dim=pos_head_dim[i],
186
+ value_head_dim=value_head_dim[i],
187
+ feedforward_dim=feedforward_dim[i],
188
+ dropout=dropout,
189
+ cnn_module_kernel=cnn_module_kernel[i],
190
+ causal=causal,
191
+ )
192
+
193
+ # For the segment of the warmup period, we let the Conv2dSubsampling
194
+ # layer learn something. Then we start to warm up the other encoders.
195
+ encoder = Zipformer2Encoder(
196
+ encoder_layer,
197
+ num_encoder_layers[i],
198
+ pos_dim=pos_dim,
199
+ dropout=dropout,
200
+ warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
201
+ warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
202
+ final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
203
+ )
204
+
205
+ if downsampling_factor[i] != 1:
206
+ encoder = DownsampledZipformer2Encoder(
207
+ encoder,
208
+ dim=encoder_dim[i],
209
+ downsample=downsampling_factor[i],
210
+ dropout=dropout,
211
+ )
212
+
213
+ encoders.append(encoder)
214
+
215
+ self.encoders = nn.ModuleList(encoders)
216
+
217
+ if output_downsampling_factor >= 2:
218
+ self.downsample_output = SimpleDownsample(
219
+ max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout
220
+ )
221
+ else:
222
+ self.downsample_output = None
223
+
224
+
225
+ def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
226
+ """
227
+ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
228
+ randomized feature masks, one per encoder.
229
+ On e.g. 15% of frames, these masks will zero out all enocder dims larger than
230
+ some supplied number, e.g. >256, so in effect on those frames we are using
231
+ a smaller encoer dim.
232
+
233
+ We generate the random masks at this level because we want the 2 masks to 'agree'
234
+ all the way up the encoder stack. This will mean that the 1st mask will have
235
+ mask values repeated self.zipformer_subsampling_factor times.
236
+
237
+ Args:
238
+ x: the embeddings (needed for the shape and dtype and device), of shape
239
+ (1, batch_size, encoder_dims0)
240
+ """
241
+ num_encoders = len(self.encoder_dim)
242
+ if not self.training:
243
+ return [1.0] * num_encoders
244
+
245
+ (num_frames0, batch_size, _encoder_dims0) = x.shape
246
+
247
+ assert self.encoder_dim[0] == _encoder_dims0, (
248
+ self.encoder_dim[0],
249
+ _encoder_dims0,
250
+ )
251
+
252
+ feature_mask_dropout_prob = 0.125
253
+
254
+ # mask1 shape: (1, batch_size, 1)
255
+ mask1 = (
256
+ torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob
257
+ ).to(x.dtype)
258
+
259
+ # mask2 has additional sequences masked, about twice the number.
260
+ mask2 = torch.logical_and(
261
+ mask1,
262
+ (
263
+ torch.rand(1, batch_size, 1, device=x.device)
264
+ > feature_mask_dropout_prob
265
+ ).to(x.dtype),
266
+ )
267
+
268
+ # dim: (1, batch_size, 2)
269
+ mask = torch.cat((mask1, mask2), dim=-1)
270
+
271
+ feature_masks = []
272
+ for i in range(num_encoders):
273
+ channels = self.encoder_dim[i]
274
+ feature_mask = torch.ones(
275
+ 1, batch_size, channels, dtype=x.dtype, device=x.device
276
+ )
277
+ u1 = self.encoder_unmasked_dim[i]
278
+ u2 = u1 + (channels - u1) // 2
279
+
280
+ feature_mask[:, :, u1:u2] *= mask[..., 0:1]
281
+ feature_mask[:, :, u2:] *= mask[..., 1:2]
282
+
283
+ feature_masks.append(feature_mask)
284
+
285
+ return feature_masks
286
+
287
+ def get_chunk_info(self) -> Tuple[int, int]:
288
+ """
289
+ Returns chunk_size and left_context_chunks.
290
+ """
291
+ if not self.causal:
292
+ return -1, -1
293
+
294
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
295
+ assert len(self.chunk_size) == 1, self.chunk_size
296
+ chunk_size = self.chunk_size[0]
297
+ else:
298
+ chunk_size = random.choice(self.chunk_size)
299
+
300
+ if chunk_size == -1:
301
+ left_context_chunks = -1
302
+ else:
303
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
304
+ assert len(self.left_context_frames) == 1, self.left_context_frames
305
+ left_context_frames = self.left_context_frames[0]
306
+ else:
307
+ left_context_frames = random.choice(self.left_context_frames)
308
+ # Note: in Python, -1 // n == -1 for n > 0
309
+ left_context_chunks = left_context_frames // chunk_size
310
+ if left_context_chunks == 0:
311
+ left_context_chunks = 1
312
+
313
+ return chunk_size, left_context_chunks
314
+
315
+ def forward(
316
+ self,
317
+ x: Tensor,
318
+ x_lens: Tensor,
319
+ src_key_padding_mask: Optional[Tensor] = None,
320
+ return_middle_out: bool = False,
321
+ ) -> Tuple[Tensor, Tensor]:
322
+ """
323
+ Args:
324
+ x:
325
+ The input tensor. Its shape is (seq_len, batch_size, feature_dim).
326
+ x_lens:
327
+ A tensor of shape (batch_size,) containing the number of frames in
328
+ `x` before padding.
329
+ src_key_padding_mask:
330
+ The mask for padding, of shape (batch_size, seq_len); True means
331
+ masked position. May be None.
332
+ Returns:
333
+ Return a tuple containing 2 tensors:
334
+ - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
335
+ - lengths, a tensor of shape (batch_size,) containing the number
336
+ of frames in `embeddings` before padding.
337
+ """
338
+ outputs = []
339
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
340
+ feature_masks = [1.0] * len(self.encoder_dim)
341
+ else:
342
+ feature_masks = self.get_feature_masks(x)
343
+
344
+ chunk_size, left_context_chunks = self.get_chunk_info()
345
+
346
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
347
+ # Not support exporting a model for simulating streaming decoding
348
+ attn_mask = None
349
+ else:
350
+ attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
351
+
352
+ for i, module in enumerate(self.encoders):
353
+ ds = self.downsampling_factor[i]
354
+ x = convert_num_channels(x, self.encoder_dim[i])
355
+
356
+ x = module(
357
+ x,
358
+ chunk_size=chunk_size,
359
+ feature_mask=feature_masks[i],
360
+ src_key_padding_mask=(
361
+ None
362
+ if src_key_padding_mask is None
363
+ else src_key_padding_mask[..., ::ds]
364
+ ),
365
+ attn_mask=attn_mask,
366
+ )
367
+ outputs.append(x)
368
+
369
+ # if the last output has the largest dimension, x will be unchanged,
370
+ # it will be the same as outputs[-1]. Otherwise it will be concatenated
371
+ # from different pieces of 'outputs', taking each dimension from the
372
+ # most recent output that has it present.
373
+ x = self._get_full_dim_output(outputs)
374
+
375
+ if self.output_downsampling_factor >= 2:
376
+ x = self.downsample_output(x)
377
+ # class Downsample has this rounding behavior..
378
+ assert self.output_downsampling_factor == 2, self.output_downsampling_factor
379
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
380
+ lengths = (x_lens + 1) // 2
381
+ else:
382
+ with warnings.catch_warnings():
383
+ warnings.simplefilter("ignore")
384
+ lengths = (x_lens + 1) // 2
385
+ else:
386
+ lengths = x_lens
387
+ if return_middle_out:
388
+ return x, lengths, outputs
389
+ else:
390
+ return x, lengths
391
+
392
+ def _get_attn_mask(
393
+ self, x: Tensor, chunk_size: int, left_context_chunks: int
394
+ ) -> Optional[Tensor]:
395
+ """
396
+ Return None if chunk_size == -1, else return attention mask of shape
397
+ (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
398
+ means a masked position.
399
+ Args:
400
+ x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
401
+ chunk_size: chunk size, must divide
402
+ """
403
+ if chunk_size <= 0:
404
+ return None
405
+ assert all(chunk_size % d == 0 for d in self.downsampling_factor)
406
+ if left_context_chunks >= 0:
407
+ num_encoders = len(self.encoder_dim)
408
+ assert all(
409
+ chunk_size * left_context_chunks
410
+ >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
411
+ for i in range(num_encoders)
412
+ )
413
+ else:
414
+ left_context_chunks = 1000000
415
+
416
+ seq_len = x.shape[0]
417
+
418
+ # t is frame index, shape (seq_len,)
419
+ t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
420
+ # c is chunk index for each frame, shape (seq_len,)
421
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
422
+ c = t // chunk_size
423
+ else:
424
+ with warnings.catch_warnings():
425
+ warnings.simplefilter("ignore")
426
+ c = t // chunk_size
427
+ src_c = c
428
+ tgt_c = c.unsqueeze(-1)
429
+
430
+ attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks)
431
+ if __name__ == "__main__":
432
+ logging.info(f"attn_mask = {attn_mask}")
433
+ return attn_mask
434
+
435
+ def _get_full_dim_output(self, outputs: List[Tensor]):
436
+ num_encoders = len(self.encoder_dim)
437
+ assert len(outputs) == num_encoders
438
+ output_dim = max(self.encoder_dim)
439
+ output_pieces = [outputs[-1]]
440
+ cur_dim = self.encoder_dim[-1]
441
+ for i in range(num_encoders - 2, -1, -1):
442
+ d = self.encoder_dim[i]
443
+ if d > cur_dim:
444
+ this_output = outputs[i]
445
+ output_pieces.append(this_output[..., cur_dim:d])
446
+ cur_dim = d
447
+ assert cur_dim == output_dim
448
+ return torch.cat(output_pieces, dim=-1)
449
+
450
+ def streaming_forward(
451
+ self,
452
+ x: Tensor,
453
+ x_lens: Tensor,
454
+ states: List[Tensor],
455
+ src_key_padding_mask: Tensor,
456
+ ) -> Tuple[Tensor, Tensor, List[Tensor]]:
457
+ """
458
+ Args:
459
+ x:
460
+ The input tensor. Its shape is (seq_len, batch_size, feature_dim).
461
+ x_lens:
462
+ A tensor of shape (batch_size,) containing the number of frames in
463
+ `x` before padding.
464
+ states: list of cached tensors of all encoder layers. For layer-i,
465
+ states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
466
+ cached_conv1, cached_conv2).
467
+ src_key_padding_mask:
468
+ The mask for padding, of shape (batch_size, seq_len); True means
469
+ masked position. May be None.
470
+ Returns:
471
+ Return a tuple containing 2 tensors:
472
+ - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
473
+ - lengths, a tensor of shape (batch_size,) containing the number
474
+ of frames in `embeddings` before padding.
475
+ - updated states
476
+ """
477
+ outputs = []
478
+ new_states = []
479
+ layer_offset = 0
480
+
481
+ for i, module in enumerate(self.encoders):
482
+ num_layers = module.num_layers
483
+ ds = self.downsampling_factor[i]
484
+ x = convert_num_channels(x, self.encoder_dim[i])
485
+
486
+ x, new_layer_states = module.streaming_forward(
487
+ x,
488
+ states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
489
+ left_context_len=self.left_context_frames[0] // ds,
490
+ src_key_padding_mask=src_key_padding_mask[..., ::ds],
491
+ )
492
+ layer_offset += num_layers
493
+ outputs.append(x)
494
+ new_states += new_layer_states
495
+
496
+ # if the last output has the largest dimension, x will be unchanged,
497
+ # it will be the same as outputs[-1]. Otherwise it will be concatenated
498
+ # from different pieces of 'outputs', taking each dimension from the
499
+ # most recent output that has it present.
500
+ x = self._get_full_dim_output(outputs)
501
+
502
+ if self.output_downsampling_factor >= 2:
503
+ x = self.downsample_output(x)
504
+ # class Downsample has this rounding behavior..
505
+ assert self.output_downsampling_factor == 2, self.output_downsampling_factor
506
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
507
+ lengths = (x_lens + 1) // 2
508
+ else:
509
+ with warnings.catch_warnings():
510
+ warnings.simplefilter("ignore")
511
+ lengths = (x_lens + 1) // 2
512
+ else:
513
+ lengths = x_lens
514
+
515
+ return x, lengths, new_states
516
+
517
+ @torch.jit.export
518
+ def get_init_states(
519
+ self,
520
+ batch_size: int = 1,
521
+ device: torch.device = torch.device("cpu"),
522
+ ) -> List[Tensor]:
523
+ """Get initial states.
524
+
525
+ A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
526
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
527
+ """
528
+ states = []
529
+ for i, module in enumerate(self.encoders):
530
+ num_layers = module.num_layers
531
+ embed_dim = self.encoder_dim[i]
532
+ ds = self.downsampling_factor[i]
533
+ num_heads = self.num_heads[i]
534
+ key_dim = self.query_head_dim[i] * num_heads
535
+ value_dim = self.value_head_dim[i] * num_heads
536
+ downsample_left = self.left_context_frames[0] // ds
537
+ nonlin_attn_head_dim = 3 * embed_dim // 4
538
+ conv_left_pad = self.cnn_module_kernel[i] // 2
539
+ for layer in range(num_layers):
540
+ cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(
541
+ device
542
+ )
543
+ cached_nonlin_attn = torch.zeros(
544
+ 1, batch_size, downsample_left, nonlin_attn_head_dim
545
+ ).to(device)
546
+ cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(
547
+ device
548
+ )
549
+ cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(
550
+ device
551
+ )
552
+ cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
553
+ device
554
+ )
555
+ cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
556
+ device
557
+ )
558
+ states += [
559
+ cached_key,
560
+ cached_nonlin_attn,
561
+ cached_val1,
562
+ cached_val2,
563
+ cached_conv1,
564
+ cached_conv2,
565
+ ]
566
+
567
+ return states
568
+
569
+
570
+ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
571
+ return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
572
+
573
+
574
+ def _balancer_schedule(min_prob: float):
575
+ return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
576
+
577
+
578
+ class Zipformer2EncoderLayer(nn.Module):
579
+ """
580
+ Args:
581
+ embed_dim: the number of expected features in the input (required).
582
+ nhead: the number of heads in the multiheadattention models (required).
583
+ feedforward_dim: the dimension of the feedforward network model (default=2048).
584
+ dropout: the dropout value (default=0.1).
585
+ cnn_module_kernel (int): Kernel size of convolution module.
586
+
587
+ Examples::
588
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
589
+ >>> src = torch.rand(10, 32, 512)
590
+ >>> pos_emb = torch.rand(32, 19, 512)
591
+ >>> out = encoder_layer(src, pos_emb)
592
+ """
593
+
594
+ def __init__(
595
+ self,
596
+ embed_dim: int,
597
+ pos_dim: int,
598
+ num_heads: int,
599
+ query_head_dim: int,
600
+ pos_head_dim: int,
601
+ value_head_dim: int,
602
+ feedforward_dim: int,
603
+ dropout: FloatLike = 0.1,
604
+ cnn_module_kernel: int = 31,
605
+ causal: bool = False,
606
+ attention_skip_rate: FloatLike = ScheduledFloat(
607
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
608
+ ),
609
+ conv_skip_rate: FloatLike = ScheduledFloat(
610
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
611
+ ),
612
+ const_attention_rate: FloatLike = ScheduledFloat(
613
+ (0.0, 0.25), (4000.0, 0.025), default=0
614
+ ),
615
+ ff2_skip_rate: FloatLike = ScheduledFloat(
616
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
617
+ ),
618
+ ff3_skip_rate: FloatLike = ScheduledFloat(
619
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
620
+ ),
621
+ bypass_skip_rate: FloatLike = ScheduledFloat(
622
+ (0.0, 0.5), (4000.0, 0.02), default=0
623
+ ),
624
+ ) -> None:
625
+ super(Zipformer2EncoderLayer, self).__init__()
626
+ self.embed_dim = embed_dim
627
+
628
+ # self.bypass implements layer skipping as well as bypass; see its default values.
629
+ self.bypass = BypassModule(
630
+ embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
631
+ )
632
+ # bypass_mid is bypass used in the middle of the layer.
633
+ self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
634
+
635
+ # skip probability for dynamic modules (meaning: anything but feedforward).
636
+ self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
637
+ # an additional skip probability that applies to ConvModule to stop it from
638
+ # contributing too much early on.
639
+ self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
640
+
641
+ # ff2_skip_rate is to prevent the ff2 module from having output that's too big
642
+ # compared to its residual.
643
+ self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
644
+ self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
645
+
646
+ self.const_attention_rate = copy.deepcopy(const_attention_rate)
647
+
648
+ self.self_attn_weights = RelPositionMultiheadAttentionWeights(
649
+ embed_dim,
650
+ pos_dim=pos_dim,
651
+ num_heads=num_heads,
652
+ query_head_dim=query_head_dim,
653
+ pos_head_dim=pos_head_dim,
654
+ dropout=0.0,
655
+ )
656
+
657
+ self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
658
+
659
+ self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
660
+
661
+ self.feed_forward1 = FeedforwardModule(
662
+ embed_dim, (feedforward_dim * 3) // 4, dropout
663
+ )
664
+
665
+ self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
666
+
667
+ self.feed_forward3 = FeedforwardModule(
668
+ embed_dim, (feedforward_dim * 5) // 4, dropout
669
+ )
670
+
671
+ self.nonlin_attention = NonlinAttention(
672
+ embed_dim, hidden_channels=3 * embed_dim // 4
673
+ )
674
+
675
+ self.conv_module1 = ConvolutionModule(
676
+ embed_dim, cnn_module_kernel, causal=causal
677
+ )
678
+
679
+ self.conv_module2 = ConvolutionModule(
680
+ embed_dim, cnn_module_kernel, causal=causal
681
+ )
682
+
683
+ # TODO: remove it
684
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
685
+
686
+ self.norm = BiasNorm(embed_dim)
687
+
688
+ self.balancer1 = Balancer(
689
+ embed_dim,
690
+ channel_dim=-1,
691
+ min_positive=0.45,
692
+ max_positive=0.55,
693
+ min_abs=0.2,
694
+ max_abs=4.0,
695
+ )
696
+
697
+ # balancer for output of NonlinAttentionModule
698
+ self.balancer_na = Balancer(
699
+ embed_dim,
700
+ channel_dim=-1,
701
+ min_positive=0.3,
702
+ max_positive=0.7,
703
+ min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
704
+ prob=0.05, # out of concern for memory usage
705
+ )
706
+
707
+ # balancer for output of feedforward2, prevent it from staying too
708
+ # small. give this a very small probability, even at the start of
709
+ # training, it's to fix a rare problem and it's OK to fix it slowly.
710
+ self.balancer_ff2 = Balancer(
711
+ embed_dim,
712
+ channel_dim=-1,
713
+ min_positive=0.3,
714
+ max_positive=0.7,
715
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
716
+ max_abs=2.0,
717
+ prob=0.05,
718
+ )
719
+
720
+ self.balancer_ff3 = Balancer(
721
+ embed_dim,
722
+ channel_dim=-1,
723
+ min_positive=0.3,
724
+ max_positive=0.7,
725
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
726
+ max_abs=4.0,
727
+ prob=0.05,
728
+ )
729
+
730
+ self.whiten = Whiten(
731
+ num_groups=1,
732
+ whitening_limit=_whitening_schedule(4.0, ratio=3.0),
733
+ prob=(0.025, 0.25),
734
+ grad_scale=0.01,
735
+ )
736
+
737
+ self.balancer2 = Balancer(
738
+ embed_dim,
739
+ channel_dim=-1,
740
+ min_positive=0.45,
741
+ max_positive=0.55,
742
+ min_abs=0.1,
743
+ max_abs=4.0,
744
+ )
745
+
746
+ def get_sequence_dropout_mask(
747
+ self, x: Tensor, dropout_rate: float
748
+ ) -> Optional[Tensor]:
749
+ if (
750
+ dropout_rate == 0.0
751
+ or not self.training
752
+ or torch.jit.is_scripting()
753
+ or torch.jit.is_tracing()
754
+ ):
755
+ return None
756
+ batch_size = x.shape[1]
757
+ mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
758
+ return mask
759
+
760
+ def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
761
+ """
762
+ Apply sequence-level dropout to x.
763
+ x shape: (seq_len, batch_size, embed_dim)
764
+ """
765
+ dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
766
+ if dropout_mask is None:
767
+ return x
768
+ else:
769
+ return x * dropout_mask
770
+
771
+ def forward(
772
+ self,
773
+ src: Tensor,
774
+ pos_emb: Tensor,
775
+ chunk_size: int = -1,
776
+ attn_mask: Optional[Tensor] = None,
777
+ src_key_padding_mask: Optional[Tensor] = None,
778
+ ) -> Tensor:
779
+ """
780
+ Pass the input through the encoder layer.
781
+ Args:
782
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
783
+ pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
784
+ chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
785
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
786
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
787
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
788
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
789
+ True means masked position. May be None.
790
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
791
+ masked position. May be None.
792
+
793
+ Returns:
794
+ A tensor which has the same shape as src
795
+ """
796
+ src_orig = src
797
+
798
+ # dropout rate for non-feedforward submodules
799
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
800
+ attention_skip_rate = 0.0
801
+ else:
802
+ attention_skip_rate = (
803
+ float(self.attention_skip_rate) if self.training else 0.0
804
+ )
805
+
806
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
807
+ attn_weights = self.self_attn_weights(
808
+ src,
809
+ pos_emb=pos_emb,
810
+ attn_mask=attn_mask,
811
+ key_padding_mask=src_key_padding_mask,
812
+ )
813
+
814
+ src = src + self.feed_forward1(src)
815
+
816
+ self_attn_dropout_mask = self.get_sequence_dropout_mask(
817
+ src, attention_skip_rate
818
+ )
819
+
820
+ selected_attn_weights = attn_weights[0:1]
821
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
822
+ pass
823
+ elif not self.training and random.random() < float(self.const_attention_rate):
824
+ # Make attention weights constant. The intention is to
825
+ # encourage these modules to do something similar to an
826
+ # averaging-over-time operation.
827
+ # only need the mask, can just use the 1st one and expand later
828
+ selected_attn_weights = selected_attn_weights[0:1]
829
+ selected_attn_weights = (selected_attn_weights > 0.0).to(
830
+ selected_attn_weights.dtype
831
+ )
832
+ selected_attn_weights = selected_attn_weights * (
833
+ 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
834
+ )
835
+
836
+ na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
837
+
838
+ src = src + (
839
+ na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
840
+ )
841
+
842
+ self_attn = self.self_attn1(src, attn_weights)
843
+
844
+ src = src + (
845
+ self_attn
846
+ if self_attn_dropout_mask is None
847
+ else self_attn * self_attn_dropout_mask
848
+ )
849
+
850
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
851
+ conv_skip_rate = 0.0
852
+ else:
853
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
854
+ src = src + self.sequence_dropout(
855
+ self.conv_module1(
856
+ src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
857
+ ),
858
+ conv_skip_rate,
859
+ )
860
+
861
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
862
+ ff2_skip_rate = 0.0
863
+ else:
864
+ ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
865
+ src = src + self.sequence_dropout(
866
+ self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
867
+ )
868
+
869
+ # bypass in the middle of the layer.
870
+ src = self.bypass_mid(src_orig, src)
871
+
872
+ self_attn = self.self_attn2(src, attn_weights)
873
+
874
+ src = src + (
875
+ self_attn
876
+ if self_attn_dropout_mask is None
877
+ else self_attn * self_attn_dropout_mask
878
+ )
879
+
880
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
881
+ conv_skip_rate = 0.0
882
+ else:
883
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
884
+ src = src + self.sequence_dropout(
885
+ self.conv_module2(
886
+ src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
887
+ ),
888
+ conv_skip_rate,
889
+ )
890
+
891
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
892
+ ff3_skip_rate = 0.0
893
+ else:
894
+ ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
895
+ src = src + self.sequence_dropout(
896
+ self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
897
+ )
898
+
899
+ src = self.balancer1(src)
900
+ src = self.norm(src)
901
+
902
+ src = self.bypass(src_orig, src)
903
+
904
+ src = self.balancer2(src)
905
+ src = self.whiten(src)
906
+
907
+ return src
908
+
909
+ def streaming_forward(
910
+ self,
911
+ src: Tensor,
912
+ pos_emb: Tensor,
913
+ cached_key: Tensor,
914
+ cached_nonlin_attn: Tensor,
915
+ cached_val1: Tensor,
916
+ cached_val2: Tensor,
917
+ cached_conv1: Tensor,
918
+ cached_conv2: Tensor,
919
+ left_context_len: int,
920
+ src_key_padding_mask: Tensor,
921
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
922
+ """Pass the input through the encoder layer in streaming forward mode.
923
+
924
+ Args:
925
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
926
+ pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or
927
+ (batch_size, left_context_len+2*seq_len-1, pos_emb_dim)
928
+ cached_key: cached attention key tensor of left context,
929
+ of shape (left_context_len, batch_size, key_dim)
930
+ cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape
931
+ (num_heads, batch_size, left_context_len, head_dim)
932
+ cached_val1: cached left context for the first attention module,
933
+ of shape (left_context_len, batch_size, value_dim)
934
+ cached_val2: cached left context for the second attention module,
935
+ of shape (left_context_len, batch_size, value_dim)
936
+ cached_conv1: cached left context for the first convolution module,
937
+ of shape (batch_size, channels, left_pad)
938
+ cached_conv2: cached left context for the second convolution module,
939
+ of shape (batch_size, channels, left_pad)
940
+ left_context_len: number of left context frames.
941
+ src_key_padding_mask: the mask for padding, of shape
942
+ (batch_size, left_context_len + seq_len); True means masked position.
943
+ May be None.
944
+
945
+ Returns:
946
+ - x, with the same shape as src
947
+ - updated cached_key
948
+ - updated cached_nonlin_attn
949
+ - updated cached_val1
950
+ - updated cached_val2
951
+ - updated cached_conv1
952
+ - updated cached_conv2
953
+ """
954
+ src_orig = src
955
+
956
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
957
+ attn_weights, cached_key = self.self_attn_weights.streaming_forward(
958
+ src,
959
+ pos_emb=pos_emb,
960
+ cached_key=cached_key,
961
+ left_context_len=left_context_len,
962
+ key_padding_mask=src_key_padding_mask,
963
+ )
964
+
965
+ src = src + self.feed_forward1(src)
966
+
967
+ na, cached_nonlin_attn = self.nonlin_attention.streaming_forward(
968
+ src,
969
+ attn_weights[0:1],
970
+ cached_x=cached_nonlin_attn,
971
+ left_context_len=left_context_len,
972
+ )
973
+ src = src + na
974
+
975
+ self_attn, cached_val1 = self.self_attn1.streaming_forward(
976
+ src,
977
+ attn_weights=attn_weights,
978
+ cached_val=cached_val1,
979
+ left_context_len=left_context_len,
980
+ )
981
+ src = src + self_attn
982
+
983
+ src_conv, cached_conv1 = self.conv_module1.streaming_forward(
984
+ src,
985
+ cache=cached_conv1,
986
+ src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
987
+ )
988
+ src = src + src_conv
989
+
990
+ src = src + self.feed_forward2(src)
991
+
992
+ # bypass in the middle of the layer.
993
+ src = self.bypass_mid(src_orig, src)
994
+
995
+ self_attn, cached_val2 = self.self_attn2.streaming_forward(
996
+ src,
997
+ attn_weights=attn_weights,
998
+ cached_val=cached_val2,
999
+ left_context_len=left_context_len,
1000
+ )
1001
+ src = src + self_attn
1002
+
1003
+ src_conv, cached_conv2 = self.conv_module2.streaming_forward(
1004
+ src,
1005
+ cache=cached_conv2,
1006
+ src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
1007
+ )
1008
+ src = src + src_conv
1009
+
1010
+ src = src + self.feed_forward3(src)
1011
+
1012
+ src = self.norm(src)
1013
+
1014
+ src = self.bypass(src_orig, src)
1015
+
1016
+ return (
1017
+ src,
1018
+ cached_key,
1019
+ cached_nonlin_attn,
1020
+ cached_val1,
1021
+ cached_val2,
1022
+ cached_conv1,
1023
+ cached_conv2,
1024
+ )
1025
+
1026
+
1027
+ class Zipformer2Encoder(nn.Module):
1028
+ r"""Zipformer2Encoder is a stack of N encoder layers
1029
+
1030
+ Args:
1031
+ encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
1032
+ num_layers: the number of sub-encoder-layers in the encoder (required).
1033
+ pos_dim: the dimension for the relative positional encoding
1034
+
1035
+ Examples::
1036
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
1037
+ >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
1038
+ >>> src = torch.rand(10, 32, 512)
1039
+ >>> out = zipformer_encoder(src)
1040
+ """
1041
+
1042
+ def __init__(
1043
+ self,
1044
+ encoder_layer: nn.Module,
1045
+ num_layers: int,
1046
+ pos_dim: int,
1047
+ dropout: float,
1048
+ warmup_begin: float,
1049
+ warmup_end: float,
1050
+ initial_layerdrop_rate: float = 0.5,
1051
+ final_layerdrop_rate: float = 0.05,
1052
+ ) -> None:
1053
+ super().__init__()
1054
+ self.encoder_pos = CompactRelPositionalEncoding(
1055
+ pos_dim, dropout_rate=0.15, length_factor=1.0
1056
+ )
1057
+
1058
+ self.layers = nn.ModuleList(
1059
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
1060
+ )
1061
+ self.num_layers = num_layers
1062
+
1063
+ assert 0 <= warmup_begin <= warmup_end
1064
+
1065
+ delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
1066
+ cur_begin = warmup_begin # interpreted as a training batch index
1067
+ for i in range(num_layers):
1068
+ cur_end = cur_begin + delta
1069
+ self.layers[i].bypass.skip_rate = ScheduledFloat(
1070
+ (cur_begin, initial_layerdrop_rate),
1071
+ (cur_end, final_layerdrop_rate),
1072
+ default=0.0,
1073
+ )
1074
+ cur_begin = cur_end
1075
+
1076
+ def forward(
1077
+ self,
1078
+ src: Tensor,
1079
+ chunk_size: int = -1,
1080
+ feature_mask: Union[Tensor, float] = 1.0,
1081
+ attn_mask: Optional[Tensor] = None,
1082
+ src_key_padding_mask: Optional[Tensor] = None,
1083
+ ) -> Tensor:
1084
+ r"""Pass the input through the encoder layers in turn.
1085
+
1086
+ Args:
1087
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
1088
+ chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
1089
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
1090
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
1091
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
1092
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
1093
+ True means masked position. May be None.
1094
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
1095
+ masked position. May be None.
1096
+
1097
+ Returns: a Tensor with the same shape as src.
1098
+ """
1099
+ pos_emb = self.encoder_pos(src)
1100
+ output = src
1101
+
1102
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1103
+ output = output * feature_mask
1104
+
1105
+ for i, mod in enumerate(self.layers):
1106
+ output = mod(
1107
+ output,
1108
+ pos_emb,
1109
+ chunk_size=chunk_size,
1110
+ attn_mask=attn_mask,
1111
+ src_key_padding_mask=src_key_padding_mask,
1112
+ )
1113
+
1114
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1115
+ output = output * feature_mask
1116
+
1117
+ return output
1118
+
1119
+ def streaming_forward(
1120
+ self,
1121
+ src: Tensor,
1122
+ states: List[Tensor],
1123
+ left_context_len: int,
1124
+ src_key_padding_mask: Tensor,
1125
+ ) -> Tuple[Tensor, List[Tensor]]:
1126
+ r"""Pass the input through the encoder layers in turn.
1127
+
1128
+ Args:
1129
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
1130
+ states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
1131
+ (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
1132
+ left_context_len: Number of left context frames.
1133
+ src_key_padding_mask: the mask for padding, of shape
1134
+ (batch_size, left_context_len + seq_len); True means masked position.
1135
+ May be None.
1136
+
1137
+ Returns:
1138
+ - output, a Tensor with the same shape as src.
1139
+ - updated states
1140
+ """
1141
+ pos_emb = self.encoder_pos(src, left_context_len)
1142
+ output = src
1143
+
1144
+ new_states = []
1145
+ for i, mod in enumerate(self.layers):
1146
+ (
1147
+ cached_key,
1148
+ cached_nonlin_attn,
1149
+ cached_val1,
1150
+ cached_val2,
1151
+ cached_conv1,
1152
+ cached_conv2,
1153
+ ) = states[i * 6 : (i + 1) * 6]
1154
+ (
1155
+ output,
1156
+ new_cached_key,
1157
+ new_cached_nonlin_attn,
1158
+ new_cached_val1,
1159
+ new_cached_val2,
1160
+ new_cached_conv1,
1161
+ new_cached_conv2,
1162
+ ) = mod.streaming_forward(
1163
+ output,
1164
+ pos_emb,
1165
+ cached_key=cached_key,
1166
+ cached_nonlin_attn=cached_nonlin_attn,
1167
+ cached_val1=cached_val1,
1168
+ cached_val2=cached_val2,
1169
+ cached_conv1=cached_conv1,
1170
+ cached_conv2=cached_conv2,
1171
+ left_context_len=left_context_len,
1172
+ src_key_padding_mask=src_key_padding_mask,
1173
+ )
1174
+ new_states += [
1175
+ new_cached_key,
1176
+ new_cached_nonlin_attn,
1177
+ new_cached_val1,
1178
+ new_cached_val2,
1179
+ new_cached_conv1,
1180
+ new_cached_conv2,
1181
+ ]
1182
+
1183
+ return output, new_states
1184
+
1185
+
1186
+ class BypassModule(nn.Module):
1187
+ """
1188
+ An nn.Module that implements a learnable bypass scale, and also randomized per-sequence
1189
+ layer-skipping. The bypass is limited during early stages of training to be close to
1190
+ "straight-through", i.e. to not do the bypass operation much initially, in order to
1191
+ force all the modules to learn something.
1192
+ """
1193
+
1194
+ def __init__(
1195
+ self,
1196
+ embed_dim: int,
1197
+ skip_rate: FloatLike = 0.0,
1198
+ straight_through_rate: FloatLike = 0.0,
1199
+ scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
1200
+ scale_max: FloatLike = 1.0,
1201
+ ):
1202
+ super().__init__()
1203
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
1204
+ self.skip_rate = copy.deepcopy(skip_rate)
1205
+ self.straight_through_rate = copy.deepcopy(straight_through_rate)
1206
+ self.scale_min = copy.deepcopy(scale_min)
1207
+ self.scale_max = copy.deepcopy(scale_max)
1208
+
1209
+ def _get_bypass_scale(self, batch_size: int):
1210
+ # returns bypass-scale of shape (num_channels,),
1211
+ # or (batch_size, num_channels,). This is actually the
1212
+ # scale on the non-residual term, so 0 correponds to bypassing
1213
+ # this module.
1214
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
1215
+ return self.bypass_scale
1216
+ else:
1217
+ ans = limit_param_value(
1218
+ self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max)
1219
+ )
1220
+ skip_rate = float(self.skip_rate)
1221
+ if skip_rate != 0.0:
1222
+ mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
1223
+ ans = ans * mask
1224
+ # now ans is of shape (batch_size, num_channels), and is zero for sequences
1225
+ # on which we have randomly chosen to do layer-skipping.
1226
+ straight_through_rate = float(self.straight_through_rate)
1227
+ if straight_through_rate != 0.0:
1228
+ mask = (
1229
+ torch.rand((batch_size, 1), device=ans.device)
1230
+ < straight_through_rate
1231
+ )
1232
+ ans = torch.maximum(ans, mask.to(ans.dtype))
1233
+ return ans
1234
+
1235
+ def forward(self, src_orig: Tensor, src: Tensor):
1236
+ """
1237
+ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
1238
+ Returns: something with the same shape as src and src_orig
1239
+ """
1240
+ bypass_scale = self._get_bypass_scale(src.shape[1])
1241
+ return src_orig + (src - src_orig) * bypass_scale
1242
+
1243
+
1244
+ class DownsampledZipformer2Encoder(nn.Module):
1245
+ r"""
1246
+ DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
1247
+ after convolutional downsampling, and then upsampled again at the output, and combined
1248
+ with the origin input, so that the output has the same shape as the input.
1249
+ """
1250
+
1251
+ def __init__(
1252
+ self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike
1253
+ ):
1254
+ super(DownsampledZipformer2Encoder, self).__init__()
1255
+ self.downsample_factor = downsample
1256
+ self.downsample = SimpleDownsample(dim, downsample, dropout)
1257
+ self.num_layers = encoder.num_layers
1258
+ self.encoder = encoder
1259
+ self.upsample = SimpleUpsample(dim, downsample)
1260
+ self.out_combiner = BypassModule(dim, straight_through_rate=0)
1261
+
1262
+ def forward(
1263
+ self,
1264
+ src: Tensor,
1265
+ chunk_size: int = -1,
1266
+ feature_mask: Union[Tensor, float] = 1.0,
1267
+ attn_mask: Optional[Tensor] = None,
1268
+ src_key_padding_mask: Optional[Tensor] = None,
1269
+ ) -> Tensor:
1270
+ r"""Downsample, go through encoder, upsample.
1271
+
1272
+ Args:
1273
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
1274
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
1275
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
1276
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
1277
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
1278
+ True means masked position. May be None.
1279
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
1280
+ masked position. May be None.
1281
+
1282
+ Returns: a Tensor with the same shape as src.
1283
+ """
1284
+ src_orig = src
1285
+ src = self.downsample(src)
1286
+ ds = self.downsample_factor
1287
+ if attn_mask is not None:
1288
+ attn_mask = attn_mask[::ds, ::ds]
1289
+
1290
+ src = self.encoder(
1291
+ src,
1292
+ chunk_size=chunk_size // ds,
1293
+ feature_mask=feature_mask,
1294
+ attn_mask=attn_mask,
1295
+ src_key_padding_mask=src_key_padding_mask,
1296
+ )
1297
+ src = self.upsample(src)
1298
+ # remove any extra frames that are not a multiple of downsample_factor
1299
+ src = src[: src_orig.shape[0]]
1300
+
1301
+ return self.out_combiner(src_orig, src)
1302
+
1303
+ def streaming_forward(
1304
+ self,
1305
+ src: Tensor,
1306
+ states: List[Tensor],
1307
+ left_context_len: int,
1308
+ src_key_padding_mask: Tensor,
1309
+ ) -> Tuple[Tensor, List[Tensor]]:
1310
+ r"""Downsample, go through encoder, upsample, in streaming forward mode.
1311
+
1312
+ Args:
1313
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
1314
+ states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
1315
+ (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
1316
+ left_context_len: Number of left context frames.
1317
+ src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len);
1318
+ True means masked position. May be None.
1319
+
1320
+ Returns:
1321
+ - output, a Tensor with the same shape as src.
1322
+ - updated states
1323
+ """
1324
+ src_orig = src
1325
+ src = self.downsample(src)
1326
+
1327
+ src, new_states = self.encoder.streaming_forward(
1328
+ src,
1329
+ states=states,
1330
+ left_context_len=left_context_len,
1331
+ src_key_padding_mask=src_key_padding_mask,
1332
+ )
1333
+ src = self.upsample(src)
1334
+ # remove any extra frames that are not a multiple of downsample_factor
1335
+ src = src[: src_orig.shape[0]]
1336
+
1337
+ return self.out_combiner(src_orig, src), new_states
1338
+
1339
+
1340
+ class SimpleDownsample(torch.nn.Module):
1341
+ """
1342
+ Does downsampling with attention, by weighted sum, and a projection..
1343
+ """
1344
+
1345
+ def __init__(self, channels: int, downsample: int, dropout: FloatLike):
1346
+ super(SimpleDownsample, self).__init__()
1347
+
1348
+ self.bias = nn.Parameter(torch.zeros(downsample))
1349
+
1350
+ self.name = None # will be set from training code
1351
+ self.dropout = copy.deepcopy(dropout)
1352
+
1353
+ self.downsample = downsample
1354
+
1355
+ def forward(self, src: Tensor) -> Tensor:
1356
+ """
1357
+ x: (seq_len, batch_size, in_channels)
1358
+ Returns a tensor of shape
1359
+ ( (seq_len+downsample-1)//downsample, batch_size, channels)
1360
+ """
1361
+ (seq_len, batch_size, in_channels) = src.shape
1362
+ ds = self.downsample
1363
+ d_seq_len = (seq_len + ds - 1) // ds
1364
+
1365
+ # Pad to an exact multiple of self.downsample
1366
+ # right-pad src, repeating the last element.
1367
+ pad = d_seq_len * ds - seq_len
1368
+ src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
1369
+ src = torch.cat((src, src_extra), dim=0)
1370
+ assert src.shape[0] == d_seq_len * ds
1371
+
1372
+ src = src.reshape(d_seq_len, ds, batch_size, in_channels)
1373
+
1374
+ weights = self.bias.softmax(dim=0)
1375
+ # weights: (downsample, 1, 1)
1376
+ weights = weights.unsqueeze(-1).unsqueeze(-1)
1377
+
1378
+ # ans1 is the first `in_channels` channels of the output
1379
+ ans = (src * weights).sum(dim=1)
1380
+
1381
+ return ans
1382
+
1383
+
1384
+ class SimpleUpsample(torch.nn.Module):
1385
+ """
1386
+ A very simple form of upsampling that mostly just repeats the input, but
1387
+ also adds a position-specific bias.
1388
+ """
1389
+
1390
+ def __init__(self, num_channels: int, upsample: int):
1391
+ super(SimpleUpsample, self).__init__()
1392
+ self.upsample = upsample
1393
+
1394
+ def forward(self, src: Tensor) -> Tensor:
1395
+ """
1396
+ x: (seq_len, batch_size, num_channels)
1397
+ Returns a tensor of shape
1398
+ ( (seq_len*upsample), batch_size, num_channels)
1399
+ """
1400
+ upsample = self.upsample
1401
+ (seq_len, batch_size, num_channels) = src.shape
1402
+ src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
1403
+ src = src.reshape(seq_len * upsample, batch_size, num_channels)
1404
+ return src
1405
+
1406
+
1407
+ class CompactRelPositionalEncoding(torch.nn.Module):
1408
+ """
1409
+ Relative positional encoding module. This version is "compact" meaning it is able to encode
1410
+ the important information about the relative position in a relatively small number of dimensions.
1411
+ The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001)
1412
+ make very little difference to the embedding. Such differences were potentially important
1413
+ when encoding absolute position, but not important when encoding relative position because there
1414
+ is now no need to compare two large offsets with each other.
1415
+
1416
+ Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval
1417
+ using the atan() function, before doing the fourier transform of that fixed interval. The
1418
+ atan() function would compress the "long tails" too small,
1419
+ making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic
1420
+ function to compress large offsets to a smaller range before applying atan().
1421
+ Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long
1422
+ as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)
1423
+
1424
+
1425
+ Args:
1426
+ embed_dim: Embedding dimension.
1427
+ dropout_rate: Dropout rate.
1428
+ max_len: Maximum input length: just a heuristic for initialization.
1429
+ length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
1430
+ less weight to small differences of offset near the origin.
1431
+ """
1432
+
1433
+ def __init__(
1434
+ self,
1435
+ embed_dim: int,
1436
+ dropout_rate: FloatLike,
1437
+ max_len: int = 1000,
1438
+ length_factor: float = 1.0,
1439
+ ) -> None:
1440
+ """Construct a CompactRelPositionalEncoding object."""
1441
+ super(CompactRelPositionalEncoding, self).__init__()
1442
+ self.embed_dim = embed_dim
1443
+ assert embed_dim % 2 == 0
1444
+ self.dropout = Dropout2(dropout_rate)
1445
+ self.pe = None
1446
+ assert length_factor >= 1.0
1447
+ self.length_factor = length_factor
1448
+ self.extend_pe(torch.tensor(0.0).expand(max_len))
1449
+
1450
+ def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
1451
+ """Reset the positional encodings."""
1452
+ T = x.size(0) + left_context_len
1453
+
1454
+ if self.pe is not None:
1455
+ # self.pe contains both positive and negative parts
1456
+ # the length of self.pe is 2 * input_len - 1
1457
+ if self.pe.size(0) >= T * 2 - 1:
1458
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
1459
+ return
1460
+
1461
+ # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
1462
+ x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
1463
+
1464
+ freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
1465
+
1466
+ # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution
1467
+ # for small time offsets but less resolution for large time offsets.
1468
+ compression_length = self.embed_dim**0.5
1469
+ # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity;
1470
+ # but it does so more slowly than T for large absolute values of T.
1471
+ # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
1472
+ # is important.
1473
+ x_compressed = (
1474
+ compression_length
1475
+ * x.sign()
1476
+ * ((x.abs() + compression_length).log() - math.log(compression_length))
1477
+ )
1478
+
1479
+ # if self.length_factor == 1.0, then length_scale is chosen so that the
1480
+ # FFT can exactly separate points close to the origin (T == 0). So this
1481
+ # part of the formulation is not really heuristic.
1482
+ # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
1483
+ length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
1484
+
1485
+ # note for machine implementations: if atan is not available, we can use:
1486
+ # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
1487
+ # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
1488
+ x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
1489
+
1490
+ cosines = (x_atan * freqs).cos()
1491
+ sines = (x_atan * freqs).sin()
1492
+
1493
+ pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
1494
+ pe[:, 0::2] = cosines
1495
+ pe[:, 1::2] = sines
1496
+ pe[:, -1] = 1.0 # for bias.
1497
+
1498
+ self.pe = pe.to(dtype=x.dtype)
1499
+
1500
+ def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
1501
+ """Create positional encoding.
1502
+
1503
+ Args:
1504
+ x (Tensor): Input tensor (time, batch, `*`).
1505
+ left_context_len: (int): Length of cached left context.
1506
+
1507
+ Returns:
1508
+ positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
1509
+ """
1510
+ self.extend_pe(x, left_context_len)
1511
+ x_size_left = x.size(0) + left_context_len
1512
+ # length of positive side: x.size(0) + left_context_len
1513
+ # length of negative side: x.size(0)
1514
+ pos_emb = self.pe[
1515
+ self.pe.size(0) // 2
1516
+ - x_size_left
1517
+ + 1 : self.pe.size(0) // 2 # noqa E203
1518
+ + x.size(0),
1519
+ :,
1520
+ ]
1521
+ pos_emb = pos_emb.unsqueeze(0)
1522
+ return self.dropout(pos_emb)
1523
+
1524
+
1525
+ class RelPositionMultiheadAttentionWeights(nn.Module):
1526
+ r"""Module that computes multi-head attention weights with relative position encoding.
1527
+ Various other modules consume the resulting attention weights: see, for example, the
1528
+ SimpleAttention module which allows you to compute conventional attention.
1529
+
1530
+ This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context",
1531
+ we have to write up the differences.
1532
+
1533
+
1534
+ Args:
1535
+ embed_dim: number of channels at the input to this module, e.g. 256
1536
+ pos_dim: dimension of the positional encoding vectors, e.g. 128.
1537
+ num_heads: number of heads to compute weights for, e.g. 8
1538
+ query_head_dim: dimension of the query (and key), per head. e.g. 24.
1539
+ pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
1540
+ dropout: dropout probability for attn_output_weights. Default: 0.0.
1541
+ pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
1542
+ any given call to forward(), in training time.
1543
+ """
1544
+
1545
+ def __init__(
1546
+ self,
1547
+ embed_dim: int,
1548
+ pos_dim: int,
1549
+ num_heads: int,
1550
+ query_head_dim: int,
1551
+ pos_head_dim: int,
1552
+ dropout: float = 0.0,
1553
+ pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
1554
+ ) -> None:
1555
+ super().__init__()
1556
+ self.embed_dim = embed_dim
1557
+ self.num_heads = num_heads
1558
+ self.query_head_dim = query_head_dim
1559
+ self.pos_head_dim = pos_head_dim
1560
+ self.dropout = dropout
1561
+ self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
1562
+ self.name = None # will be overwritten in training code; for diagnostics.
1563
+
1564
+ key_head_dim = query_head_dim
1565
+ in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
1566
+
1567
+ # the initial_scale is supposed to take over the "scaling" factor of
1568
+ # head_dim ** -0.5 that has been used in previous forms of attention,
1569
+ # dividing it between the query and key. Note: this module is intended
1570
+ # to be used with the ScaledAdam optimizer; with most other optimizers,
1571
+ # it would be necessary to apply the scaling factor in the forward function.
1572
+ self.in_proj = ScaledLinear(
1573
+ embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
1574
+ )
1575
+
1576
+ self.whiten_keys = Whiten(
1577
+ num_groups=num_heads,
1578
+ whitening_limit=_whitening_schedule(3.0),
1579
+ prob=(0.025, 0.25),
1580
+ grad_scale=0.025,
1581
+ )
1582
+
1583
+ # add a balancer for the keys that runs with very small probability, and
1584
+ # tries to enforce that all dimensions have mean around zero. The
1585
+ # weights produced by this module are invariant to adding a constant to
1586
+ # the keys, so the derivative of the bias is mathematically zero; but
1587
+ # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
1588
+ # bias because the small numerical roundoff tends to have a non-random
1589
+ # sign. This module is intended to prevent that. Use a very small
1590
+ # probability; that should be suffixient to fix the problem.
1591
+ self.balance_keys = Balancer(
1592
+ key_head_dim * num_heads,
1593
+ channel_dim=-1,
1594
+ min_positive=0.4,
1595
+ max_positive=0.6,
1596
+ min_abs=0.0,
1597
+ max_abs=100.0,
1598
+ prob=0.025,
1599
+ )
1600
+
1601
+ # linear transformation for positional encoding.
1602
+ self.linear_pos = ScaledLinear(
1603
+ pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
1604
+ )
1605
+
1606
+ # the following are for diagnosics only, see --print-diagnostics option
1607
+ self.copy_pos_query = Identity()
1608
+ self.copy_query = Identity()
1609
+
1610
+ def forward(
1611
+ self,
1612
+ x: Tensor,
1613
+ pos_emb: Tensor,
1614
+ key_padding_mask: Optional[Tensor] = None,
1615
+ attn_mask: Optional[Tensor] = None,
1616
+ ) -> Tensor:
1617
+ r"""
1618
+ Args:
1619
+ x: input of shape (seq_len, batch_size, embed_dim)
1620
+ pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
1621
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
1622
+ are True in this mask will be ignored as sources in the attention weighting.
1623
+ attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
1624
+ interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
1625
+ saying which positions are allowed to attend to which other positions.
1626
+ Returns:
1627
+ a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
1628
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
1629
+ """
1630
+ x = self.in_proj(x)
1631
+ query_head_dim = self.query_head_dim
1632
+ pos_head_dim = self.pos_head_dim
1633
+ num_heads = self.num_heads
1634
+
1635
+ seq_len, batch_size, _ = x.shape
1636
+
1637
+ query_dim = query_head_dim * num_heads
1638
+
1639
+ # self-attention
1640
+ q = x[..., 0:query_dim]
1641
+ k = x[..., query_dim : 2 * query_dim]
1642
+ # p is the position-encoding query
1643
+ p = x[..., 2 * query_dim :]
1644
+ assert p.shape[-1] == num_heads * pos_head_dim
1645
+
1646
+ q = self.copy_query(q) # for diagnostics only, does nothing.
1647
+ k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
1648
+ p = self.copy_pos_query(p) # for diagnostics only, does nothing.
1649
+
1650
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
1651
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
1652
+ k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
1653
+
1654
+ # time1 refers to target, time2 refers to source.
1655
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
1656
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
1657
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
1658
+
1659
+ attn_scores = torch.matmul(q, k)
1660
+
1661
+ use_pos_scores = False
1662
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1663
+ # We can't put random.random() in the same line
1664
+ use_pos_scores = True
1665
+ elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
1666
+ use_pos_scores = True
1667
+
1668
+ if use_pos_scores:
1669
+ pos_emb = self.linear_pos(pos_emb)
1670
+ seq_len2 = 2 * seq_len - 1
1671
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
1672
+ 2, 0, 3, 1
1673
+ )
1674
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
1675
+
1676
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
1677
+ # [where seq_len2 represents relative position.]
1678
+ pos_scores = torch.matmul(p, pos_emb)
1679
+ # the following .as_strided() expression converts the last axis of pos_scores from relative
1680
+ # to absolute position. I don't know whether I might have got the time-offsets backwards or
1681
+ # not, but let this code define which way round it is supposed to be.
1682
+ if torch.jit.is_tracing():
1683
+ (num_heads, batch_size, time1, n) = pos_scores.shape
1684
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
1685
+ cols = torch.arange(seq_len)
1686
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
1687
+ indexes = rows + cols
1688
+ pos_scores = pos_scores.reshape(-1, n)
1689
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
1690
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
1691
+ else:
1692
+ pos_scores = pos_scores.as_strided(
1693
+ (num_heads, batch_size, seq_len, seq_len),
1694
+ (
1695
+ pos_scores.stride(0),
1696
+ pos_scores.stride(1),
1697
+ pos_scores.stride(2) - pos_scores.stride(3),
1698
+ pos_scores.stride(3),
1699
+ ),
1700
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
1701
+ )
1702
+
1703
+ attn_scores = attn_scores + pos_scores
1704
+
1705
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1706
+ pass
1707
+ elif self.training and random.random() < 0.1:
1708
+ # This is a harder way of limiting the attention scores to not be
1709
+ # too large. It incurs a penalty if any of them has an absolute
1710
+ # value greater than 50.0. this should be outside the normal range
1711
+ # of the attention scores. We use this mechanism instead of, say,
1712
+ # something added to the loss function involving the entropy,
1713
+ # because once the entropy gets very small gradients through the
1714
+ # softmax can become very small, and we'd get zero derivatives. The
1715
+ # choices of 1.0e-04 as the scale on the penalty makes this
1716
+ # mechanism vulnerable to the absolute scale of the loss function,
1717
+ # but we view this as a failsafe to avoid "implausible" parameter
1718
+ # values rather than a regularization method that should be active
1719
+ # under normal circumstances.
1720
+ attn_scores = penalize_abs_values_gt(
1721
+ attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
1722
+ )
1723
+
1724
+ assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
1725
+
1726
+ if attn_mask is not None:
1727
+ assert attn_mask.dtype == torch.bool
1728
+ # use -1000 to avoid nan's where attn_mask and key_padding_mask make
1729
+ # all scores zero. It's important that this be large enough that exp(-1000)
1730
+ # is exactly zero, for reasons related to const_attention_rate, it
1731
+ # compares the final weights with zero.
1732
+ attn_scores = attn_scores.masked_fill(attn_mask, -1000)
1733
+
1734
+ if key_padding_mask is not None:
1735
+ assert key_padding_mask.shape == (
1736
+ batch_size,
1737
+ seq_len,
1738
+ ), key_padding_mask.shape
1739
+ attn_scores = attn_scores.masked_fill(
1740
+ key_padding_mask.unsqueeze(1),
1741
+ -1000,
1742
+ )
1743
+
1744
+ # We use our own version of softmax, defined in scaling.py, which should
1745
+ # save a little of the memory used in backprop by, if we are in
1746
+ # automatic mixed precision mode (amp / autocast), by only storing the
1747
+ # half-precision output for backprop purposes.
1748
+ attn_weights = softmax(attn_scores, dim=-1)
1749
+
1750
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1751
+ pass
1752
+ elif random.random() < 0.001 and not self.training:
1753
+ self._print_attn_entropy(attn_weights)
1754
+
1755
+ attn_weights = nn.functional.dropout(
1756
+ attn_weights, p=self.dropout, training=self.training
1757
+ )
1758
+
1759
+ return attn_weights
1760
+
1761
+ def streaming_forward(
1762
+ self,
1763
+ x: Tensor,
1764
+ pos_emb: Tensor,
1765
+ cached_key: Tensor,
1766
+ left_context_len: int,
1767
+ key_padding_mask: Tensor,
1768
+ ) -> Tuple[Tensor, Tensor]:
1769
+ r"""
1770
+ Args:
1771
+ x: input of shape (seq_len, batch_size, embed_dim)
1772
+ pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim)
1773
+ cached_key: cached attention key tensor of left context,
1774
+ of shape (left_context_len, batch_size, key_dim)
1775
+ left_context_len: number of left context frames.
1776
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
1777
+ are True in this mask will be ignored as sources in the attention weighting.
1778
+
1779
+ Returns:
1780
+ - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2),
1781
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
1782
+ - updated cached attention key tensor of left context.
1783
+ """
1784
+ x = self.in_proj(x)
1785
+ query_head_dim = self.query_head_dim
1786
+ pos_head_dim = self.pos_head_dim
1787
+ num_heads = self.num_heads
1788
+
1789
+ seq_len, batch_size, _ = x.shape
1790
+
1791
+ query_dim = query_head_dim * num_heads
1792
+
1793
+ # self-attention
1794
+ q = x[..., 0:query_dim]
1795
+ k = x[..., query_dim : 2 * query_dim]
1796
+ # p is the position-encoding query
1797
+ p = x[..., 2 * query_dim :]
1798
+ assert p.shape[-1] == num_heads * pos_head_dim
1799
+
1800
+ # Pad cached left contexts
1801
+ assert cached_key.shape[0] == left_context_len, (
1802
+ cached_key.shape[0],
1803
+ left_context_len,
1804
+ )
1805
+ k = torch.cat([cached_key, k], dim=0)
1806
+ # Update cached left contexts
1807
+ cached_key = k[-left_context_len:, ...]
1808
+
1809
+ # The length of key
1810
+ k_len = k.shape[0]
1811
+
1812
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
1813
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
1814
+ k = k.reshape(k_len, batch_size, num_heads, query_head_dim)
1815
+
1816
+ # time1 refers to target, time2 refers to source.
1817
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
1818
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
1819
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
1820
+
1821
+ attn_scores = torch.matmul(q, k)
1822
+
1823
+ pos_emb = self.linear_pos(pos_emb)
1824
+ seq_len2 = 2 * seq_len - 1 + left_context_len
1825
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
1826
+ 2, 0, 3, 1
1827
+ )
1828
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
1829
+
1830
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
1831
+ # [where seq_len2 represents relative position.]
1832
+ pos_scores = torch.matmul(p, pos_emb)
1833
+
1834
+ if torch.jit.is_tracing():
1835
+ (num_heads, batch_size, time1, n) = pos_scores.shape
1836
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
1837
+ cols = torch.arange(k_len)
1838
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
1839
+ indexes = rows + cols
1840
+ pos_scores = pos_scores.reshape(-1, n)
1841
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
1842
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
1843
+ # the following .as_strided() expression converts the last axis of pos_scores from relative
1844
+ # to absolute position. I don't know whether I might have got the time-offsets backwards or
1845
+ # not, but let this code define which way round it is supposed to be.
1846
+ else:
1847
+ pos_scores = pos_scores.as_strided(
1848
+ (num_heads, batch_size, seq_len, k_len),
1849
+ (
1850
+ pos_scores.stride(0),
1851
+ pos_scores.stride(1),
1852
+ pos_scores.stride(2) - pos_scores.stride(3),
1853
+ pos_scores.stride(3),
1854
+ ),
1855
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
1856
+ )
1857
+
1858
+ attn_scores = attn_scores + pos_scores
1859
+
1860
+ assert attn_scores.shape == (
1861
+ num_heads,
1862
+ batch_size,
1863
+ seq_len,
1864
+ k_len,
1865
+ ), attn_scores.shape
1866
+
1867
+ if key_padding_mask is not None:
1868
+ assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape
1869
+ attn_scores = attn_scores.masked_fill(
1870
+ key_padding_mask.unsqueeze(1),
1871
+ -1000,
1872
+ )
1873
+
1874
+ attn_weights = attn_scores.softmax(dim=-1)
1875
+
1876
+ return attn_weights, cached_key
1877
+
1878
+ def _print_attn_entropy(self, attn_weights: Tensor):
1879
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
1880
+ (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
1881
+
1882
+ with torch.no_grad():
1883
+ with torch.cuda.amp.autocast(enabled=False):
1884
+ attn_weights = attn_weights.to(torch.float32)
1885
+ attn_weights_entropy = (
1886
+ -((attn_weights + 1.0e-20).log() * attn_weights)
1887
+ .sum(dim=-1)
1888
+ .mean(dim=(1, 2))
1889
+ )
1890
+ logging.info(
1891
+ f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
1892
+ )
1893
+
1894
+
1895
+ class SelfAttention(nn.Module):
1896
+ """
1897
+ The simplest possible attention module. This one works with already-computed attention
1898
+ weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
1899
+
1900
+ Args:
1901
+ embed_dim: the input and output embedding dimension
1902
+ num_heads: the number of attention heads
1903
+ value_head_dim: the value dimension per head
1904
+ """
1905
+
1906
+ def __init__(
1907
+ self,
1908
+ embed_dim: int,
1909
+ num_heads: int,
1910
+ value_head_dim: int,
1911
+ ) -> None:
1912
+ super().__init__()
1913
+ self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
1914
+
1915
+ self.out_proj = ScaledLinear(
1916
+ num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
1917
+ )
1918
+
1919
+ self.whiten = Whiten(
1920
+ num_groups=1,
1921
+ whitening_limit=_whitening_schedule(7.5, ratio=3.0),
1922
+ prob=(0.025, 0.25),
1923
+ grad_scale=0.01,
1924
+ )
1925
+
1926
+ def forward(
1927
+ self,
1928
+ x: Tensor,
1929
+ attn_weights: Tensor,
1930
+ ) -> Tensor:
1931
+ """
1932
+ Args:
1933
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
1934
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
1935
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
1936
+ attn_weights.sum(dim=-1) == 1.
1937
+ Returns:
1938
+ a tensor with the same shape as x.
1939
+ """
1940
+ (seq_len, batch_size, embed_dim) = x.shape
1941
+ num_heads = attn_weights.shape[0]
1942
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
1943
+
1944
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
1945
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
1946
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
1947
+ value_head_dim = x.shape[-1]
1948
+
1949
+ # todo: see whether there is benefit in overriding matmul
1950
+ x = torch.matmul(attn_weights, x)
1951
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
1952
+
1953
+ x = (
1954
+ x.permute(2, 1, 0, 3)
1955
+ .contiguous()
1956
+ .view(seq_len, batch_size, num_heads * value_head_dim)
1957
+ )
1958
+
1959
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
1960
+ x = self.out_proj(x)
1961
+ x = self.whiten(x)
1962
+
1963
+ return x
1964
+
1965
+ def streaming_forward(
1966
+ self,
1967
+ x: Tensor,
1968
+ attn_weights: Tensor,
1969
+ cached_val: Tensor,
1970
+ left_context_len: int,
1971
+ ) -> Tuple[Tensor, Tensor]:
1972
+ """
1973
+ Args:
1974
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
1975
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
1976
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
1977
+ attn_weights.sum(dim=-1) == 1.
1978
+ cached_val: cached attention value tensor of left context,
1979
+ of shape (left_context_len, batch_size, value_dim)
1980
+ left_context_len: number of left context frames.
1981
+
1982
+ Returns:
1983
+ - attention weighted output, a tensor with the same shape as x.
1984
+ - updated cached attention value tensor of left context.
1985
+ """
1986
+ (seq_len, batch_size, embed_dim) = x.shape
1987
+ num_heads = attn_weights.shape[0]
1988
+ seq_len2 = seq_len + left_context_len
1989
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)
1990
+
1991
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
1992
+
1993
+ # Pad cached left contexts
1994
+ assert cached_val.shape[0] == left_context_len, (
1995
+ cached_val.shape[0],
1996
+ left_context_len,
1997
+ )
1998
+ x = torch.cat([cached_val, x], dim=0)
1999
+ # Update cached left contexts
2000
+ cached_val = x[-left_context_len:, ...]
2001
+
2002
+ x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
2003
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
2004
+ value_head_dim = x.shape[-1]
2005
+
2006
+ # todo: see whether there is benefit in overriding matmul
2007
+ x = torch.matmul(attn_weights, x)
2008
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
2009
+
2010
+ x = (
2011
+ x.permute(2, 1, 0, 3)
2012
+ .contiguous()
2013
+ .view(seq_len, batch_size, num_heads * value_head_dim)
2014
+ )
2015
+
2016
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
2017
+ x = self.out_proj(x)
2018
+
2019
+ return x, cached_val
2020
+
2021
+
2022
+ class FeedforwardModule(nn.Module):
2023
+ """Feedforward module in Zipformer2 model."""
2024
+
2025
+ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
2026
+ super(FeedforwardModule, self).__init__()
2027
+ self.in_proj = nn.Linear(embed_dim, feedforward_dim)
2028
+
2029
+ self.hidden_balancer = Balancer(
2030
+ feedforward_dim,
2031
+ channel_dim=-1,
2032
+ min_positive=0.3,
2033
+ max_positive=1.0,
2034
+ min_abs=0.75,
2035
+ max_abs=5.0,
2036
+ )
2037
+
2038
+ # shared_dim=0 means we share the dropout mask along the time axis
2039
+ self.out_proj = ActivationDropoutAndLinear(
2040
+ feedforward_dim,
2041
+ embed_dim,
2042
+ activation="SwooshL",
2043
+ dropout_p=dropout,
2044
+ dropout_shared_dim=0,
2045
+ bias=True,
2046
+ initial_scale=0.1,
2047
+ )
2048
+
2049
+ self.out_whiten = Whiten(
2050
+ num_groups=1,
2051
+ whitening_limit=_whitening_schedule(7.5),
2052
+ prob=(0.025, 0.25),
2053
+ grad_scale=0.01,
2054
+ )
2055
+
2056
+ def forward(self, x: Tensor):
2057
+ x = self.in_proj(x)
2058
+ x = self.hidden_balancer(x)
2059
+ # out_proj contains SwooshL activation, then dropout, then linear.
2060
+ x = self.out_proj(x)
2061
+ x = self.out_whiten(x)
2062
+ return x
2063
+
2064
+
2065
+ class NonlinAttention(nn.Module):
2066
+ """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
2067
+ from the attention module) in place of actual convolution. We also took out the second nonlinearity, the
2068
+ one after the attention mechanism.
2069
+
2070
+ Args:
2071
+ channels (int): The number of channels of conv layers.
2072
+ """
2073
+
2074
+ def __init__(
2075
+ self,
2076
+ channels: int,
2077
+ hidden_channels: int,
2078
+ ) -> None:
2079
+ super().__init__()
2080
+
2081
+ self.hidden_channels = hidden_channels
2082
+
2083
+ self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
2084
+
2085
+ # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0,
2086
+ # because we noticed that well-trained instances of this module have abs-value before the sigmoid
2087
+ # starting from about 3, and poorly-trained instances of the module have smaller abs values
2088
+ # before the sigmoid.
2089
+ self.balancer = Balancer(
2090
+ hidden_channels,
2091
+ channel_dim=-1,
2092
+ min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
2093
+ max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
2094
+ min_abs=0.5,
2095
+ max_abs=5.0,
2096
+ )
2097
+ self.tanh = nn.Tanh()
2098
+
2099
+ self.identity1 = Identity() # for diagnostics.
2100
+ self.identity2 = Identity() # for diagnostics.
2101
+ self.identity3 = Identity() # for diagnostics.
2102
+
2103
+ self.out_proj = ScaledLinear(
2104
+ hidden_channels, channels, bias=True, initial_scale=0.05
2105
+ )
2106
+
2107
+ self.whiten1 = Whiten(
2108
+ num_groups=1,
2109
+ whitening_limit=_whitening_schedule(5.0),
2110
+ prob=(0.025, 0.25),
2111
+ grad_scale=0.01,
2112
+ )
2113
+
2114
+ self.whiten2 = Whiten(
2115
+ num_groups=1,
2116
+ whitening_limit=_whitening_schedule(5.0, ratio=3.0),
2117
+ prob=(0.025, 0.25),
2118
+ grad_scale=0.01,
2119
+ )
2120
+
2121
+ def forward(
2122
+ self,
2123
+ x: Tensor,
2124
+ attn_weights: Tensor,
2125
+ ) -> Tensor:
2126
+ """.
2127
+ Args:
2128
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
2129
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
2130
+ Returns:
2131
+ a Tensor with the same shape as x
2132
+ """
2133
+ x = self.in_proj(x)
2134
+
2135
+ (seq_len, batch_size, _) = x.shape
2136
+ hidden_channels = self.hidden_channels
2137
+
2138
+ s, x, y = x.chunk(3, dim=-1)
2139
+
2140
+ # s will go through tanh.
2141
+
2142
+ s = self.balancer(s)
2143
+ s = self.tanh(s)
2144
+
2145
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
2146
+ x = self.whiten1(x)
2147
+ x = x * s
2148
+ x = self.identity1(x) # diagnostics only, it's the identity.
2149
+
2150
+ (seq_len, batch_size, embed_dim) = x.shape
2151
+ num_heads = attn_weights.shape[0]
2152
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
2153
+
2154
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
2155
+ # now x: (num_heads, batch_size, seq_len, head_dim)
2156
+ x = torch.matmul(attn_weights, x)
2157
+ # now x: (num_heads, batch_size, seq_len, head_dim)
2158
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
2159
+
2160
+ y = self.identity2(y)
2161
+ x = x * y
2162
+ x = self.identity3(x)
2163
+
2164
+ x = self.out_proj(x)
2165
+ x = self.whiten2(x)
2166
+ return x
2167
+
2168
+ def streaming_forward(
2169
+ self,
2170
+ x: Tensor,
2171
+ attn_weights: Tensor,
2172
+ cached_x: Tensor,
2173
+ left_context_len: int,
2174
+ ) -> Tuple[Tensor, Tensor]:
2175
+ """.
2176
+ Args:
2177
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
2178
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
2179
+ cached_x: left context, a Tensor of shape
2180
+ (num_heads, batch_size, left_context_len, head_dim)
2181
+ left_context_len: number of left context frames.
2182
+ Returns:
2183
+ - a Tensor with the same shape as x
2184
+ - updated left context with same shape as cached_x
2185
+ """
2186
+ x = self.in_proj(x)
2187
+
2188
+ (seq_len, batch_size, _) = x.shape
2189
+ hidden_channels = self.hidden_channels
2190
+
2191
+ s, x, y = x.chunk(3, dim=-1)
2192
+
2193
+ # s will go through tanh.
2194
+ s = self.tanh(s)
2195
+
2196
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
2197
+ x = x * s
2198
+
2199
+ (seq_len, batch_size, embed_dim) = x.shape
2200
+ num_heads = attn_weights.shape[0]
2201
+ assert attn_weights.shape == (
2202
+ num_heads,
2203
+ batch_size,
2204
+ seq_len,
2205
+ left_context_len + seq_len,
2206
+ )
2207
+
2208
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
2209
+ # now x: (num_heads, batch_size, seq_len, head_dim)
2210
+
2211
+ # Pad cached tensor
2212
+ assert cached_x.shape[2] == left_context_len, (
2213
+ cached_x.shape[2],
2214
+ left_context_len,
2215
+ )
2216
+ x_pad = torch.cat([cached_x, x], dim=2)
2217
+ # Update cached tensor
2218
+ cached_x = x_pad[:, :, -left_context_len:, :]
2219
+
2220
+ x = torch.matmul(attn_weights, x_pad)
2221
+ # now x: (num_heads, batch_size, seq_len, head_dim)
2222
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
2223
+
2224
+ x = x * y
2225
+
2226
+ x = self.out_proj(x)
2227
+ return x, cached_x
2228
+
2229
+
2230
+ class ConvolutionModule(nn.Module):
2231
+ """ConvolutionModule in Zipformer2 model.
2232
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
2233
+
2234
+ Args:
2235
+ channels (int): The number of channels of conv layers.
2236
+ kernel_size (int): Kernerl size of conv layers.
2237
+ bias (bool): Whether to use bias in conv layers (default=True).
2238
+
2239
+ """
2240
+
2241
+ def __init__(
2242
+ self,
2243
+ channels: int,
2244
+ kernel_size: int,
2245
+ causal: bool,
2246
+ ) -> None:
2247
+ """Construct a ConvolutionModule object."""
2248
+ super(ConvolutionModule, self).__init__()
2249
+ # kernerl_size should be a odd number for 'SAME' padding
2250
+ assert (kernel_size - 1) % 2 == 0
2251
+
2252
+ bottleneck_dim = channels
2253
+ self.causal = causal
2254
+
2255
+ self.in_proj = nn.Linear(
2256
+ channels,
2257
+ 2 * bottleneck_dim,
2258
+ )
2259
+ # the gradients on in_proj are a little noisy, likely to do with the
2260
+ # sigmoid in glu.
2261
+
2262
+ # after in_proj we put x through a gated linear unit (nn.functional.glu).
2263
+ # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
2264
+ # but sometimes, for some reason, for layer 0 the rms ends up being very large,
2265
+ # between 50 and 100 for different channels. This will cause very peaky and
2266
+ # sparse derivatives for the sigmoid gating function, which will tend to make
2267
+ # the loss function not learn effectively. (for most layers the average absolute values
2268
+ # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
2269
+ # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
2270
+ # layers, which likely breaks down as 0.5 for the "linear" half and
2271
+ # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
2272
+ # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
2273
+ # it will be in a better position to start learning something, i.e. to latch onto
2274
+ # the correct range.
2275
+ self.balancer1 = Balancer(
2276
+ bottleneck_dim,
2277
+ channel_dim=-1,
2278
+ min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
2279
+ max_positive=1.0,
2280
+ min_abs=1.5,
2281
+ max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
2282
+ )
2283
+
2284
+ self.activation1 = Identity() # for diagnostics
2285
+
2286
+ self.sigmoid = nn.Sigmoid()
2287
+
2288
+ self.activation2 = Identity() # for diagnostics
2289
+
2290
+ assert kernel_size % 2 == 1
2291
+
2292
+ self.depthwise_conv = (
2293
+ ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size)
2294
+ if causal
2295
+ else nn.Conv1d(
2296
+ in_channels=bottleneck_dim,
2297
+ out_channels=bottleneck_dim,
2298
+ groups=bottleneck_dim,
2299
+ kernel_size=kernel_size,
2300
+ padding=kernel_size // 2,
2301
+ )
2302
+ )
2303
+
2304
+ self.balancer2 = Balancer(
2305
+ bottleneck_dim,
2306
+ channel_dim=1,
2307
+ min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
2308
+ max_positive=1.0,
2309
+ min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
2310
+ max_abs=10.0,
2311
+ )
2312
+
2313
+ self.whiten = Whiten(
2314
+ num_groups=1,
2315
+ whitening_limit=_whitening_schedule(7.5),
2316
+ prob=(0.025, 0.25),
2317
+ grad_scale=0.01,
2318
+ )
2319
+
2320
+ self.out_proj = ActivationDropoutAndLinear(
2321
+ bottleneck_dim,
2322
+ channels,
2323
+ activation="SwooshR",
2324
+ dropout_p=0.0,
2325
+ initial_scale=0.05,
2326
+ )
2327
+
2328
+ def forward(
2329
+ self,
2330
+ x: Tensor,
2331
+ src_key_padding_mask: Optional[Tensor] = None,
2332
+ chunk_size: int = -1,
2333
+ ) -> Tensor:
2334
+ """Compute convolution module.
2335
+
2336
+ Args:
2337
+ x: Input tensor (#time, batch, channels).
2338
+ src_key_padding_mask: the mask for the src keys per batch (optional):
2339
+ (batch, #time), contains True in masked positions.
2340
+
2341
+ Returns:
2342
+ Tensor: Output tensor (#time, batch, channels).
2343
+
2344
+ """
2345
+
2346
+ x = self.in_proj(x) # (time, batch, 2*channels)
2347
+
2348
+ x, s = x.chunk(2, dim=-1)
2349
+ s = self.balancer1(s)
2350
+ s = self.sigmoid(s)
2351
+ x = self.activation1(x) # identity.
2352
+ x = x * s
2353
+ x = self.activation2(x) # identity
2354
+
2355
+ # (time, batch, channels)
2356
+
2357
+ # exchange the temporal dimension and the feature dimension
2358
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
2359
+
2360
+ if src_key_padding_mask is not None:
2361
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
2362
+
2363
+ if (
2364
+ not torch.jit.is_scripting()
2365
+ and not torch.jit.is_tracing()
2366
+ and chunk_size >= 0
2367
+ ):
2368
+ # Not support exporting a model for simulated streaming decoding
2369
+ assert (
2370
+ self.causal
2371
+ ), "Must initialize model with causal=True if you use chunk_size"
2372
+ x = self.depthwise_conv(x, chunk_size=chunk_size)
2373
+ else:
2374
+ x = self.depthwise_conv(x)
2375
+
2376
+ x = self.balancer2(x)
2377
+ x = x.permute(2, 0, 1) # (time, batch, channels)
2378
+
2379
+ x = self.whiten(x) # (time, batch, channels)
2380
+ x = self.out_proj(x) # (time, batch, channels)
2381
+
2382
+ return x
2383
+
2384
+ def streaming_forward(
2385
+ self,
2386
+ x: Tensor,
2387
+ cache: Tensor,
2388
+ src_key_padding_mask: Tensor,
2389
+ ) -> Tuple[Tensor, Tensor]:
2390
+ """Compute convolution module in streaming forward mode.
2391
+
2392
+ Args:
2393
+ x: Input tensor (#time, batch, channels).
2394
+ cache: cached left context for depthwise_conv of shape
2395
+ (#batch, channels, left_pad)
2396
+ src_key_padding_mask: the mask for the src keys per batch (optional):
2397
+ (batch, #time), contains True in masked positions.
2398
+
2399
+ Returns:
2400
+ - Output tensor (#time, batch, channels).
2401
+ - Updated cache (#batch, channels, left_pad)
2402
+ """
2403
+
2404
+ x = self.in_proj(x) # (time, batch, 2*channels)
2405
+
2406
+ x, s = x.chunk(2, dim=2)
2407
+ s = self.sigmoid(s)
2408
+ x = x * s
2409
+ # (time, batch, channels)
2410
+
2411
+ # exchange the temporal dimension and the feature dimension
2412
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
2413
+
2414
+ if src_key_padding_mask is not None:
2415
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
2416
+
2417
+ x, cache = self.depthwise_conv.streaming_forward(x, cache=cache)
2418
+
2419
+ x = x.permute(2, 0, 1) # (time, batch, channels)
2420
+
2421
+ x = self.out_proj(x) # (time, batch, channels)
2422
+
2423
+ return x, cache
2424
+
2425
+
2426
+ class ScalarMultiply(nn.Module):
2427
+ def __init__(self, scale: float):
2428
+ super().__init__()
2429
+ self.scale = scale
2430
+
2431
+ def forward(self, x):
2432
+ return x * self.scale
2433
+
2434
+
2435
+ def _test_zipformer_main(causal: bool = False):
2436
+ batch_size = 5
2437
+ seq_len = 20
2438
+ # Just make sure the forward pass runs.
2439
+
2440
+ c = Zipformer2(
2441
+ encoder_dim=(64, 96),
2442
+ encoder_unmasked_dim=(48, 64),
2443
+ num_heads=(4, 4),
2444
+ causal=causal,
2445
+ chunk_size=(4,) if causal else (-1,),
2446
+ left_context_frames=(64,),
2447
+ )
2448
+ batch_size = 5
2449
+ seq_len = 20
2450
+ # Just make sure the forward pass runs.
2451
+ f = c(
2452
+ torch.randn(seq_len, batch_size, 64),
2453
+ torch.full((batch_size,), seq_len, dtype=torch.int64),
2454
+ )
2455
+ f[0].sum().backward()
2456
+ c.eval()
2457
+ f = c(
2458
+ torch.randn(seq_len, batch_size, 64),
2459
+ torch.full((batch_size,), seq_len, dtype=torch.int64),
2460
+ )
2461
+ f # to remove flake8 warnings
2462
+
2463
+
2464
+ if __name__ == "__main__":
2465
+ logging.getLogger().setLevel(logging.INFO)
2466
+ torch.set_num_threads(1)
2467
+ torch.set_num_interop_threads(1)
2468
+ _test_zipformer_main(False)
2469
+ _test_zipformer_main(True)