init
Browse files- 1284-1180-0027.flac +0 -0
- envirnoment.txt +2 -0
- inference_600m.py +244 -0
- inference_600m.sh +16 -0
- inference_600m_streaming_forward.py +504 -0
- iter-400000-avg-4.pt +3 -0
- model.py +799 -0
- scaling.py +1913 -0
- subsampling.py +406 -0
- utilities.py +90 -0
- zipformer.py +2469 -0
1284-1180-0027.flac
ADDED
|
Binary file (72.5 kB). View file
|
|
|
envirnoment.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch=2.1.1
|
| 2 |
+
lhotse=1.28.0
|
inference_600m.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
from model import MultiKDModel
|
| 5 |
+
from scaling import ScheduledFloat
|
| 6 |
+
from subsampling import Conv2dSubsampling
|
| 7 |
+
from zipformer import Zipformer2
|
| 8 |
+
|
| 9 |
+
from lhotse import Fbank, FbankConfig
|
| 10 |
+
import torchaudio
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
LOG_EPS = math.log(1e-10)
|
| 15 |
+
|
| 16 |
+
class ZipformerConfig:
|
| 17 |
+
def __init__(self):
|
| 18 |
+
# 用 _config 存储所有参数
|
| 19 |
+
self._config = {
|
| 20 |
+
"feature_dim": 128,
|
| 21 |
+
"pos_dim": 48,
|
| 22 |
+
"output_downsampling_factor": 2,
|
| 23 |
+
"downsampling_factor": "1,2,4,8,4,2",
|
| 24 |
+
"num_encoder_layers": "2,2,3,4,3,2",
|
| 25 |
+
"feedforward_dim": "512,768,1024,1536,1024,768",
|
| 26 |
+
"encoder_dim": "192,256,448,768,448,192",
|
| 27 |
+
"encoder_unmasked_dim": "192,192,256,256,256,192",
|
| 28 |
+
"cnn_module_kernel": "31,31,15,15,15,31",
|
| 29 |
+
"num_heads": "4,4,4,8,4,4",
|
| 30 |
+
"causal": True,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
def __getattr__(self, key):
|
| 34 |
+
if key in self._config:
|
| 35 |
+
return self._config[key]
|
| 36 |
+
raise AttributeError(f"'ZipformerConfig' object has no attribute '{key}'")
|
| 37 |
+
|
| 38 |
+
def __setattr__(self, key, value):
|
| 39 |
+
if key == "_config":
|
| 40 |
+
super().__setattr__(key, value)
|
| 41 |
+
else:
|
| 42 |
+
self._config[key] = value
|
| 43 |
+
|
| 44 |
+
def __delattr__(self, key):
|
| 45 |
+
if key in self._config:
|
| 46 |
+
del self._config[key]
|
| 47 |
+
else:
|
| 48 |
+
raise AttributeError(f"'ZipformerConfig' object has no attribute '{key}'")
|
| 49 |
+
|
| 50 |
+
def to_dict(self):
|
| 51 |
+
return dict(self._config)
|
| 52 |
+
|
| 53 |
+
def __repr__(self):
|
| 54 |
+
return f"ZipformerConfig({self._config})"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def str2bool(v):
|
| 58 |
+
"""Used in argparse.ArgumentParser.add_argument to indicate
|
| 59 |
+
that a type is a bool type and user can enter
|
| 60 |
+
|
| 61 |
+
- yes, true, t, y, 1, to represent True
|
| 62 |
+
- no, false, f, n, 0, to represent False
|
| 63 |
+
|
| 64 |
+
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
| 65 |
+
"""
|
| 66 |
+
if isinstance(v, bool):
|
| 67 |
+
return v
|
| 68 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 69 |
+
return True
|
| 70 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 71 |
+
return False
|
| 72 |
+
else:
|
| 73 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 74 |
+
|
| 75 |
+
def get_parser():
|
| 76 |
+
parser = argparse.ArgumentParser(
|
| 77 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--model-version",
|
| 82 |
+
type=str,
|
| 83 |
+
default="600m_uniform_out_ds1",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--causal",
|
| 88 |
+
type=str2bool,
|
| 89 |
+
default=False,
|
| 90 |
+
help="If True, use causal version of model.",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--chunk-size",
|
| 95 |
+
type=str,
|
| 96 |
+
default="16,32,64,-1",
|
| 97 |
+
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
|
| 98 |
+
" Must be just -1 if --causal=False",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--left-context-frames",
|
| 103 |
+
type=str,
|
| 104 |
+
default="64,128,256,-1",
|
| 105 |
+
help="Maximum left-contexts for causal training, measured in frames which will "
|
| 106 |
+
"be converted to a number of chunks. If splitting into chunks, "
|
| 107 |
+
"chunk left-context frames will be chosen randomly from this list; else not relevant.",
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--ckpt-path",
|
| 112 |
+
type=str,
|
| 113 |
+
required=True,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--audio",
|
| 118 |
+
type=str,
|
| 119 |
+
required=True,
|
| 120 |
+
help="The path to the audio"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return parser
|
| 124 |
+
|
| 125 |
+
def _to_int_tuple(s: str):
|
| 126 |
+
return tuple(map(int, s.split(",")))
|
| 127 |
+
|
| 128 |
+
def get_encoder_embed(params) -> nn.Module:
|
| 129 |
+
encoder_embed = Conv2dSubsampling(
|
| 130 |
+
in_channels=128,
|
| 131 |
+
out_channels=_to_int_tuple(params.encoder_dim)[0],
|
| 132 |
+
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
| 133 |
+
)
|
| 134 |
+
return encoder_embed
|
| 135 |
+
|
| 136 |
+
def get_encoder_model(params) -> nn.Module:
|
| 137 |
+
encoder = Zipformer2(
|
| 138 |
+
output_downsampling_factor=params.output_downsampling_factor,
|
| 139 |
+
downsampling_factor=_to_int_tuple(params.downsampling_factor),
|
| 140 |
+
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
| 141 |
+
encoder_dim=_to_int_tuple(params.encoder_dim),
|
| 142 |
+
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
|
| 143 |
+
query_head_dim=_to_int_tuple("32"),
|
| 144 |
+
pos_head_dim=_to_int_tuple("4"),
|
| 145 |
+
value_head_dim=_to_int_tuple("12"),
|
| 146 |
+
pos_dim=params.pos_dim,
|
| 147 |
+
num_heads=_to_int_tuple(params.num_heads),
|
| 148 |
+
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
| 149 |
+
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
|
| 150 |
+
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
| 151 |
+
warmup_batches=4000.0,
|
| 152 |
+
causal=params.causal,
|
| 153 |
+
chunk_size=_to_int_tuple(params.chunk_size),
|
| 154 |
+
left_context_frames=_to_int_tuple(params.left_context_frames),
|
| 155 |
+
)
|
| 156 |
+
return encoder
|
| 157 |
+
|
| 158 |
+
def get_params(args):
|
| 159 |
+
params = ZipformerConfig()
|
| 160 |
+
params.chunk_size = args.chunk_size
|
| 161 |
+
params.left_context_frames = args.left_context_frames
|
| 162 |
+
|
| 163 |
+
model_version = args.model_version
|
| 164 |
+
if model_version == "600m_uniform_out_ds1":
|
| 165 |
+
params.output_downsampling_factor = 1
|
| 166 |
+
params.downsampling_factor = "1,2,4,8,4,2,1"
|
| 167 |
+
params.num_encoder_layers = "1,2,3,4,1,1,1"
|
| 168 |
+
params.feedforward_dim = "3840,3840,3840,3840,3840,3840,3840"
|
| 169 |
+
params.encoder_dim = "1280,1280,1280,1280,1280,1280,1280"
|
| 170 |
+
params.encoder_unmasked_dim = "768,768,768,768,768,768,768"
|
| 171 |
+
params.cnn_module_kernel = "31,31,15,15,15,31,31"
|
| 172 |
+
params.num_heads = "8,8,8,8,8,8,8"
|
| 173 |
+
elif model_version == "600m_uniform_out_ds2":
|
| 174 |
+
params.output_downsampling_factor = 2
|
| 175 |
+
params.downsampling_factor = "1,2,4,8,4,2,1"
|
| 176 |
+
params.num_encoder_layers = "1,2,3,4,1,1,1"
|
| 177 |
+
params.feedforward_dim = "3840,3840,3840,3840,3840,3840,3840"
|
| 178 |
+
params.encoder_dim = "1280,1280,1280,1280,1280,1280,1280"
|
| 179 |
+
params.encoder_unmasked_dim = "768,768,768,768,768,768,768"
|
| 180 |
+
params.cnn_module_kernel = "31,31,15,15,15,31,31"
|
| 181 |
+
params.num_heads = "8,8,8,8,8,8,8"
|
| 182 |
+
else:
|
| 183 |
+
raise ValueError()
|
| 184 |
+
return params
|
| 185 |
+
|
| 186 |
+
def get_model(model_version) -> nn.Module:
|
| 187 |
+
# initialise the encoder model
|
| 188 |
+
|
| 189 |
+
params = get_params(model_version)
|
| 190 |
+
encoder_embed = get_encoder_embed(params)
|
| 191 |
+
encoder = get_encoder_model(params)
|
| 192 |
+
print(params)
|
| 193 |
+
|
| 194 |
+
model = MultiKDModel(
|
| 195 |
+
encoder_embed=encoder_embed,
|
| 196 |
+
encoder=encoder,
|
| 197 |
+
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
| 198 |
+
num_codebooks=0,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
return model
|
| 202 |
+
|
| 203 |
+
def main(args):
|
| 204 |
+
device = torch.device("cpu")
|
| 205 |
+
if torch.cuda.is_available():
|
| 206 |
+
device = torch.device("cuda")
|
| 207 |
+
|
| 208 |
+
# load model
|
| 209 |
+
model = get_model(args)
|
| 210 |
+
model.to(device)
|
| 211 |
+
|
| 212 |
+
info = model.load_state_dict(
|
| 213 |
+
torch.load(args.ckpt_path)["model"], strict=False
|
| 214 |
+
)
|
| 215 |
+
print(info)
|
| 216 |
+
model.eval()
|
| 217 |
+
|
| 218 |
+
# fbank extractor
|
| 219 |
+
extractor = Fbank(FbankConfig(num_mel_bins=128))
|
| 220 |
+
|
| 221 |
+
# load audio
|
| 222 |
+
audio, fs = torchaudio.load(args.audio)
|
| 223 |
+
assert fs == 16000
|
| 224 |
+
audio_lens = audio.shape[1]
|
| 225 |
+
audios = audio.squeeze()
|
| 226 |
+
feature = [extractor.extract(audios, sampling_rate=fs)]
|
| 227 |
+
feature_lens = [f.size(0) for f in feature]
|
| 228 |
+
|
| 229 |
+
feature = torch.nn.utils.rnn.pad_sequence(feature, batch_first=True, padding_value=LOG_EPS).to(device)
|
| 230 |
+
feature_lens = torch.tensor(feature_lens, device=device)
|
| 231 |
+
|
| 232 |
+
# batch inference
|
| 233 |
+
encoder_out, encoder_out_lens = model.forward_encoder(
|
| 234 |
+
feature,
|
| 235 |
+
feature_lens,
|
| 236 |
+
)
|
| 237 |
+
print(encoder_out)
|
| 238 |
+
print(encoder_out.shape)
|
| 239 |
+
|
| 240 |
+
if __name__=="__main__":
|
| 241 |
+
parser = get_parser()
|
| 242 |
+
args = parser.parse_args()
|
| 243 |
+
|
| 244 |
+
main(args)
|
inference_600m.sh
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
export PYTHONPATH=./../../../:$PYTHONPATH
|
| 4 |
+
|
| 5 |
+
model_version=600m_uniform_out_ds1
|
| 6 |
+
causal=1
|
| 7 |
+
left_context_frames=256
|
| 8 |
+
chunk_size=8
|
| 9 |
+
|
| 10 |
+
python inference_600m.py \
|
| 11 |
+
--model-version $model_version \
|
| 12 |
+
--ckpt-path iter-400000-avg-4.pt \
|
| 13 |
+
--causal $causal \
|
| 14 |
+
--left-context-frames $left_context_frames \
|
| 15 |
+
--chunk-size $chunk_size \
|
| 16 |
+
--audio 1284-1180-0027.flac
|
inference_600m_streaming_forward.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
from typing import Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from model import MultiKDModel
|
| 6 |
+
from scaling import ScheduledFloat
|
| 7 |
+
from subsampling import Conv2dSubsampling
|
| 8 |
+
from zipformer import Zipformer2
|
| 9 |
+
|
| 10 |
+
from lhotse import Fbank, FbankConfig
|
| 11 |
+
import torchaudio
|
| 12 |
+
import torch
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
from utilities import make_pad_mask, str2bool, ZipformerConfig
|
| 17 |
+
|
| 18 |
+
LOG_EPS = math.log(1e-10)
|
| 19 |
+
|
| 20 |
+
def get_parser():
|
| 21 |
+
parser = argparse.ArgumentParser(
|
| 22 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--model-version",
|
| 27 |
+
type=str,
|
| 28 |
+
default="600m_uniform_out_ds1",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--causal",
|
| 33 |
+
type=str2bool,
|
| 34 |
+
default=False,
|
| 35 |
+
help="If True, use causal version of model.",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--chunk-size",
|
| 40 |
+
type=str,
|
| 41 |
+
default="16,32,64,-1",
|
| 42 |
+
help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
|
| 43 |
+
" Must be just -1 if --causal=False",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--left-context-frames",
|
| 48 |
+
type=str,
|
| 49 |
+
default="64,128,256,-1",
|
| 50 |
+
help="Maximum left-contexts for causal training, measured in frames which will "
|
| 51 |
+
"be converted to a number of chunks. If splitting into chunks, "
|
| 52 |
+
"chunk left-context frames will be chosen randomly from this list; else not relevant.",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--ckpt-path",
|
| 57 |
+
type=str,
|
| 58 |
+
required=True,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--audio",
|
| 63 |
+
type=str,
|
| 64 |
+
required=True,
|
| 65 |
+
help="The path to the audio"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
return parser
|
| 69 |
+
|
| 70 |
+
def _to_int_tuple(s: str):
|
| 71 |
+
return tuple(map(int, s.split(",")))
|
| 72 |
+
|
| 73 |
+
def get_encoder_embed(params) -> nn.Module:
|
| 74 |
+
encoder_embed = Conv2dSubsampling(
|
| 75 |
+
in_channels=128,
|
| 76 |
+
out_channels=_to_int_tuple(params.encoder_dim)[0],
|
| 77 |
+
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
| 78 |
+
)
|
| 79 |
+
return encoder_embed
|
| 80 |
+
|
| 81 |
+
def get_encoder_model(params) -> nn.Module:
|
| 82 |
+
encoder = Zipformer2(
|
| 83 |
+
output_downsampling_factor=params.output_downsampling_factor,
|
| 84 |
+
downsampling_factor=_to_int_tuple(params.downsampling_factor),
|
| 85 |
+
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
| 86 |
+
encoder_dim=_to_int_tuple(params.encoder_dim),
|
| 87 |
+
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
|
| 88 |
+
query_head_dim=_to_int_tuple("32"),
|
| 89 |
+
pos_head_dim=_to_int_tuple("4"),
|
| 90 |
+
value_head_dim=_to_int_tuple("12"),
|
| 91 |
+
pos_dim=params.pos_dim,
|
| 92 |
+
num_heads=_to_int_tuple(params.num_heads),
|
| 93 |
+
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
| 94 |
+
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
|
| 95 |
+
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
| 96 |
+
warmup_batches=4000.0,
|
| 97 |
+
causal=params.causal,
|
| 98 |
+
chunk_size=_to_int_tuple(params.chunk_size),
|
| 99 |
+
left_context_frames=_to_int_tuple(params.left_context_frames),
|
| 100 |
+
)
|
| 101 |
+
return encoder
|
| 102 |
+
|
| 103 |
+
def get_params(args):
|
| 104 |
+
params = ZipformerConfig()
|
| 105 |
+
params.chunk_size = args.chunk_size
|
| 106 |
+
params.left_context_frames = args.left_context_frames
|
| 107 |
+
|
| 108 |
+
model_version = args.model_version
|
| 109 |
+
if model_version == "600m_uniform_out_ds1":
|
| 110 |
+
params.output_downsampling_factor = 1
|
| 111 |
+
params.downsampling_factor = "1,2,4,8,4,2,1"
|
| 112 |
+
params.num_encoder_layers = "1,2,3,4,1,1,1"
|
| 113 |
+
params.feedforward_dim = "3840,3840,3840,3840,3840,3840,3840"
|
| 114 |
+
params.encoder_dim = "1280,1280,1280,1280,1280,1280,1280"
|
| 115 |
+
params.encoder_unmasked_dim = "768,768,768,768,768,768,768"
|
| 116 |
+
params.cnn_module_kernel = "31,31,15,15,15,31,31"
|
| 117 |
+
params.num_heads = "8,8,8,8,8,8,8"
|
| 118 |
+
elif model_version == "600m_uniform_out_ds2":
|
| 119 |
+
params.output_downsampling_factor = 2
|
| 120 |
+
params.downsampling_factor = "1,2,4,8,4,2,1"
|
| 121 |
+
params.num_encoder_layers = "1,2,3,4,1,1,1"
|
| 122 |
+
params.feedforward_dim = "3840,3840,3840,3840,3840,3840,3840"
|
| 123 |
+
params.encoder_dim = "1280,1280,1280,1280,1280,1280,1280"
|
| 124 |
+
params.encoder_unmasked_dim = "768,768,768,768,768,768,768"
|
| 125 |
+
params.cnn_module_kernel = "31,31,15,15,15,31,31"
|
| 126 |
+
params.num_heads = "8,8,8,8,8,8,8"
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError()
|
| 129 |
+
return params
|
| 130 |
+
|
| 131 |
+
def get_model(model_version) -> nn.Module:
|
| 132 |
+
# initialise the encoder model
|
| 133 |
+
|
| 134 |
+
params = get_params(model_version)
|
| 135 |
+
encoder_embed = get_encoder_embed(params)
|
| 136 |
+
encoder = get_encoder_model(params)
|
| 137 |
+
print(params)
|
| 138 |
+
|
| 139 |
+
model = MultiKDModel(
|
| 140 |
+
encoder_embed=encoder_embed,
|
| 141 |
+
encoder=encoder,
|
| 142 |
+
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
| 143 |
+
num_codebooks=0,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
return model
|
| 147 |
+
|
| 148 |
+
def get_init_states(
|
| 149 |
+
model: nn.Module,
|
| 150 |
+
batch_size: int = 1,
|
| 151 |
+
device: torch.device = torch.device("cpu"),
|
| 152 |
+
) -> List[torch.Tensor]:
|
| 153 |
+
"""
|
| 154 |
+
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
|
| 155 |
+
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
| 156 |
+
states[-2] is the cached left padding for ConvNeXt module,
|
| 157 |
+
of shape (batch_size, num_channels, left_pad, num_freqs)
|
| 158 |
+
states[-1] is processed_lens of shape (batch,), which records the number
|
| 159 |
+
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
| 160 |
+
"""
|
| 161 |
+
states = model.encoder.get_init_states(batch_size, device)
|
| 162 |
+
|
| 163 |
+
embed_states = model.encoder_embed.get_init_states(batch_size, device)
|
| 164 |
+
states.append(embed_states)
|
| 165 |
+
|
| 166 |
+
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
| 167 |
+
states.append(processed_lens)
|
| 168 |
+
|
| 169 |
+
return states
|
| 170 |
+
|
| 171 |
+
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
|
| 172 |
+
"""Stack list of zipformer states that correspond to separate utterances
|
| 173 |
+
into a single emformer state, so that it can be used as an input for
|
| 174 |
+
zipformer when those utterances are formed into a batch.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
state_list:
|
| 178 |
+
Each element in state_list corresponding to the internal state
|
| 179 |
+
of the zipformer model for a single utterance. For element-n,
|
| 180 |
+
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
|
| 181 |
+
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
|
| 182 |
+
cached_val2, cached_conv1, cached_conv2).
|
| 183 |
+
state_list[n][-2] is the cached left padding for ConvNeXt module,
|
| 184 |
+
of shape (batch_size, num_channels, left_pad, num_freqs)
|
| 185 |
+
state_list[n][-1] is processed_lens of shape (batch,), which records the number
|
| 186 |
+
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
| 187 |
+
|
| 188 |
+
Note:
|
| 189 |
+
It is the inverse of :func:`unstack_states`.
|
| 190 |
+
"""
|
| 191 |
+
batch_size = len(state_list)
|
| 192 |
+
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
|
| 193 |
+
tot_num_layers = (len(state_list[0]) - 2) // 6
|
| 194 |
+
|
| 195 |
+
batch_states = []
|
| 196 |
+
for layer in range(tot_num_layers):
|
| 197 |
+
layer_offset = layer * 6
|
| 198 |
+
# cached_key: (left_context_len, batch_size, key_dim)
|
| 199 |
+
cached_key = torch.cat(
|
| 200 |
+
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
|
| 201 |
+
)
|
| 202 |
+
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
| 203 |
+
cached_nonlin_attn = torch.cat(
|
| 204 |
+
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
|
| 205 |
+
)
|
| 206 |
+
# cached_val1: (left_context_len, batch_size, value_dim)
|
| 207 |
+
cached_val1 = torch.cat(
|
| 208 |
+
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
|
| 209 |
+
)
|
| 210 |
+
# cached_val2: (left_context_len, batch_size, value_dim)
|
| 211 |
+
cached_val2 = torch.cat(
|
| 212 |
+
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
|
| 213 |
+
)
|
| 214 |
+
# cached_conv1: (#batch, channels, left_pad)
|
| 215 |
+
cached_conv1 = torch.cat(
|
| 216 |
+
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
|
| 217 |
+
)
|
| 218 |
+
# cached_conv2: (#batch, channels, left_pad)
|
| 219 |
+
cached_conv2 = torch.cat(
|
| 220 |
+
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
|
| 221 |
+
)
|
| 222 |
+
batch_states += [
|
| 223 |
+
cached_key,
|
| 224 |
+
cached_nonlin_attn,
|
| 225 |
+
cached_val1,
|
| 226 |
+
cached_val2,
|
| 227 |
+
cached_conv1,
|
| 228 |
+
cached_conv2,
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
cached_embed_left_pad = torch.cat(
|
| 232 |
+
[state_list[i][-2] for i in range(batch_size)], dim=0
|
| 233 |
+
)
|
| 234 |
+
batch_states.append(cached_embed_left_pad)
|
| 235 |
+
|
| 236 |
+
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
|
| 237 |
+
batch_states.append(processed_lens)
|
| 238 |
+
|
| 239 |
+
return batch_states
|
| 240 |
+
|
| 241 |
+
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
| 242 |
+
"""Unstack the zipformer state corresponding to a batch of utterances
|
| 243 |
+
into a list of states, where the i-th entry is the state from the i-th
|
| 244 |
+
utterance in the batch.
|
| 245 |
+
|
| 246 |
+
Note:
|
| 247 |
+
It is the inverse of :func:`stack_states`.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
batch_states: A list of cached tensors of all encoder layers. For layer-i,
|
| 251 |
+
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
|
| 252 |
+
cached_conv1, cached_conv2).
|
| 253 |
+
state_list[-2] is the cached left padding for ConvNeXt module,
|
| 254 |
+
of shape (batch_size, num_channels, left_pad, num_freqs)
|
| 255 |
+
states[-1] is processed_lens of shape (batch,), which records the number
|
| 256 |
+
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
state_list: A list of list. Each element in state_list corresponding to the internal state
|
| 260 |
+
of the zipformer model for a single utterance.
|
| 261 |
+
"""
|
| 262 |
+
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
|
| 263 |
+
tot_num_layers = (len(batch_states) - 2) // 6
|
| 264 |
+
|
| 265 |
+
processed_lens = batch_states[-1]
|
| 266 |
+
batch_size = processed_lens.shape[0]
|
| 267 |
+
|
| 268 |
+
state_list = [[] for _ in range(batch_size)]
|
| 269 |
+
|
| 270 |
+
for layer in range(tot_num_layers):
|
| 271 |
+
layer_offset = layer * 6
|
| 272 |
+
# cached_key: (left_context_len, batch_size, key_dim)
|
| 273 |
+
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
|
| 274 |
+
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
| 275 |
+
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
|
| 276 |
+
chunks=batch_size, dim=1
|
| 277 |
+
)
|
| 278 |
+
# cached_val1: (left_context_len, batch_size, value_dim)
|
| 279 |
+
cached_val1_list = batch_states[layer_offset + 2].chunk(
|
| 280 |
+
chunks=batch_size, dim=1
|
| 281 |
+
)
|
| 282 |
+
# cached_val2: (left_context_len, batch_size, value_dim)
|
| 283 |
+
cached_val2_list = batch_states[layer_offset + 3].chunk(
|
| 284 |
+
chunks=batch_size, dim=1
|
| 285 |
+
)
|
| 286 |
+
# cached_conv1: (#batch, channels, left_pad)
|
| 287 |
+
cached_conv1_list = batch_states[layer_offset + 4].chunk(
|
| 288 |
+
chunks=batch_size, dim=0
|
| 289 |
+
)
|
| 290 |
+
# cached_conv2: (#batch, channels, left_pad)
|
| 291 |
+
cached_conv2_list = batch_states[layer_offset + 5].chunk(
|
| 292 |
+
chunks=batch_size, dim=0
|
| 293 |
+
)
|
| 294 |
+
for i in range(batch_size):
|
| 295 |
+
state_list[i] += [
|
| 296 |
+
cached_key_list[i],
|
| 297 |
+
cached_nonlin_attn_list[i],
|
| 298 |
+
cached_val1_list[i],
|
| 299 |
+
cached_val2_list[i],
|
| 300 |
+
cached_conv1_list[i],
|
| 301 |
+
cached_conv2_list[i],
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
|
| 305 |
+
for i in range(batch_size):
|
| 306 |
+
state_list[i].append(cached_embed_left_pad_list[i])
|
| 307 |
+
|
| 308 |
+
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
|
| 309 |
+
for i in range(batch_size):
|
| 310 |
+
state_list[i].append(processed_lens_list[i])
|
| 311 |
+
|
| 312 |
+
return state_list
|
| 313 |
+
|
| 314 |
+
def streaming_forward(
|
| 315 |
+
features: Tensor,
|
| 316 |
+
feature_lens: Tensor,
|
| 317 |
+
model: nn.Module,
|
| 318 |
+
states: List[Tensor],
|
| 319 |
+
chunk_size: int,
|
| 320 |
+
left_context_len: int,
|
| 321 |
+
) -> Tuple[Tensor, Tensor, List[Tensor]]:
|
| 322 |
+
"""
|
| 323 |
+
Returns encoder outputs, output lengths, and updated states.
|
| 324 |
+
"""
|
| 325 |
+
cached_embed_left_pad = states[-2]
|
| 326 |
+
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
|
| 327 |
+
x=features,
|
| 328 |
+
x_lens=feature_lens,
|
| 329 |
+
cached_left_pad=cached_embed_left_pad,
|
| 330 |
+
)
|
| 331 |
+
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
|
| 332 |
+
|
| 333 |
+
src_key_padding_mask = make_pad_mask(x_lens)
|
| 334 |
+
|
| 335 |
+
# processed_mask is used to mask out initial states
|
| 336 |
+
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
| 337 |
+
x.size(0), left_context_len
|
| 338 |
+
)
|
| 339 |
+
processed_lens = states[-1] # (batch,)
|
| 340 |
+
# (batch, left_context_size)
|
| 341 |
+
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
| 342 |
+
# Update processed lengths
|
| 343 |
+
new_processed_lens = processed_lens + x_lens
|
| 344 |
+
|
| 345 |
+
# (batch, left_context_size + chunk_size)
|
| 346 |
+
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
| 347 |
+
|
| 348 |
+
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
| 349 |
+
encoder_states = states[:-2]
|
| 350 |
+
(
|
| 351 |
+
encoder_out,
|
| 352 |
+
encoder_out_lens,
|
| 353 |
+
new_encoder_states,
|
| 354 |
+
) = model.encoder.streaming_forward(
|
| 355 |
+
x=x,
|
| 356 |
+
x_lens=x_lens,
|
| 357 |
+
states=encoder_states,
|
| 358 |
+
src_key_padding_mask=src_key_padding_mask,
|
| 359 |
+
)
|
| 360 |
+
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
| 361 |
+
|
| 362 |
+
new_states = new_encoder_states + [
|
| 363 |
+
new_cached_embed_left_pad,
|
| 364 |
+
new_processed_lens,
|
| 365 |
+
]
|
| 366 |
+
return encoder_out, encoder_out_lens, new_states
|
| 367 |
+
|
| 368 |
+
def chunk_forward(
|
| 369 |
+
audio: torch.Tensor,
|
| 370 |
+
model: torch.nn.Module,
|
| 371 |
+
feature_dim: int = 128,
|
| 372 |
+
chunk_size: int = 8,
|
| 373 |
+
left_context_frames: int = 256,
|
| 374 |
+
):
|
| 375 |
+
# Perform chunk by chunk forward for the encoder. Each chunk is conditioned on the current chunk and left context (maintained by the states)
|
| 376 |
+
# At each step, we take a chunk of audio and forward the encoder
|
| 377 |
+
# For the first chunk, we wait until the accumulated audio duration to reach (buffer + chunk_duration), the buffer
|
| 378 |
+
# is necessary for the convolution subsampling module in the encoder.
|
| 379 |
+
# After the first chunk, we perform normal chunk-by-chunk inference when the accumulated audio reaches chunk_duration
|
| 380 |
+
# An example of Buffer=2 frames, chunk=5 frames, the latency for the first chunk is 7 frames (as we need to accumulate 7 frames
|
| 381 |
+
# for decoding), the rest chunks have latency of 5 frames.
|
| 382 |
+
# Each time we feed (5 + 2) frames to the encoder, and then shift 5 frames
|
| 383 |
+
# Chunk 1: AAAAAAA
|
| 384 |
+
# Chunk 2: AAAAAAA
|
| 385 |
+
# Chunk 3: AAAAAAA
|
| 386 |
+
|
| 387 |
+
# NOTE: params.chunk_size is the chunk_size regarding to the input of the zipformer encoder, so at fbank level, the chunk size
|
| 388 |
+
# is 2 * params.chunk_size
|
| 389 |
+
|
| 390 |
+
# fbank extractor
|
| 391 |
+
extractor = Fbank(FbankConfig(num_mel_bins=feature_dim))
|
| 392 |
+
|
| 393 |
+
device = next(model.parameters()).device
|
| 394 |
+
|
| 395 |
+
chunk_size = int(chunk_size)
|
| 396 |
+
chunk_size_samples = int(chunk_size * 2 * 160) # chunk size represented in audio samples of 16kHz sampling rate
|
| 397 |
+
left_context_len = int(left_context_frames)
|
| 398 |
+
pad_length = 7 + 2 * 3 # buffer required by encoder_embed module (i.e. convolution subsampling)
|
| 399 |
+
pad_length_samples = (7 + 2 * 3) * 160
|
| 400 |
+
|
| 401 |
+
# intialize states, to be maintained during chunk-wise forward
|
| 402 |
+
initial_states = get_init_states(model=model, batch_size=1, device=device)
|
| 403 |
+
|
| 404 |
+
# start forward chunk by chunk
|
| 405 |
+
encoder_outs = []
|
| 406 |
+
encoder_out_lens = 0
|
| 407 |
+
states = initial_states
|
| 408 |
+
|
| 409 |
+
num_chunk = 0
|
| 410 |
+
num_processed_samples = 0 # audio samples
|
| 411 |
+
|
| 412 |
+
# the actual loop performing the chunk-wise inference of the encoder
|
| 413 |
+
while True:
|
| 414 |
+
# prepare the input for processing current chunk
|
| 415 |
+
# compute fbank for the current chunk
|
| 416 |
+
audio_chunk = audio[:, num_processed_samples: num_processed_samples + (chunk_size_samples + pad_length_samples)]
|
| 417 |
+
features = extractor.extract(audio_chunk, sampling_rate=16000)
|
| 418 |
+
features = features.to(device)
|
| 419 |
+
feature_lens = features.shape[0]
|
| 420 |
+
|
| 421 |
+
feature_lens = torch.tensor([feature_lens], device=device) # shape: (1)
|
| 422 |
+
features = features.unsqueeze(0) # shape: (1,T,num_mels)
|
| 423 |
+
|
| 424 |
+
# the audio chunk could be shorter than the expected length, for example in the last two chunks
|
| 425 |
+
# pad the chunk so that the input shape is (chunk_size + buffer)
|
| 426 |
+
tail_length = chunk_size * 2 + 7 + 2 * 3 # each prepared chunk should have this length
|
| 427 |
+
if features.size(1) < tail_length:
|
| 428 |
+
pad_length = tail_length - features.size(1)
|
| 429 |
+
feature_lens += pad_length
|
| 430 |
+
features = torch.nn.functional.pad(
|
| 431 |
+
features,
|
| 432 |
+
(0, 0, 0, pad_length),
|
| 433 |
+
mode="constant",
|
| 434 |
+
value=LOG_EPS,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
states = stack_states([states])
|
| 438 |
+
|
| 439 |
+
# forward current chunk in batch=1
|
| 440 |
+
encoder_out, encoder_out_len, new_states = streaming_forward(
|
| 441 |
+
features=features,
|
| 442 |
+
feature_lens=feature_lens,
|
| 443 |
+
model=model,
|
| 444 |
+
states=states,
|
| 445 |
+
chunk_size=chunk_size,
|
| 446 |
+
left_context_len=left_context_len,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
encoder_outs.append(encoder_out)
|
| 450 |
+
encoder_out_lens += encoder_out_len
|
| 451 |
+
|
| 452 |
+
# update the states
|
| 453 |
+
states = unstack_states(new_states)[0]
|
| 454 |
+
|
| 455 |
+
num_chunk += 1
|
| 456 |
+
num_processed_samples += chunk_size_samples
|
| 457 |
+
|
| 458 |
+
if num_processed_samples > audio.shape[1]:
|
| 459 |
+
print(f"Audio is exhausted.")
|
| 460 |
+
break
|
| 461 |
+
|
| 462 |
+
encoder_outs = torch.cat(encoder_outs, dim=1) # shape: (1,T,C)
|
| 463 |
+
|
| 464 |
+
return encoder_outs, encoder_out_lens
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def main(args):
|
| 469 |
+
device = torch.device("cpu")
|
| 470 |
+
if torch.cuda.is_available():
|
| 471 |
+
device = torch.device("cuda")
|
| 472 |
+
|
| 473 |
+
# load model
|
| 474 |
+
model = get_model(args)
|
| 475 |
+
model.to(device)
|
| 476 |
+
|
| 477 |
+
info = model.load_state_dict(
|
| 478 |
+
torch.load(args.ckpt_path)["model"], strict=False
|
| 479 |
+
)
|
| 480 |
+
print(info)
|
| 481 |
+
model.eval()
|
| 482 |
+
|
| 483 |
+
# load audio
|
| 484 |
+
audio, fs = torchaudio.load(args.audio)
|
| 485 |
+
assert fs == 16000
|
| 486 |
+
|
| 487 |
+
encoder_out, encoder_out_lens = chunk_forward(
|
| 488 |
+
audio=audio, # shape (1, num_samples)
|
| 489 |
+
model=model,
|
| 490 |
+
feature_dim=128,
|
| 491 |
+
chunk_size=args.chunk_size,
|
| 492 |
+
left_context_frames=args.left_context_frames,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
print(encoder_out)
|
| 497 |
+
print(encoder_out.shape)
|
| 498 |
+
# torch.save(encoder_out, "streaming_forward_encoder_out_no_k2.pt")
|
| 499 |
+
|
| 500 |
+
if __name__=="__main__":
|
| 501 |
+
parser = get_parser()
|
| 502 |
+
args = parser.parse_args()
|
| 503 |
+
|
| 504 |
+
main(args)
|
iter-400000-avg-4.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8ab1b41eb7d85e83e01bf44249dfbac1c56c0f644aedc7f790268aaff0e287e4
|
| 3 |
+
size 2398204810
|
model.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
| 2 |
+
# Wei Kang,
|
| 3 |
+
# Zengwei Yao)
|
| 4 |
+
#
|
| 5 |
+
# Copyright 2025 University of Cambridge (authors: Xiaoyu Yang)
|
| 6 |
+
#
|
| 7 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
import logging
|
| 22 |
+
from typing import Optional, Tuple
|
| 23 |
+
import random
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
|
| 30 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
lengths:
|
| 34 |
+
A 1-D tensor containing sentence lengths.
|
| 35 |
+
max_len:
|
| 36 |
+
The length of masks.
|
| 37 |
+
Returns:
|
| 38 |
+
Return a 2-D bool tensor, where masked positions
|
| 39 |
+
are filled with `True` and non-masked positions are
|
| 40 |
+
filled with `False`.
|
| 41 |
+
|
| 42 |
+
>>> lengths = torch.tensor([1, 3, 2, 5])
|
| 43 |
+
>>> make_pad_mask(lengths)
|
| 44 |
+
tensor([[False, True, True, True, True],
|
| 45 |
+
[False, False, False, True, True],
|
| 46 |
+
[False, False, True, True, True],
|
| 47 |
+
[False, False, False, False, False]])
|
| 48 |
+
"""
|
| 49 |
+
assert lengths.ndim == 1, lengths.ndim
|
| 50 |
+
max_len = max(max_len, lengths.max())
|
| 51 |
+
n = lengths.size(0)
|
| 52 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
| 53 |
+
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
| 54 |
+
|
| 55 |
+
return expaned_lengths >= lengths.unsqueeze(-1)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MultiKDModel(nn.Module):
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
encoder_embed: nn.Module,
|
| 62 |
+
encoder: nn.Module,
|
| 63 |
+
encoder_dim: int,
|
| 64 |
+
num_codebooks: int=8,
|
| 65 |
+
distillation_layer: int=9,
|
| 66 |
+
distillation_delta: int=0,
|
| 67 |
+
teacher_frame_ratio: int = 2,
|
| 68 |
+
interpolate_teacher: bool = False,
|
| 69 |
+
n_mels: int = 128,
|
| 70 |
+
num_events: int = 527,
|
| 71 |
+
mask_mode: str = "w2v2",
|
| 72 |
+
mask_prob: float = 0.65,
|
| 73 |
+
mask_length: int = 10,
|
| 74 |
+
mask_selection: str = "static",
|
| 75 |
+
mask_other: float = 0.0,
|
| 76 |
+
min_masks: int = 2,
|
| 77 |
+
mask_channel_prob: float = 0.0,
|
| 78 |
+
mask_channel_length: int = 10,
|
| 79 |
+
mask_channel_selection: str = "static",
|
| 80 |
+
mask_channel_other: float = 0.0,
|
| 81 |
+
loss_only_mask: bool = False,
|
| 82 |
+
):
|
| 83 |
+
"""A model that performs MVQ KD pre-training .
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
encoder_embed:
|
| 87 |
+
It is a Convolutional 2D subsampling module. It converts
|
| 88 |
+
an input of shape (N, T, idim) to an output of of shape
|
| 89 |
+
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
| 90 |
+
encoder:
|
| 91 |
+
It is the transcription network in the paper. Its accepts
|
| 92 |
+
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
| 93 |
+
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
| 94 |
+
`logit_lens` of shape (N,).
|
| 95 |
+
num_codebooks:
|
| 96 |
+
The number of codebooks used in the target
|
| 97 |
+
distillation_layer:
|
| 98 |
+
Use which layer to do MVQ pre-training
|
| 99 |
+
distillation_delta:
|
| 100 |
+
How many frames to delay the alignment between the model and the target frames.
|
| 101 |
+
Should be zero for non-streaming models, and a positive number for streaming models
|
| 102 |
+
teacher_frame_ratio:
|
| 103 |
+
The frame rate ratio between the target and the model output
|
| 104 |
+
mask_mode:
|
| 105 |
+
The masking mode.
|
| 106 |
+
w2v2: the wav2vec2 style of masking, allows overlap
|
| 107 |
+
custom: no overlap, therefore bigger masking ratio
|
| 108 |
+
mask_prob:
|
| 109 |
+
The probability of selecting choosing one frame as the start index
|
| 110 |
+
mask_length:
|
| 111 |
+
The length of each mask
|
| 112 |
+
mask_selection:
|
| 113 |
+
How to determine the length of the mask, see ``compute_mask_indices''
|
| 114 |
+
"""
|
| 115 |
+
super().__init__()
|
| 116 |
+
|
| 117 |
+
self.encoder_embed = encoder_embed
|
| 118 |
+
self.encoder = encoder
|
| 119 |
+
self.encoder_dim = encoder_dim
|
| 120 |
+
|
| 121 |
+
self.distillation_layer = distillation_layer
|
| 122 |
+
# the frame ratio between the teacher and student
|
| 123 |
+
# if larger than one, we are basically having more than one set of
|
| 124 |
+
# codebooks for each frame
|
| 125 |
+
self.num_codebooks= num_codebooks
|
| 126 |
+
self.teacher_frame_ratio = teacher_frame_ratio
|
| 127 |
+
self.interpolate_teacher = interpolate_teacher
|
| 128 |
+
self.distillation_delta = distillation_delta
|
| 129 |
+
|
| 130 |
+
if num_codebooks > 0:
|
| 131 |
+
from multi_quantization.prediction import JointCodebookLoss
|
| 132 |
+
self.codebook_loss_net = JointCodebookLoss(
|
| 133 |
+
predictor_channels=encoder_dim,
|
| 134 |
+
num_codebooks=num_codebooks * self.teacher_frame_ratio,
|
| 135 |
+
is_joint=False,
|
| 136 |
+
reduction="none",
|
| 137 |
+
)
|
| 138 |
+
else:
|
| 139 |
+
self.codebook_loss_net = None
|
| 140 |
+
|
| 141 |
+
self.audio_tagging_proj = nn.Sequential(
|
| 142 |
+
nn.Dropout(0.1),
|
| 143 |
+
nn.Linear(encoder_dim, num_events),
|
| 144 |
+
) # 527 classes
|
| 145 |
+
|
| 146 |
+
# masking related
|
| 147 |
+
assert mask_mode in ["w2v2", "block"], f"Unseen mask mode: {mask_mode}"
|
| 148 |
+
self.mask_mode = mask_mode
|
| 149 |
+
|
| 150 |
+
self.mask_emb = nn.Parameter(torch.FloatTensor(n_mels).normal_())
|
| 151 |
+
self.mask_prob = mask_prob
|
| 152 |
+
self.mask_length = mask_length
|
| 153 |
+
self.mask_selection = mask_selection
|
| 154 |
+
self.mask_other = mask_other
|
| 155 |
+
self.min_masks = min_masks
|
| 156 |
+
|
| 157 |
+
self.mask_channel_prob = mask_channel_prob
|
| 158 |
+
self.mask_channel_length = mask_channel_length
|
| 159 |
+
self.mask_channel_selection = mask_channel_selection
|
| 160 |
+
self.mask_channel_other = mask_channel_other
|
| 161 |
+
|
| 162 |
+
self.loss_only_mask = loss_only_mask
|
| 163 |
+
|
| 164 |
+
def forward_encoder(
|
| 165 |
+
self, x: torch.Tensor, x_lens: torch.Tensor
|
| 166 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 167 |
+
"""Compute encoder outputs.
|
| 168 |
+
Args:
|
| 169 |
+
x:
|
| 170 |
+
A 3-D tensor of shape (N, T, C).
|
| 171 |
+
x_lens:
|
| 172 |
+
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
| 173 |
+
before padding.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
encoder_out:
|
| 177 |
+
Encoder output, of shape (N, T, C).
|
| 178 |
+
encoder_out_lens:
|
| 179 |
+
Encoder output lengths, of shape (N,).
|
| 180 |
+
"""
|
| 181 |
+
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
| 182 |
+
x, x_lens = self.encoder_embed(x, x_lens)
|
| 183 |
+
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
| 184 |
+
|
| 185 |
+
src_key_padding_mask = make_pad_mask(x_lens)
|
| 186 |
+
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
| 187 |
+
|
| 188 |
+
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
| 189 |
+
|
| 190 |
+
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
| 191 |
+
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
| 192 |
+
|
| 193 |
+
return encoder_out, encoder_out_lens
|
| 194 |
+
|
| 195 |
+
def forward(
|
| 196 |
+
self,
|
| 197 |
+
x: torch.Tensor,
|
| 198 |
+
x_lens: torch.Tensor,
|
| 199 |
+
codebook_indexes: torch.Tensor = None,
|
| 200 |
+
at_targets: torch.Tensor = None,
|
| 201 |
+
mask: bool = True,
|
| 202 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 203 |
+
"""
|
| 204 |
+
Args:
|
| 205 |
+
x:
|
| 206 |
+
A 3-D tensor of shape (N, T, C).
|
| 207 |
+
x_lens:
|
| 208 |
+
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
| 209 |
+
before padding.
|
| 210 |
+
codebook_indexes:
|
| 211 |
+
Codebook indexes of teacher embeddings
|
| 212 |
+
mask:
|
| 213 |
+
If we perform w2v2 style of masking over the fbank frames
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
Return the codebook loss
|
| 217 |
+
"""
|
| 218 |
+
assert x.ndim == 3, x.shape
|
| 219 |
+
assert x_lens.ndim == 1, x_lens.shape
|
| 220 |
+
assert codebook_indexes is not None or at_targets is not None
|
| 221 |
+
|
| 222 |
+
# apply masking
|
| 223 |
+
if self.training and mask:
|
| 224 |
+
padding_mask = make_pad_mask(x_lens)
|
| 225 |
+
|
| 226 |
+
# apply masking to the fbank features
|
| 227 |
+
x, mask_indices = self.apply_mask(
|
| 228 |
+
x.clone(),
|
| 229 |
+
padding_mask=padding_mask
|
| 230 |
+
) # (N,T,C), (N,T)
|
| 231 |
+
else:
|
| 232 |
+
mask_indices = None
|
| 233 |
+
|
| 234 |
+
# Compute encoder outputs
|
| 235 |
+
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
| 236 |
+
|
| 237 |
+
if codebook_indexes is not None and self.codebook_loss_net is not None:
|
| 238 |
+
codebook_loss = self.forward_codebook_loss(
|
| 239 |
+
encoder_out, encoder_out_lens, codebook_indexes, reduction="none"
|
| 240 |
+
)
|
| 241 |
+
if self.loss_only_mask and mask_indices is not None:
|
| 242 |
+
# downsample the mask
|
| 243 |
+
mask_indices = nn.functional.avg_pool1d(mask_indices, 4) >= 0.5
|
| 244 |
+
assert mask_indices.size(1) >= codebook_loss.size(1)
|
| 245 |
+
mask_indices = mask_indices[:, :codebook_loss.size(1)].float()
|
| 246 |
+
codebook_loss = codebook_loss * mask_indices
|
| 247 |
+
codebook_loss = codebook_loss.sum(dim=1) # (B,)
|
| 248 |
+
else:
|
| 249 |
+
codebook_loss = None
|
| 250 |
+
|
| 251 |
+
if at_targets is not None:
|
| 252 |
+
at_loss = self.forward_audio_tagging(encoder_out, encoder_out_lens, at_targets, return_logits=False)
|
| 253 |
+
else:
|
| 254 |
+
at_loss = None
|
| 255 |
+
|
| 256 |
+
return codebook_loss, at_loss
|
| 257 |
+
|
| 258 |
+
def forward_codebook_loss(
|
| 259 |
+
self,
|
| 260 |
+
encoder_out: torch.Tensor,
|
| 261 |
+
encoder_out_lens: torch.Tensor,
|
| 262 |
+
codebook_indexes: torch.Tensor,
|
| 263 |
+
reduction: str = "sum",
|
| 264 |
+
):
|
| 265 |
+
# align the encoder features with the codebook indexes
|
| 266 |
+
if self.interpolate_teacher:
|
| 267 |
+
codebook_indexes = self.interpolate_codebook_indexes(
|
| 268 |
+
encoder_out, codebook_indexes
|
| 269 |
+
)
|
| 270 |
+
else:
|
| 271 |
+
if codebook_indexes.shape[1] != encoder_out.shape[1]:
|
| 272 |
+
# align the codebook indexes to the frame rate of the student encoder out
|
| 273 |
+
codebook_indexes = self.concat_successive_codebook_indexes(
|
| 274 |
+
encoder_out, codebook_indexes, ratio=self.teacher_frame_ratio
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# the delta is associated with the frame-rate of the encoder
|
| 278 |
+
# so a bigger delta maybe necessary for 50Hz student encoder
|
| 279 |
+
if self.distillation_delta > 0:
|
| 280 |
+
codebook_indexes = codebook_indexes[:,:-self.distillation_delta, :]
|
| 281 |
+
encoder_out = encoder_out[:, self.distillation_delta:, :]
|
| 282 |
+
truncated_padding_mask = make_pad_mask(encoder_out_lens - self.distillation_delta)
|
| 283 |
+
codebook_indexes = codebook_indexes.masked_fill(truncated_padding_mask.unsqueeze(-1), value=-100)
|
| 284 |
+
|
| 285 |
+
N,T,_ = encoder_out.shape
|
| 286 |
+
codebook_loss = self.codebook_loss_net(encoder_out.float(), codebook_indexes)
|
| 287 |
+
codebook_loss = codebook_loss.reshape(N,T,-1)
|
| 288 |
+
num_cb = codebook_loss.size(-1)
|
| 289 |
+
# normalize the loss by the number of codebooks
|
| 290 |
+
if reduction == "sum":
|
| 291 |
+
codebook_loss = codebook_loss.sum(dim=(1,2)) / num_cb # (B,)
|
| 292 |
+
elif reduction == "none":
|
| 293 |
+
codebook_loss = codebook_loss.sum(dim=2) / num_cb # (B,T)
|
| 294 |
+
else:
|
| 295 |
+
raise NotImplementedError()
|
| 296 |
+
|
| 297 |
+
return codebook_loss
|
| 298 |
+
|
| 299 |
+
def forward_audio_tagging(
|
| 300 |
+
self,
|
| 301 |
+
encoder_out: torch.Tensor,
|
| 302 |
+
encoder_out_lens: torch.Tensor,
|
| 303 |
+
target: torch.Tensor = None,
|
| 304 |
+
return_logits: bool = False,
|
| 305 |
+
):
|
| 306 |
+
# target: (N, num_events)
|
| 307 |
+
logits = self.audio_tagging_proj(encoder_out) # (N, T, num_classes)
|
| 308 |
+
padding_mask = make_pad_mask(encoder_out_lens) # (N,T)
|
| 309 |
+
logits[padding_mask] = 0
|
| 310 |
+
logits = logits.sum(dim=1)
|
| 311 |
+
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) # (N, num_events)
|
| 312 |
+
if return_logits:
|
| 313 |
+
return logits
|
| 314 |
+
|
| 315 |
+
at_loss = F.binary_cross_entropy_with_logits(logits, target, reduction="none")
|
| 316 |
+
|
| 317 |
+
return at_loss
|
| 318 |
+
|
| 319 |
+
def apply_mask(
|
| 320 |
+
self,
|
| 321 |
+
x: torch.Tensor,
|
| 322 |
+
padding_mask: torch.Tensor = None
|
| 323 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 324 |
+
"""Apply mask according to the mask_mode, return the masked features and the masked positions
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
x (torch.Tensor): The input fbank features
|
| 328 |
+
padding_mask (torch.Tensor, optional): The padding mask
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
The masked fbank feature and the masked_indices, with masked positions as 1
|
| 332 |
+
"""
|
| 333 |
+
# apply mask to the fbank features, two modes applicable
|
| 334 |
+
if self.mask_mode == "w2v2":
|
| 335 |
+
x, masked_indices = self.apply_mask_w2v2(x, padding_mask)
|
| 336 |
+
elif self.mask_mode == "block":
|
| 337 |
+
x, masked_indices = self.apply_mask_block(x, padding_mask)
|
| 338 |
+
else:
|
| 339 |
+
raise NotImplementedError()
|
| 340 |
+
|
| 341 |
+
if random.random() > 0.97:
|
| 342 |
+
logging.info(f"Apply {self.mask_mode} masking. A proportion of {masked_indices.sum()/masked_indices.numel():.2f} frames are masked")
|
| 343 |
+
return x, masked_indices
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def apply_mask_block(
|
| 347 |
+
self,
|
| 348 |
+
x: torch.Tensor,
|
| 349 |
+
padding_mask: torch.Tensor = None
|
| 350 |
+
):
|
| 351 |
+
B,T,C = x.shape
|
| 352 |
+
assert self.mask_prob > 0.0
|
| 353 |
+
|
| 354 |
+
mask_indices = compute_mask_indices_block(
|
| 355 |
+
shape=(B,T),
|
| 356 |
+
padding_mask=padding_mask,
|
| 357 |
+
mask_prob=self.mask_prob,
|
| 358 |
+
mask_length=self.mask_length,
|
| 359 |
+
min_masks=self.min_masks,
|
| 360 |
+
).to(x.device)
|
| 361 |
+
|
| 362 |
+
x = index_put(x, mask_indices.bool(), self.mask_emb)
|
| 363 |
+
|
| 364 |
+
return x, mask_indices
|
| 365 |
+
|
| 366 |
+
def apply_mask_w2v2(
|
| 367 |
+
self,
|
| 368 |
+
x: torch.Tensor,
|
| 369 |
+
padding_mask: torch.Tensor = None
|
| 370 |
+
):
|
| 371 |
+
# this function is modified from fairseq: https://github.com/facebookresearch/fairseq/blob/bedb259bf34a9fc22073c13a1cee23192fa70ef3/fairseq/models/wav2vec/wav2vec2.py#L429
|
| 372 |
+
# The masked indices have value 1
|
| 373 |
+
B, T, C = x.shape
|
| 374 |
+
|
| 375 |
+
# we mask channel first, then mask timestamps
|
| 376 |
+
if self.mask_channel_prob > 0:
|
| 377 |
+
mask_channel_indices = compute_mask_indices(
|
| 378 |
+
(B, C),
|
| 379 |
+
None,
|
| 380 |
+
self.mask_channel_prob,
|
| 381 |
+
self.mask_channel_length,
|
| 382 |
+
self.mask_channel_selection,
|
| 383 |
+
self.mask_channel_other,
|
| 384 |
+
no_overlap=False,
|
| 385 |
+
min_space=1,
|
| 386 |
+
require_same_masks=False,
|
| 387 |
+
)
|
| 388 |
+
mask_channel_indices = (
|
| 389 |
+
torch.from_numpy(mask_channel_indices)
|
| 390 |
+
.to(x.device)
|
| 391 |
+
.unsqueeze(1)
|
| 392 |
+
.expand(-1, T, -1)
|
| 393 |
+
)
|
| 394 |
+
if random.random() > 0.98:
|
| 395 |
+
logging.info(f"A proportion of {mask_channel_indices.sum()/mask_channel_indices.numel():.2f} feature dims are masked")
|
| 396 |
+
x[mask_channel_indices] = 0
|
| 397 |
+
|
| 398 |
+
if self.mask_prob > 0:
|
| 399 |
+
mask_indices = compute_mask_indices(
|
| 400 |
+
(B, T),
|
| 401 |
+
padding_mask,
|
| 402 |
+
self.mask_prob,
|
| 403 |
+
self.mask_length,
|
| 404 |
+
mask_type=self.mask_selection,
|
| 405 |
+
mask_other=self.mask_other,
|
| 406 |
+
min_masks=2, # fixed
|
| 407 |
+
no_overlap=False, # False
|
| 408 |
+
min_space=1, # 1
|
| 409 |
+
require_same_masks=False,
|
| 410 |
+
)
|
| 411 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
| 412 |
+
x = index_put(x, mask_indices, self.mask_emb)
|
| 413 |
+
mask_indices = mask_indices.float()
|
| 414 |
+
else:
|
| 415 |
+
mask_indices = None
|
| 416 |
+
|
| 417 |
+
return x, mask_indices
|
| 418 |
+
|
| 419 |
+
@staticmethod
|
| 420 |
+
def interpolate_codebook_indexes(middle_layer_output, codebook_indexes):
|
| 421 |
+
# This function addresses the case where the teacher has a lower frame rate
|
| 422 |
+
# than the student model
|
| 423 |
+
t_expected = middle_layer_output.shape[1]
|
| 424 |
+
N, T, C = codebook_indexes.shape # C should be 256
|
| 425 |
+
|
| 426 |
+
codebook_indexes = codebook_indexes.permute(0,2,1).float() # (N,C,T)
|
| 427 |
+
codebook_indexes = torch.nn.functional.interpolate(codebook_indexes, t_expected)
|
| 428 |
+
codebook_indexes = codebook_indexes.permute(0,2,1).int() # (N,T,C)
|
| 429 |
+
|
| 430 |
+
assert codebook_indexes.shape[1] == middle_layer_output.shape[1]
|
| 431 |
+
return codebook_indexes
|
| 432 |
+
|
| 433 |
+
@staticmethod
|
| 434 |
+
def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes, ratio=2):
|
| 435 |
+
# Output rate of hubert is 50 frames per second,
|
| 436 |
+
# while that of current encoder is 25.
|
| 437 |
+
# Following code handling two issues:
|
| 438 |
+
# 1.
|
| 439 |
+
# Roughly speaking, to generate another frame output,
|
| 440 |
+
# hubert needes extra two frames,
|
| 441 |
+
# while current encoder needs extra four frames.
|
| 442 |
+
# Suppose there are only extra three frames provided,
|
| 443 |
+
# hubert will generate another frame while current encoder does nothing.
|
| 444 |
+
# 2.
|
| 445 |
+
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
|
| 446 |
+
# learns from 50 frames teacher output, two successive frames of teacher model
|
| 447 |
+
# output is concatenated together.
|
| 448 |
+
t_expected = middle_layer_output.shape[1]
|
| 449 |
+
N, T, C = codebook_indexes.shape # C should be 256
|
| 450 |
+
|
| 451 |
+
# Handling issue 1.
|
| 452 |
+
if T >= t_expected * ratio:
|
| 453 |
+
codebook_indexes = codebook_indexes[:, : t_expected * ratio, :]
|
| 454 |
+
else:
|
| 455 |
+
assert t_expected * ratio - T <= 5, (T, t_expected, ratio)
|
| 456 |
+
diff = t_expected * ratio - T
|
| 457 |
+
codebook_indexes = torch.cat(
|
| 458 |
+
[
|
| 459 |
+
codebook_indexes,
|
| 460 |
+
torch.full((N,diff,C), -100).to(codebook_indexes.device).to(codebook_indexes.dtype)
|
| 461 |
+
],
|
| 462 |
+
dim=1,
|
| 463 |
+
)
|
| 464 |
+
assert codebook_indexes.size(1) == middle_layer_output.size(1) * ratio
|
| 465 |
+
|
| 466 |
+
# Handling issue 2.
|
| 467 |
+
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * ratio)
|
| 468 |
+
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
|
| 469 |
+
return codebook_indexes
|
| 470 |
+
|
| 471 |
+
def index_put(tensor, indices, value):
|
| 472 |
+
tensor[indices] = value
|
| 473 |
+
return tensor
|
| 474 |
+
|
| 475 |
+
def compute_mask_indices_block(
|
| 476 |
+
shape,
|
| 477 |
+
padding_mask,
|
| 478 |
+
mask_prob: float = 0.5,
|
| 479 |
+
mask_length: int = 10,
|
| 480 |
+
min_masks: int = 2,
|
| 481 |
+
):
|
| 482 |
+
# self-implemented mask, no overlap
|
| 483 |
+
B,T = shape
|
| 484 |
+
mask_indices = []
|
| 485 |
+
for i in range(B):
|
| 486 |
+
if padding_mask is not None:
|
| 487 |
+
num_segments = (T - padding_mask[i].sum()) // mask_length # discard the last few frames
|
| 488 |
+
else:
|
| 489 |
+
num_segments = T // mask_length
|
| 490 |
+
segment_mask = torch.rand(num_segments) < mask_prob
|
| 491 |
+
while sum(segment_mask) < min_masks:
|
| 492 |
+
segment_mask = torch.rand(num_segments) < mask_prob
|
| 493 |
+
segment_mask_expanded = segment_mask.unsqueeze(-1).expand(num_segments, mask_length)
|
| 494 |
+
segment_mask_expanded = segment_mask_expanded.reshape(-1).float()
|
| 495 |
+
if segment_mask_expanded.size(0) < T:
|
| 496 |
+
pad = T - segment_mask_expanded.size(0)
|
| 497 |
+
segment_mask_expanded = torch.cat([segment_mask_expanded, torch.zeros(pad)])
|
| 498 |
+
mask_indices.append(segment_mask_expanded)
|
| 499 |
+
|
| 500 |
+
mask_indices = torch.stack(mask_indices)
|
| 501 |
+
return mask_indices
|
| 502 |
+
|
| 503 |
+
def compute_mask_indices(
|
| 504 |
+
shape: Tuple[int, int],
|
| 505 |
+
padding_mask: Optional[torch.Tensor],
|
| 506 |
+
mask_prob: float,
|
| 507 |
+
mask_length: int,
|
| 508 |
+
mask_type: str = "static",
|
| 509 |
+
mask_other: float = 0.0,
|
| 510 |
+
min_masks: int = 0,
|
| 511 |
+
no_overlap: bool = False,
|
| 512 |
+
min_space: int = 0,
|
| 513 |
+
require_same_masks: bool = True,
|
| 514 |
+
mask_dropout: float = 0.0,
|
| 515 |
+
add_masks: bool = False,
|
| 516 |
+
seed: Optional[int] = None,
|
| 517 |
+
epoch: Optional[int] = None,
|
| 518 |
+
indices: Optional[torch.Tensor] = None,
|
| 519 |
+
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
|
| 520 |
+
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
|
| 521 |
+
) -> np.ndarray:
|
| 522 |
+
"""
|
| 523 |
+
Computes random mask spans for a given shape
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
shape: the the shape for which to compute masks.
|
| 527 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
| 528 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
| 529 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
| 530 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
| 531 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
| 532 |
+
mask_type: how to compute mask lengths
|
| 533 |
+
static = fixed size
|
| 534 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
| 535 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
| 536 |
+
poisson = sample from possion distribution with lambda = mask length
|
| 537 |
+
min_masks: minimum number of masked spans
|
| 538 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
| 539 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
| 540 |
+
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
| 541 |
+
mask_dropout: randomly dropout this percentage of masks in each example
|
| 542 |
+
"""
|
| 543 |
+
|
| 544 |
+
bsz, all_sz = shape
|
| 545 |
+
mask = np.full((bsz, all_sz), False)
|
| 546 |
+
|
| 547 |
+
if num_mask_ver == 1:
|
| 548 |
+
all_num_mask = int(
|
| 549 |
+
# add a random number for probabilistic rounding
|
| 550 |
+
mask_prob * all_sz / float(mask_length)
|
| 551 |
+
+ np.random.rand()
|
| 552 |
+
)
|
| 553 |
+
all_num_mask = max(min_masks, all_num_mask)
|
| 554 |
+
|
| 555 |
+
mask_idcs = []
|
| 556 |
+
for i in range(bsz):
|
| 557 |
+
if seed is not None and epoch is not None and indices is not None:
|
| 558 |
+
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
|
| 559 |
+
else:
|
| 560 |
+
seed_i = None
|
| 561 |
+
|
| 562 |
+
rng = np.random.default_rng(seed_i)
|
| 563 |
+
|
| 564 |
+
if padding_mask is not None:
|
| 565 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
| 566 |
+
assert sz >= 0, sz
|
| 567 |
+
else:
|
| 568 |
+
sz = all_sz
|
| 569 |
+
|
| 570 |
+
if num_mask_ver == 1:
|
| 571 |
+
if padding_mask is not None:
|
| 572 |
+
num_mask = int(
|
| 573 |
+
# add a random number for probabilistic rounding
|
| 574 |
+
mask_prob * sz / float(mask_length)
|
| 575 |
+
+ np.random.rand()
|
| 576 |
+
)
|
| 577 |
+
num_mask = max(min_masks, num_mask)
|
| 578 |
+
else:
|
| 579 |
+
num_mask = all_num_mask
|
| 580 |
+
elif num_mask_ver == 2:
|
| 581 |
+
num_mask = int(
|
| 582 |
+
# add a random number for probabilistic rounding
|
| 583 |
+
mask_prob * sz / float(mask_length)
|
| 584 |
+
+ rng.random()
|
| 585 |
+
)
|
| 586 |
+
num_mask = max(min_masks, num_mask)
|
| 587 |
+
hard_max = sz // mask_length
|
| 588 |
+
num_mask = min(hard_max, num_mask) # prevent whole sequence being masked
|
| 589 |
+
else:
|
| 590 |
+
raise ValueError()
|
| 591 |
+
|
| 592 |
+
if mask_type == "static":
|
| 593 |
+
lengths = np.full(num_mask, mask_length)
|
| 594 |
+
elif mask_type == "uniform":
|
| 595 |
+
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
| 596 |
+
elif mask_type == "normal":
|
| 597 |
+
lengths = rng.normal(mask_length, mask_other, size=num_mask)
|
| 598 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
| 599 |
+
elif mask_type == "poisson":
|
| 600 |
+
lengths = rng.poisson(mask_length, size=num_mask)
|
| 601 |
+
lengths = [int(round(x)) for x in lengths]
|
| 602 |
+
else:
|
| 603 |
+
raise Exception("unknown mask selection " + mask_type)
|
| 604 |
+
|
| 605 |
+
if sum(lengths) == 0:
|
| 606 |
+
if mask_type == "static":
|
| 607 |
+
raise ValueError("this should never happens")
|
| 608 |
+
else:
|
| 609 |
+
lengths = [min(mask_length, sz - 1)]
|
| 610 |
+
|
| 611 |
+
if no_overlap:
|
| 612 |
+
mask_idc = []
|
| 613 |
+
|
| 614 |
+
def arrange(s, e, length, keep_length):
|
| 615 |
+
span_start = rng.randint(s, e - length)
|
| 616 |
+
mask_idc.extend(span_start + i for i in range(length))
|
| 617 |
+
|
| 618 |
+
new_parts = []
|
| 619 |
+
if span_start - s - min_space >= keep_length:
|
| 620 |
+
new_parts.append((s, span_start - min_space + 1))
|
| 621 |
+
if e - span_start - length - min_space > keep_length:
|
| 622 |
+
new_parts.append((span_start + length + min_space, e))
|
| 623 |
+
return new_parts
|
| 624 |
+
|
| 625 |
+
parts = [(0, sz)]
|
| 626 |
+
min_length = min(lengths)
|
| 627 |
+
for length in sorted(lengths, reverse=True):
|
| 628 |
+
lens = np.fromiter(
|
| 629 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
| 630 |
+
np.int,
|
| 631 |
+
)
|
| 632 |
+
l_sum = np.sum(lens)
|
| 633 |
+
if l_sum == 0:
|
| 634 |
+
break
|
| 635 |
+
probs = lens / np.sum(lens)
|
| 636 |
+
c = rng.choice(len(parts), p=probs)
|
| 637 |
+
s, e = parts.pop(c)
|
| 638 |
+
parts.extend(arrange(s, e, length, min_length))
|
| 639 |
+
mask_idc = np.asarray(mask_idc)
|
| 640 |
+
else:
|
| 641 |
+
if idc_select_ver == 1:
|
| 642 |
+
min_len = min(lengths)
|
| 643 |
+
if sz - min_len <= num_mask:
|
| 644 |
+
min_len = sz - num_mask - 1
|
| 645 |
+
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
| 646 |
+
elif idc_select_ver == 2:
|
| 647 |
+
mask_idc = rng.choice(sz, num_mask, replace=False)
|
| 648 |
+
else:
|
| 649 |
+
raise ValueError()
|
| 650 |
+
|
| 651 |
+
mask_idc = np.asarray(
|
| 652 |
+
[
|
| 653 |
+
mask_idc[j] + offset
|
| 654 |
+
for j in range(len(mask_idc))
|
| 655 |
+
for offset in range(lengths[j])
|
| 656 |
+
]
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
mask_idc = np.unique(mask_idc[mask_idc < sz])
|
| 660 |
+
if len(mask_idc) >= sz:
|
| 661 |
+
|
| 662 |
+
raise ValueError(
|
| 663 |
+
(
|
| 664 |
+
f"the entire sequence is masked. "
|
| 665 |
+
f"sz={sz}; mask_idc[mask_idc]; "
|
| 666 |
+
f"index={indices[i] if indices is not None else None}"
|
| 667 |
+
)
|
| 668 |
+
)
|
| 669 |
+
mask_idcs.append(mask_idc)
|
| 670 |
+
|
| 671 |
+
target_len = None
|
| 672 |
+
if require_same_masks:
|
| 673 |
+
if add_masks:
|
| 674 |
+
target_len = max([len(m) for m in mask_idcs])
|
| 675 |
+
else:
|
| 676 |
+
target_len = min([len(m) for m in mask_idcs])
|
| 677 |
+
|
| 678 |
+
for i, mask_idc in enumerate(mask_idcs):
|
| 679 |
+
if target_len is not None and len(mask_idc) > target_len:
|
| 680 |
+
mask_idc = rng.choice(mask_idc, target_len, replace=False)
|
| 681 |
+
|
| 682 |
+
mask[i, mask_idc] = True
|
| 683 |
+
|
| 684 |
+
if target_len is not None and len(mask_idc) < target_len:
|
| 685 |
+
unmasked = np.flatnonzero(~mask[i])
|
| 686 |
+
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
|
| 687 |
+
mask[i, to_mask] = True
|
| 688 |
+
|
| 689 |
+
if mask_dropout > 0:
|
| 690 |
+
masked = np.flatnonzero(mask[i])
|
| 691 |
+
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
|
| 692 |
+
to_drop = rng.choice(masked, num_holes, replace=False)
|
| 693 |
+
mask[i, to_drop] = False
|
| 694 |
+
|
| 695 |
+
return mask
|
| 696 |
+
|
| 697 |
+
def _test_w2v2_channel_mask():
|
| 698 |
+
x = torch.ones(100, 1000, 128)
|
| 699 |
+
B, T, C = x.shape
|
| 700 |
+
|
| 701 |
+
configs = [(0.25, 15), (0.25, 20), (0.5, 15),]
|
| 702 |
+
# configs = [(0.2, 20), (0.3, 20), (0.4, 20),]
|
| 703 |
+
for config in configs:
|
| 704 |
+
mask_channel_prob, mask_channel_length = config
|
| 705 |
+
ratios = []
|
| 706 |
+
for i in range(20):
|
| 707 |
+
mask_channel_indices = compute_mask_indices(
|
| 708 |
+
(B, C),
|
| 709 |
+
None,
|
| 710 |
+
mask_channel_prob,
|
| 711 |
+
mask_channel_length,
|
| 712 |
+
"static",
|
| 713 |
+
0.0,
|
| 714 |
+
no_overlap=False,
|
| 715 |
+
min_space=1,
|
| 716 |
+
require_same_masks=False,
|
| 717 |
+
)
|
| 718 |
+
mask_channel_indices = (
|
| 719 |
+
torch.from_numpy(mask_channel_indices)
|
| 720 |
+
.to(x.device)
|
| 721 |
+
.unsqueeze(1)
|
| 722 |
+
.expand(-1, T, -1)
|
| 723 |
+
)
|
| 724 |
+
ratio = mask_channel_indices.sum() / mask_channel_indices.numel()
|
| 725 |
+
ratios.append(ratio)
|
| 726 |
+
import pdb; pdb.set_trace()
|
| 727 |
+
avg_ratio = sum(ratios) / len(ratios)
|
| 728 |
+
print(f"Current config: mask_channel_prob = {mask_channel_prob}, mask_channel_length = {mask_channel_length}")
|
| 729 |
+
print(f"Averaged masking ratio: {avg_ratio}")
|
| 730 |
+
|
| 731 |
+
def _test_w2v2_mask():
|
| 732 |
+
x = torch.ones(100, 1000, 128)
|
| 733 |
+
B, T, C = x.shape
|
| 734 |
+
|
| 735 |
+
mask_prob = 0.65
|
| 736 |
+
mask_length = 10
|
| 737 |
+
|
| 738 |
+
# configs = [(0.65, 10), (0.01, 40), (0.1, 40), (0.2, 40), (0.2, 20), (0.35, 10), (0.35, 20), (0.25, 20)]
|
| 739 |
+
configs = []
|
| 740 |
+
for i in range(6):
|
| 741 |
+
p = 0.05 + (i+1) * 0.1
|
| 742 |
+
for l in [10, 20, 30, 40]:
|
| 743 |
+
configs.append((p, l))
|
| 744 |
+
configs = [(0.65, 10), (0.02, 40), (0.05, 40), (0.1, 40)]
|
| 745 |
+
for config in configs:
|
| 746 |
+
mask_prob, mask_length = config
|
| 747 |
+
ratios = []
|
| 748 |
+
for i in range(20):
|
| 749 |
+
mask_indices = compute_mask_indices(
|
| 750 |
+
(B, T),
|
| 751 |
+
None,
|
| 752 |
+
mask_prob,
|
| 753 |
+
mask_length,
|
| 754 |
+
mask_type="static",
|
| 755 |
+
mask_other=0.0,
|
| 756 |
+
min_masks=2,
|
| 757 |
+
no_overlap=False, # False
|
| 758 |
+
min_space=1, # 1
|
| 759 |
+
require_same_masks=False,
|
| 760 |
+
)
|
| 761 |
+
mask_indices = torch.from_numpy(mask_indices)
|
| 762 |
+
ratio = mask_indices.sum() / mask_indices.numel()
|
| 763 |
+
ratios.append(ratio)
|
| 764 |
+
avg_ratio = sum(ratios) / len(ratios)
|
| 765 |
+
print(f"Current config: mask_prob = {mask_prob}, mask_length = {mask_length}")
|
| 766 |
+
print(f"Averaged masking ratio: {avg_ratio}")
|
| 767 |
+
|
| 768 |
+
def _test_custom_mask():
|
| 769 |
+
x = torch.ones(100, 1000, 128)
|
| 770 |
+
B, T, C = x.shape
|
| 771 |
+
|
| 772 |
+
configs = [(0.5, 20), (0.2, 20), (0.3, 20), (0.4, 20), (0.5, 20)]
|
| 773 |
+
for config in configs:
|
| 774 |
+
mask_prob, mask_length = config
|
| 775 |
+
ratios = []
|
| 776 |
+
for i in range(20):
|
| 777 |
+
all_possible_mask_lengths = [mask_length + i * 2 for i in range(-5, 6)]
|
| 778 |
+
mask_length = random.sample(all_possible_mask_lengths, 1)[0]
|
| 779 |
+
assert mask_length > 0, f"Sampled mask_length smaller than 0, {mask_length}"
|
| 780 |
+
|
| 781 |
+
mask_indices = compute_mask_indices_block(
|
| 782 |
+
shape=(B, T),
|
| 783 |
+
padding_mask=None,
|
| 784 |
+
mask_prob=mask_prob,
|
| 785 |
+
mask_length=mask_length,
|
| 786 |
+
min_masks=2,
|
| 787 |
+
)
|
| 788 |
+
import pdb; pdb.set_trace()
|
| 789 |
+
ratio = mask_indices.sum() / mask_indices.numel()
|
| 790 |
+
ratios.append(ratio)
|
| 791 |
+
avg_ratio = sum(ratios) / len(ratios)
|
| 792 |
+
print(f"Current config: mask_prob = {mask_prob}, mask_length = {mask_length}")
|
| 793 |
+
print(f"Averaged masking ratio: {avg_ratio}")
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
if __name__=="__main__":
|
| 797 |
+
_test_w2v2_channel_mask()
|
| 798 |
+
# _test_w2v2_mask()
|
| 799 |
+
# _test_custom_mask()
|
scaling.py
ADDED
|
@@ -0,0 +1,1913 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey)
|
| 2 |
+
#
|
| 3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import logging
|
| 19 |
+
import math
|
| 20 |
+
import random
|
| 21 |
+
from typing import Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
# import k2
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
from torch import Tensor
|
| 27 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
|
| 31 |
+
max_value = torch.max(x, y)
|
| 32 |
+
diff = torch.abs(x - y)
|
| 33 |
+
return max_value + torch.log1p(torch.exp(-diff))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
|
| 37 |
+
# 14 is not supported. Please feel free to request support or submit
|
| 38 |
+
# a pull request on PyTorch GitHub.
|
| 39 |
+
#
|
| 40 |
+
# The following function is to solve the above error when exporting
|
| 41 |
+
# models to ONNX via torch.jit.trace()
|
| 42 |
+
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
| 43 |
+
# Caution(fangjun): Put torch.jit.is_scripting() before
|
| 44 |
+
# torch.onnx.is_in_onnx_export();
|
| 45 |
+
# otherwise, it will cause errors for torch.jit.script().
|
| 46 |
+
#
|
| 47 |
+
# torch.logaddexp() works for both torch.jit.script() and
|
| 48 |
+
# torch.jit.trace() but it causes errors for ONNX export.
|
| 49 |
+
#
|
| 50 |
+
if torch.jit.is_scripting():
|
| 51 |
+
# Note: We cannot use torch.jit.is_tracing() here as it also
|
| 52 |
+
# matches torch.onnx.export().
|
| 53 |
+
return torch.logaddexp(x, y)
|
| 54 |
+
elif torch.onnx.is_in_onnx_export():
|
| 55 |
+
return logaddexp_onnx(x, y)
|
| 56 |
+
else:
|
| 57 |
+
# for torch.jit.trace()
|
| 58 |
+
return torch.logaddexp(x, y)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class PiecewiseLinear(object):
|
| 62 |
+
"""
|
| 63 |
+
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
|
| 64 |
+
the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
|
| 65 |
+
respectively.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, *args):
|
| 69 |
+
assert len(args) >= 1, len(args)
|
| 70 |
+
if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
|
| 71 |
+
self.pairs = list(args[0].pairs)
|
| 72 |
+
else:
|
| 73 |
+
self.pairs = [(float(x), float(y)) for x, y in args]
|
| 74 |
+
for x, y in self.pairs:
|
| 75 |
+
assert isinstance(x, (float, int)), type(x)
|
| 76 |
+
assert isinstance(y, (float, int)), type(y)
|
| 77 |
+
|
| 78 |
+
for i in range(len(self.pairs) - 1):
|
| 79 |
+
assert self.pairs[i + 1][0] > self.pairs[i][0], (
|
| 80 |
+
i,
|
| 81 |
+
self.pairs[i],
|
| 82 |
+
self.pairs[i + 1],
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def __str__(self):
|
| 86 |
+
# e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
|
| 87 |
+
return f"PiecewiseLinear({str(self.pairs)[1:-1]})"
|
| 88 |
+
|
| 89 |
+
def __call__(self, x):
|
| 90 |
+
if x <= self.pairs[0][0]:
|
| 91 |
+
return self.pairs[0][1]
|
| 92 |
+
elif x >= self.pairs[-1][0]:
|
| 93 |
+
return self.pairs[-1][1]
|
| 94 |
+
else:
|
| 95 |
+
cur_x, cur_y = self.pairs[0]
|
| 96 |
+
for i in range(1, len(self.pairs)):
|
| 97 |
+
next_x, next_y = self.pairs[i]
|
| 98 |
+
if x >= cur_x and x <= next_x:
|
| 99 |
+
return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
|
| 100 |
+
cur_x, cur_y = next_x, next_y
|
| 101 |
+
assert False
|
| 102 |
+
|
| 103 |
+
def __mul__(self, alpha):
|
| 104 |
+
return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
|
| 105 |
+
|
| 106 |
+
def __add__(self, x):
|
| 107 |
+
if isinstance(x, (float, int)):
|
| 108 |
+
return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
|
| 109 |
+
s, x = self.get_common_basis(x)
|
| 110 |
+
return PiecewiseLinear(
|
| 111 |
+
*[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def max(self, x):
|
| 115 |
+
if isinstance(x, (float, int)):
|
| 116 |
+
x = PiecewiseLinear((0, x))
|
| 117 |
+
s, x = self.get_common_basis(x, include_crossings=True)
|
| 118 |
+
return PiecewiseLinear(
|
| 119 |
+
*[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def min(self, x):
|
| 123 |
+
if isinstance(x, float) or isinstance(x, int):
|
| 124 |
+
x = PiecewiseLinear((0, x))
|
| 125 |
+
s, x = self.get_common_basis(x, include_crossings=True)
|
| 126 |
+
return PiecewiseLinear(
|
| 127 |
+
*[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def __eq__(self, other):
|
| 131 |
+
return self.pairs == other.pairs
|
| 132 |
+
|
| 133 |
+
def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
|
| 134 |
+
"""
|
| 135 |
+
Returns (self_mod, p_mod) which are equivalent piecewise linear
|
| 136 |
+
functions to self and p, but with the same x values.
|
| 137 |
+
|
| 138 |
+
p: the other piecewise linear function
|
| 139 |
+
include_crossings: if true, include in the x values positions
|
| 140 |
+
where the functions indicate by this and p cross.
|
| 141 |
+
"""
|
| 142 |
+
assert isinstance(p, PiecewiseLinear), type(p)
|
| 143 |
+
|
| 144 |
+
# get sorted x-values without repetition.
|
| 145 |
+
x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
|
| 146 |
+
y_vals1 = [self(x) for x in x_vals]
|
| 147 |
+
y_vals2 = [p(x) for x in x_vals]
|
| 148 |
+
|
| 149 |
+
if include_crossings:
|
| 150 |
+
extra_x_vals = []
|
| 151 |
+
for i in range(len(x_vals) - 1):
|
| 152 |
+
if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
|
| 153 |
+
# if the two lines in this subsegment potentially cross each other..
|
| 154 |
+
diff_cur = abs(y_vals1[i] - y_vals2[i])
|
| 155 |
+
diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
|
| 156 |
+
# `pos`, between 0 and 1, gives the relative x position,
|
| 157 |
+
# with 0 being x_vals[i] and 1 being x_vals[i+1].
|
| 158 |
+
pos = diff_cur / (diff_cur + diff_next)
|
| 159 |
+
extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
|
| 160 |
+
extra_x_vals.append(extra_x_val)
|
| 161 |
+
if len(extra_x_vals) > 0:
|
| 162 |
+
x_vals = sorted(set(x_vals + extra_x_vals))
|
| 163 |
+
y_vals1 = [self(x) for x in x_vals]
|
| 164 |
+
y_vals2 = [p(x) for x in x_vals]
|
| 165 |
+
return (
|
| 166 |
+
PiecewiseLinear(*zip(x_vals, y_vals1)),
|
| 167 |
+
PiecewiseLinear(*zip(x_vals, y_vals2)),
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class ScheduledFloat(torch.nn.Module):
|
| 172 |
+
"""
|
| 173 |
+
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
| 174 |
+
it does not have a working forward() function. You are supposed to cast it to float, as
|
| 175 |
+
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
| 176 |
+
|
| 177 |
+
It is a floating point value whose value changes depending on the batch count of the
|
| 178 |
+
training loop. It is a piecewise linear function where you specify the (x,y) pairs
|
| 179 |
+
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
| 180 |
+
first x or after the last x, we just use the first or last y value.
|
| 181 |
+
|
| 182 |
+
Example:
|
| 183 |
+
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
| 184 |
+
|
| 185 |
+
`default` is used when self.batch_count is not set or not in training mode or in
|
| 186 |
+
torch.jit scripting mode.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def __init__(self, *args, default: float = 0.0):
|
| 190 |
+
super().__init__()
|
| 191 |
+
# self.batch_count and self.name will be written to in the training loop.
|
| 192 |
+
self.batch_count = None
|
| 193 |
+
self.name = None
|
| 194 |
+
self.default = default
|
| 195 |
+
self.schedule = PiecewiseLinear(*args)
|
| 196 |
+
|
| 197 |
+
def extra_repr(self) -> str:
|
| 198 |
+
return (
|
| 199 |
+
f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
def __float__(self):
|
| 203 |
+
batch_count = self.batch_count
|
| 204 |
+
if (
|
| 205 |
+
batch_count is None
|
| 206 |
+
or not self.training
|
| 207 |
+
or torch.jit.is_scripting()
|
| 208 |
+
or torch.jit.is_tracing()
|
| 209 |
+
):
|
| 210 |
+
return float(self.default)
|
| 211 |
+
else:
|
| 212 |
+
ans = self.schedule(self.batch_count)
|
| 213 |
+
if random.random() < 0.0002:
|
| 214 |
+
logging.info(
|
| 215 |
+
f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}"
|
| 216 |
+
)
|
| 217 |
+
return ans
|
| 218 |
+
|
| 219 |
+
def __add__(self, x):
|
| 220 |
+
if isinstance(x, float) or isinstance(x, int):
|
| 221 |
+
return ScheduledFloat(self.schedule + x, default=self.default)
|
| 222 |
+
else:
|
| 223 |
+
return ScheduledFloat(
|
| 224 |
+
self.schedule + x.schedule, default=self.default + x.default
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
def max(self, x):
|
| 228 |
+
if isinstance(x, float) or isinstance(x, int):
|
| 229 |
+
return ScheduledFloat(self.schedule.max(x), default=self.default)
|
| 230 |
+
else:
|
| 231 |
+
return ScheduledFloat(
|
| 232 |
+
self.schedule.max(x.schedule), default=max(self.default, x.default)
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
FloatLike = Union[float, ScheduledFloat]
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
|
| 240 |
+
"""
|
| 241 |
+
A randomized way of casting a floating point value to half precision.
|
| 242 |
+
"""
|
| 243 |
+
if x.dtype == torch.float16:
|
| 244 |
+
return x
|
| 245 |
+
x_abs = x.abs()
|
| 246 |
+
is_too_small = x_abs < min_abs
|
| 247 |
+
# for elements where is_too_small is true, random_val will contain +-min_abs with
|
| 248 |
+
# probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
|
| 249 |
+
# for those elements].
|
| 250 |
+
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
| 251 |
+
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class CutoffEstimator:
|
| 255 |
+
"""
|
| 256 |
+
Estimates cutoffs of an arbitrary numerical quantity such that a specified
|
| 257 |
+
proportion of items will be above the cutoff on average.
|
| 258 |
+
|
| 259 |
+
p is the proportion of items that should be above the cutoff.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
def __init__(self, p: float):
|
| 263 |
+
self.p = p
|
| 264 |
+
# total count of items
|
| 265 |
+
self.count = 0
|
| 266 |
+
# total count of items that were above the cutoff
|
| 267 |
+
self.count_above = 0
|
| 268 |
+
# initial cutoff value
|
| 269 |
+
self.cutoff = 0
|
| 270 |
+
|
| 271 |
+
def __call__(self, x: float) -> bool:
|
| 272 |
+
"""
|
| 273 |
+
Returns true if x is above the cutoff.
|
| 274 |
+
"""
|
| 275 |
+
ans = x > self.cutoff
|
| 276 |
+
self.count += 1
|
| 277 |
+
if ans:
|
| 278 |
+
self.count_above += 1
|
| 279 |
+
cur_p = self.count_above / self.count
|
| 280 |
+
delta_p = cur_p - self.p
|
| 281 |
+
if (delta_p > 0) == ans:
|
| 282 |
+
q = abs(delta_p)
|
| 283 |
+
self.cutoff = x * q + self.cutoff * (1 - q)
|
| 284 |
+
return ans
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class SoftmaxFunction(torch.autograd.Function):
|
| 288 |
+
"""
|
| 289 |
+
Tries to handle half-precision derivatives in a randomized way that should
|
| 290 |
+
be more accurate for training than the default behavior.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
@staticmethod
|
| 294 |
+
def forward(ctx, x: Tensor, dim: int):
|
| 295 |
+
ans = x.softmax(dim=dim)
|
| 296 |
+
# if x dtype is float16, x.softmax() returns a float32 because
|
| 297 |
+
# (presumably) that op does not support float16, and autocast
|
| 298 |
+
# is enabled.
|
| 299 |
+
if torch.is_autocast_enabled():
|
| 300 |
+
ans = ans.to(torch.get_autocast_gpu_dtype())
|
| 301 |
+
ctx.save_for_backward(ans)
|
| 302 |
+
ctx.x_dtype = x.dtype
|
| 303 |
+
ctx.dim = dim
|
| 304 |
+
return ans
|
| 305 |
+
|
| 306 |
+
@staticmethod
|
| 307 |
+
def backward(ctx, ans_grad: Tensor):
|
| 308 |
+
(ans,) = ctx.saved_tensors
|
| 309 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 310 |
+
ans_grad = ans_grad.to(torch.float32)
|
| 311 |
+
ans = ans.to(torch.float32)
|
| 312 |
+
x_grad = ans_grad * ans
|
| 313 |
+
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
| 314 |
+
return x_grad, None
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def softmax(x: Tensor, dim: int):
|
| 318 |
+
if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 319 |
+
return x.softmax(dim=dim)
|
| 320 |
+
|
| 321 |
+
return SoftmaxFunction.apply(x, dim)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class MaxEigLimiterFunction(torch.autograd.Function):
|
| 325 |
+
@staticmethod
|
| 326 |
+
def forward(
|
| 327 |
+
ctx,
|
| 328 |
+
x: Tensor,
|
| 329 |
+
coeffs: Tensor,
|
| 330 |
+
direction: Tensor,
|
| 331 |
+
channel_dim: int,
|
| 332 |
+
grad_scale: float,
|
| 333 |
+
) -> Tensor:
|
| 334 |
+
ctx.channel_dim = channel_dim
|
| 335 |
+
ctx.grad_scale = grad_scale
|
| 336 |
+
ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
|
| 337 |
+
return x
|
| 338 |
+
|
| 339 |
+
@staticmethod
|
| 340 |
+
def backward(ctx, x_grad, *args):
|
| 341 |
+
with torch.enable_grad():
|
| 342 |
+
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
| 343 |
+
x_orig.requires_grad = True
|
| 344 |
+
num_channels = x_orig.shape[ctx.channel_dim]
|
| 345 |
+
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
| 346 |
+
new_direction.requires_grad = False
|
| 347 |
+
x = x - x.mean(dim=0)
|
| 348 |
+
x_var = (x**2).mean()
|
| 349 |
+
x_residual = x - coeffs * new_direction
|
| 350 |
+
x_residual_var = (x_residual**2).mean()
|
| 351 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
| 352 |
+
# by the top eigen-direction. This is to be minimized.
|
| 353 |
+
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
| 354 |
+
variance_proportion.backward()
|
| 355 |
+
x_orig_grad = x_orig.grad
|
| 356 |
+
x_extra_grad = (
|
| 357 |
+
x_orig.grad
|
| 358 |
+
* ctx.grad_scale
|
| 359 |
+
* x_grad.norm()
|
| 360 |
+
/ (x_orig_grad.norm() + 1.0e-20)
|
| 361 |
+
)
|
| 362 |
+
return x_grad + x_extra_grad.detach(), None, None, None, None
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class BiasNormFunction(torch.autograd.Function):
|
| 366 |
+
# This computes:
|
| 367 |
+
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
|
| 368 |
+
# return x * scales
|
| 369 |
+
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
|
| 370 |
+
# it can just store the returned value (chances are, this will also be needed for
|
| 371 |
+
# some other reason, related to the next operation, so we can save memory).
|
| 372 |
+
@staticmethod
|
| 373 |
+
def forward(
|
| 374 |
+
ctx,
|
| 375 |
+
x: Tensor,
|
| 376 |
+
bias: Tensor,
|
| 377 |
+
log_scale: Tensor,
|
| 378 |
+
channel_dim: int,
|
| 379 |
+
store_output_for_backprop: bool,
|
| 380 |
+
) -> Tensor:
|
| 381 |
+
assert bias.ndim == 1
|
| 382 |
+
if channel_dim < 0:
|
| 383 |
+
channel_dim = channel_dim + x.ndim
|
| 384 |
+
ctx.store_output_for_backprop = store_output_for_backprop
|
| 385 |
+
ctx.channel_dim = channel_dim
|
| 386 |
+
for _ in range(channel_dim + 1, x.ndim):
|
| 387 |
+
bias = bias.unsqueeze(-1)
|
| 388 |
+
scales = (
|
| 389 |
+
torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
|
| 390 |
+
) * log_scale.exp()
|
| 391 |
+
ans = x * scales
|
| 392 |
+
ctx.save_for_backward(
|
| 393 |
+
ans.detach() if store_output_for_backprop else x,
|
| 394 |
+
scales.detach(),
|
| 395 |
+
bias.detach(),
|
| 396 |
+
log_scale.detach(),
|
| 397 |
+
)
|
| 398 |
+
return ans
|
| 399 |
+
|
| 400 |
+
@staticmethod
|
| 401 |
+
def backward(ctx, ans_grad: Tensor) -> Tensor:
|
| 402 |
+
ans_or_x, scales, bias, log_scale = ctx.saved_tensors
|
| 403 |
+
if ctx.store_output_for_backprop:
|
| 404 |
+
x = ans_or_x / scales
|
| 405 |
+
else:
|
| 406 |
+
x = ans_or_x
|
| 407 |
+
x = x.detach()
|
| 408 |
+
x.requires_grad = True
|
| 409 |
+
bias.requires_grad = True
|
| 410 |
+
log_scale.requires_grad = True
|
| 411 |
+
with torch.enable_grad():
|
| 412 |
+
# recompute scales from x, bias and log_scale.
|
| 413 |
+
scales = (
|
| 414 |
+
torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
|
| 415 |
+
) * log_scale.exp()
|
| 416 |
+
ans = x * scales
|
| 417 |
+
ans.backward(gradient=ans_grad)
|
| 418 |
+
return x.grad, bias.grad.flatten(), log_scale.grad, None, None
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class BiasNorm(torch.nn.Module):
|
| 422 |
+
"""
|
| 423 |
+
This is intended to be a simpler, and hopefully cheaper, replacement for
|
| 424 |
+
LayerNorm. The observation this is based on, is that Transformer-type
|
| 425 |
+
networks, especially with pre-norm, sometimes seem to set one of the
|
| 426 |
+
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
| 427 |
+
the LayerNorm because the output magnitude is then not strongly dependent
|
| 428 |
+
on the other (useful) features. Presumably the weight and bias of the
|
| 429 |
+
LayerNorm are required to allow it to do this.
|
| 430 |
+
|
| 431 |
+
Instead, we give the BiasNorm a trainable bias that it can use when
|
| 432 |
+
computing the scale for normalization. We also give it a (scalar)
|
| 433 |
+
trainable scale on the output.
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
num_channels: the number of channels, e.g. 512.
|
| 438 |
+
channel_dim: the axis/dimension corresponding to the channel,
|
| 439 |
+
interpreted as an offset from the input's ndim if negative.
|
| 440 |
+
This is NOT the num_channels; it should typically be one of
|
| 441 |
+
{-2, -1, 0, 1, 2, 3}.
|
| 442 |
+
log_scale: the initial log-scale that we multiply the output by; this
|
| 443 |
+
is learnable.
|
| 444 |
+
log_scale_min: FloatLike, minimum allowed value of log_scale
|
| 445 |
+
log_scale_max: FloatLike, maximum allowed value of log_scale
|
| 446 |
+
store_output_for_backprop: only possibly affects memory use; recommend
|
| 447 |
+
to set to True if you think the output of this module is more likely
|
| 448 |
+
than the input of this module to be required to be stored for the
|
| 449 |
+
backprop.
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
def __init__(
|
| 453 |
+
self,
|
| 454 |
+
num_channels: int,
|
| 455 |
+
channel_dim: int = -1, # CAUTION: see documentation.
|
| 456 |
+
log_scale: float = 1.0,
|
| 457 |
+
log_scale_min: float = -1.5,
|
| 458 |
+
log_scale_max: float = 1.5,
|
| 459 |
+
store_output_for_backprop: bool = False,
|
| 460 |
+
) -> None:
|
| 461 |
+
super(BiasNorm, self).__init__()
|
| 462 |
+
self.num_channels = num_channels
|
| 463 |
+
self.channel_dim = channel_dim
|
| 464 |
+
self.log_scale = nn.Parameter(torch.tensor(log_scale))
|
| 465 |
+
self.bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4))
|
| 466 |
+
|
| 467 |
+
self.log_scale_min = log_scale_min
|
| 468 |
+
self.log_scale_max = log_scale_max
|
| 469 |
+
|
| 470 |
+
self.store_output_for_backprop = store_output_for_backprop
|
| 471 |
+
|
| 472 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 473 |
+
assert x.shape[self.channel_dim] == self.num_channels
|
| 474 |
+
|
| 475 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 476 |
+
channel_dim = self.channel_dim
|
| 477 |
+
if channel_dim < 0:
|
| 478 |
+
channel_dim += x.ndim
|
| 479 |
+
bias = self.bias
|
| 480 |
+
for _ in range(channel_dim + 1, x.ndim):
|
| 481 |
+
bias = bias.unsqueeze(-1)
|
| 482 |
+
scales = (
|
| 483 |
+
torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
|
| 484 |
+
) * self.log_scale.exp()
|
| 485 |
+
return x * scales
|
| 486 |
+
|
| 487 |
+
log_scale = limit_param_value(
|
| 488 |
+
self.log_scale,
|
| 489 |
+
min=float(self.log_scale_min),
|
| 490 |
+
max=float(self.log_scale_max),
|
| 491 |
+
training=self.training,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
return BiasNormFunction.apply(
|
| 495 |
+
x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
| 500 |
+
"""
|
| 501 |
+
Behaves like a constructor of a modified version of nn.Linear
|
| 502 |
+
that gives an easy way to set the default initial parameter scale.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
| 506 |
+
e.g. in_features, out_features, bias=False.
|
| 507 |
+
|
| 508 |
+
initial_scale: you can override this if you want to increase
|
| 509 |
+
or decrease the initial magnitude of the module's output
|
| 510 |
+
(affects the initialization of weight_scale and bias_scale).
|
| 511 |
+
Another option, if you want to do something like this, is
|
| 512 |
+
to re-initialize the parameters.
|
| 513 |
+
"""
|
| 514 |
+
ans = nn.Linear(*args, **kwargs)
|
| 515 |
+
with torch.no_grad():
|
| 516 |
+
ans.weight[:] *= initial_scale
|
| 517 |
+
if ans.bias is not None:
|
| 518 |
+
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
| 519 |
+
return ans
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
|
| 523 |
+
"""
|
| 524 |
+
Behaves like a constructor of a modified version of nn.Conv1d
|
| 525 |
+
that gives an easy way to set the default initial parameter scale.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
| 529 |
+
e.g. in_features, out_features, bias=False.
|
| 530 |
+
|
| 531 |
+
initial_scale: you can override this if you want to increase
|
| 532 |
+
or decrease the initial magnitude of the module's output
|
| 533 |
+
(affects the initialization of weight_scale and bias_scale).
|
| 534 |
+
Another option, if you want to do something like this, is
|
| 535 |
+
to re-initialize the parameters.
|
| 536 |
+
"""
|
| 537 |
+
ans = nn.Conv1d(*args, **kwargs)
|
| 538 |
+
with torch.no_grad():
|
| 539 |
+
ans.weight[:] *= initial_scale
|
| 540 |
+
if ans.bias is not None:
|
| 541 |
+
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
| 542 |
+
return ans
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d:
|
| 546 |
+
"""
|
| 547 |
+
Behaves like a constructor of a modified version of nn.Conv2d
|
| 548 |
+
that gives an easy way to set the default initial parameter scale.
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
| 552 |
+
e.g. in_features, out_features, bias=False, but:
|
| 553 |
+
NO PADDING-RELATED ARGS.
|
| 554 |
+
|
| 555 |
+
initial_scale: you can override this if you want to increase
|
| 556 |
+
or decrease the initial magnitude of the module's output
|
| 557 |
+
(affects the initialization of weight_scale and bias_scale).
|
| 558 |
+
Another option, if you want to do something like this, is
|
| 559 |
+
to re-initialize the parameters.
|
| 560 |
+
"""
|
| 561 |
+
ans = nn.Conv2d(*args, **kwargs)
|
| 562 |
+
with torch.no_grad():
|
| 563 |
+
ans.weight[:] *= initial_scale
|
| 564 |
+
if ans.bias is not None:
|
| 565 |
+
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
| 566 |
+
return ans
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
class ChunkCausalDepthwiseConv1d(torch.nn.Module):
|
| 570 |
+
"""
|
| 571 |
+
Behaves like a depthwise 1d convolution, except that it is causal in
|
| 572 |
+
a chunkwise way, as if we had a block-triangular attention mask.
|
| 573 |
+
The chunk size is provided at test time (it should probably be
|
| 574 |
+
kept in sync with the attention mask).
|
| 575 |
+
|
| 576 |
+
This has a little more than twice the parameters of a conventional
|
| 577 |
+
depthwise conv1d module: we implement it by having one
|
| 578 |
+
depthwise convolution, of half the width, that is causal (via
|
| 579 |
+
right-padding); and one depthwise convolution that is applied only
|
| 580 |
+
within chunks, that we multiply by a scaling factor which depends
|
| 581 |
+
on the position within the chunk.
|
| 582 |
+
|
| 583 |
+
Args:
|
| 584 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
| 585 |
+
e.g. in_features, out_features, bias=False.
|
| 586 |
+
|
| 587 |
+
initial_scale: you can override this if you want to increase
|
| 588 |
+
or decrease the initial magnitude of the module's output
|
| 589 |
+
(affects the initialization of weight_scale and bias_scale).
|
| 590 |
+
Another option, if you want to do something like this, is
|
| 591 |
+
to re-initialize the parameters.
|
| 592 |
+
"""
|
| 593 |
+
|
| 594 |
+
def __init__(
|
| 595 |
+
self,
|
| 596 |
+
channels: int,
|
| 597 |
+
kernel_size: int,
|
| 598 |
+
initial_scale: float = 1.0,
|
| 599 |
+
bias: bool = True,
|
| 600 |
+
):
|
| 601 |
+
super().__init__()
|
| 602 |
+
assert kernel_size % 2 == 1
|
| 603 |
+
|
| 604 |
+
half_kernel_size = (kernel_size + 1) // 2
|
| 605 |
+
# will pad manually, on one side.
|
| 606 |
+
self.causal_conv = nn.Conv1d(
|
| 607 |
+
in_channels=channels,
|
| 608 |
+
out_channels=channels,
|
| 609 |
+
groups=channels,
|
| 610 |
+
kernel_size=half_kernel_size,
|
| 611 |
+
padding=0,
|
| 612 |
+
bias=True,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
self.chunkwise_conv = nn.Conv1d(
|
| 616 |
+
in_channels=channels,
|
| 617 |
+
out_channels=channels,
|
| 618 |
+
groups=channels,
|
| 619 |
+
kernel_size=kernel_size,
|
| 620 |
+
padding=kernel_size // 2,
|
| 621 |
+
bias=bias,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
# first row is correction factors added to the scale near the left edge of the chunk,
|
| 625 |
+
# second row is correction factors added to the scale near the right edge of the chunk,
|
| 626 |
+
# both of these are added to a default scale of 1.0.
|
| 627 |
+
self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size))
|
| 628 |
+
self.kernel_size = kernel_size
|
| 629 |
+
|
| 630 |
+
with torch.no_grad():
|
| 631 |
+
self.causal_conv.weight[:] *= initial_scale
|
| 632 |
+
self.chunkwise_conv.weight[:] *= initial_scale
|
| 633 |
+
if bias:
|
| 634 |
+
torch.nn.init.uniform_(
|
| 635 |
+
self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
|
| 639 |
+
"""Forward function.
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
x: a Tensor of shape (batch_size, channels, seq_len)
|
| 643 |
+
chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
|
| 644 |
+
"""
|
| 645 |
+
(batch_size, num_channels, seq_len) = x.shape
|
| 646 |
+
|
| 647 |
+
# half_kernel_size = self.kernel_size + 1 // 2
|
| 648 |
+
# left_pad is half_kernel_size - 1 where half_kernel_size is the size used
|
| 649 |
+
# in the causal conv. It's the amount by which we must pad on the left,
|
| 650 |
+
# to make the convolution causal.
|
| 651 |
+
left_pad = self.kernel_size // 2
|
| 652 |
+
|
| 653 |
+
if chunk_size < 0 or chunk_size > seq_len:
|
| 654 |
+
chunk_size = seq_len
|
| 655 |
+
right_pad = -seq_len % chunk_size
|
| 656 |
+
|
| 657 |
+
x = torch.nn.functional.pad(x, (left_pad, right_pad))
|
| 658 |
+
|
| 659 |
+
x_causal = self.causal_conv(x[..., : left_pad + seq_len])
|
| 660 |
+
assert x_causal.shape == (batch_size, num_channels, seq_len)
|
| 661 |
+
|
| 662 |
+
x_chunk = x[..., left_pad:]
|
| 663 |
+
num_chunks = x_chunk.shape[2] // chunk_size
|
| 664 |
+
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size)
|
| 665 |
+
x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(
|
| 666 |
+
batch_size * num_chunks, num_channels, chunk_size
|
| 667 |
+
)
|
| 668 |
+
x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
|
| 669 |
+
|
| 670 |
+
chunk_scale = self._get_chunk_scale(chunk_size)
|
| 671 |
+
|
| 672 |
+
x_chunk = x_chunk * chunk_scale
|
| 673 |
+
x_chunk = x_chunk.reshape(
|
| 674 |
+
batch_size, num_chunks, num_channels, chunk_size
|
| 675 |
+
).permute(0, 2, 1, 3)
|
| 676 |
+
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[
|
| 677 |
+
..., :seq_len
|
| 678 |
+
]
|
| 679 |
+
|
| 680 |
+
return x_chunk + x_causal
|
| 681 |
+
|
| 682 |
+
def _get_chunk_scale(self, chunk_size: int):
|
| 683 |
+
"""Returns tensor of shape (num_channels, chunk_size) that will be used to
|
| 684 |
+
scale the output of self.chunkwise_conv."""
|
| 685 |
+
left_edge = self.chunkwise_conv_scale[0]
|
| 686 |
+
right_edge = self.chunkwise_conv_scale[1]
|
| 687 |
+
if chunk_size < self.kernel_size:
|
| 688 |
+
left_edge = left_edge[:, :chunk_size]
|
| 689 |
+
right_edge = right_edge[:, -chunk_size:]
|
| 690 |
+
else:
|
| 691 |
+
t = chunk_size - self.kernel_size
|
| 692 |
+
channels = left_edge.shape[0]
|
| 693 |
+
pad = torch.zeros(
|
| 694 |
+
channels, t, device=left_edge.device, dtype=left_edge.dtype
|
| 695 |
+
)
|
| 696 |
+
left_edge = torch.cat((left_edge, pad), dim=-1)
|
| 697 |
+
right_edge = torch.cat((pad, right_edge), dim=-1)
|
| 698 |
+
return 1.0 + (left_edge + right_edge)
|
| 699 |
+
|
| 700 |
+
def streaming_forward(
|
| 701 |
+
self,
|
| 702 |
+
x: Tensor,
|
| 703 |
+
cache: Tensor,
|
| 704 |
+
) -> Tuple[Tensor, Tensor]:
|
| 705 |
+
"""Streaming Forward function.
|
| 706 |
+
|
| 707 |
+
Args:
|
| 708 |
+
x: a Tensor of shape (batch_size, channels, seq_len)
|
| 709 |
+
cache: cached left context of shape (batch_size, channels, left_pad)
|
| 710 |
+
"""
|
| 711 |
+
(batch_size, num_channels, seq_len) = x.shape
|
| 712 |
+
|
| 713 |
+
# left_pad is half_kernel_size - 1 where half_kernel_size is the size used
|
| 714 |
+
# in the causal conv. It's the amount by which we must pad on the left,
|
| 715 |
+
# to make the convolution causal.
|
| 716 |
+
left_pad = self.kernel_size // 2
|
| 717 |
+
|
| 718 |
+
# Pad cache
|
| 719 |
+
assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad)
|
| 720 |
+
x = torch.cat([cache, x], dim=2)
|
| 721 |
+
# Update cache
|
| 722 |
+
cache = x[..., -left_pad:]
|
| 723 |
+
|
| 724 |
+
x_causal = self.causal_conv(x)
|
| 725 |
+
assert x_causal.shape == (batch_size, num_channels, seq_len)
|
| 726 |
+
|
| 727 |
+
x_chunk = x[..., left_pad:]
|
| 728 |
+
x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
|
| 729 |
+
|
| 730 |
+
chunk_scale = self._get_chunk_scale(chunk_size=seq_len)
|
| 731 |
+
x_chunk = x_chunk * chunk_scale
|
| 732 |
+
|
| 733 |
+
return x_chunk + x_causal, cache
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
class BalancerFunction(torch.autograd.Function):
|
| 737 |
+
@staticmethod
|
| 738 |
+
def forward(
|
| 739 |
+
ctx,
|
| 740 |
+
x: Tensor,
|
| 741 |
+
min_mean: float,
|
| 742 |
+
max_mean: float,
|
| 743 |
+
min_rms: float,
|
| 744 |
+
max_rms: float,
|
| 745 |
+
grad_scale: float,
|
| 746 |
+
channel_dim: int,
|
| 747 |
+
) -> Tensor:
|
| 748 |
+
if channel_dim < 0:
|
| 749 |
+
channel_dim += x.ndim
|
| 750 |
+
ctx.channel_dim = channel_dim
|
| 751 |
+
ctx.save_for_backward(x)
|
| 752 |
+
ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim)
|
| 753 |
+
return x
|
| 754 |
+
|
| 755 |
+
@staticmethod
|
| 756 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
| 757 |
+
(x,) = ctx.saved_tensors
|
| 758 |
+
(min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
|
| 759 |
+
|
| 760 |
+
try:
|
| 761 |
+
with torch.enable_grad():
|
| 762 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 763 |
+
x = x.to(torch.float32)
|
| 764 |
+
x = x.detach()
|
| 765 |
+
x.requires_grad = True
|
| 766 |
+
mean_dims = [i for i in range(x.ndim) if i != channel_dim]
|
| 767 |
+
uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True)
|
| 768 |
+
mean = x.mean(dim=mean_dims, keepdim=True)
|
| 769 |
+
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
|
| 770 |
+
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
|
| 771 |
+
|
| 772 |
+
m = mean / stddev
|
| 773 |
+
# part of loss that relates to mean / stddev
|
| 774 |
+
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
|
| 775 |
+
|
| 776 |
+
# put a much larger scale on the RMS-max-limit loss, so that if both it and the
|
| 777 |
+
# m_loss are violated we fix the RMS loss first.
|
| 778 |
+
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
|
| 779 |
+
r_loss = (rms_clamped / rms).log().abs()
|
| 780 |
+
|
| 781 |
+
loss = m_loss + r_loss
|
| 782 |
+
|
| 783 |
+
loss.backward(gradient=torch.ones_like(loss))
|
| 784 |
+
loss_grad = x.grad
|
| 785 |
+
loss_grad_rms = (
|
| 786 |
+
(loss_grad**2)
|
| 787 |
+
.mean(dim=mean_dims, keepdim=True)
|
| 788 |
+
.sqrt()
|
| 789 |
+
.clamp(min=1.0e-20)
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
| 793 |
+
|
| 794 |
+
x_grad_float = x_grad.to(torch.float32)
|
| 795 |
+
# scale each element of loss_grad by the absolute value of the corresponding
|
| 796 |
+
# element of x_grad, which we view as a noisy estimate of its magnitude for that
|
| 797 |
+
# (frame and dimension). later we can consider factored versions.
|
| 798 |
+
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
|
| 799 |
+
x_grad = x_grad_mod.to(x_grad.dtype)
|
| 800 |
+
except Exception as e:
|
| 801 |
+
logging.info(
|
| 802 |
+
f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue."
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
return x_grad, None, None, None, None, None, None
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
class Balancer(torch.nn.Module):
|
| 809 |
+
"""
|
| 810 |
+
Modifies the backpropped derivatives of a function to try to encourage, for
|
| 811 |
+
each channel, that it is positive at least a proportion `threshold` of the
|
| 812 |
+
time. It does this by multiplying negative derivative values by up to
|
| 813 |
+
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
| 814 |
+
interpolated from 1 at the threshold to those extremal values when none
|
| 815 |
+
of the inputs are positive.
|
| 816 |
+
|
| 817 |
+
Args:
|
| 818 |
+
num_channels: the number of channels
|
| 819 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
| 820 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
| 821 |
+
min_positive: the minimum, per channel, of the proportion of the time
|
| 822 |
+
that (x > 0), below which we start to modify the derivatives.
|
| 823 |
+
max_positive: the maximum, per channel, of the proportion of the time
|
| 824 |
+
that (x > 0), above which we start to modify the derivatives.
|
| 825 |
+
scale_gain_factor: determines the 'gain' with which we increase the
|
| 826 |
+
change in gradient once the constraints on min_abs and max_abs
|
| 827 |
+
are violated.
|
| 828 |
+
min_abs: the minimum average-absolute-value difference from the mean
|
| 829 |
+
value per channel, which we allow, before we start to modify
|
| 830 |
+
the derivatives to prevent this.
|
| 831 |
+
max_abs: the maximum average-absolute-value difference from the mean
|
| 832 |
+
value per channel, which we allow, before we start to modify
|
| 833 |
+
the derivatives to prevent this.
|
| 834 |
+
prob: determines the minimum probability with which we modify the
|
| 835 |
+
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
| 836 |
+
on each forward(). This is done randomly to prevent all layers
|
| 837 |
+
from doing it at the same time.
|
| 838 |
+
"""
|
| 839 |
+
|
| 840 |
+
def __init__(
|
| 841 |
+
self,
|
| 842 |
+
num_channels: int,
|
| 843 |
+
channel_dim: int,
|
| 844 |
+
min_positive: FloatLike = 0.05,
|
| 845 |
+
max_positive: FloatLike = 0.95,
|
| 846 |
+
min_abs: FloatLike = 0.2,
|
| 847 |
+
max_abs: FloatLike = 100.0,
|
| 848 |
+
grad_scale: FloatLike = 0.04,
|
| 849 |
+
prob: Optional[FloatLike] = None,
|
| 850 |
+
):
|
| 851 |
+
super().__init__()
|
| 852 |
+
|
| 853 |
+
if prob is None:
|
| 854 |
+
prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
|
| 855 |
+
self.prob = prob
|
| 856 |
+
# 5% of the time we will return and do nothing because memory usage is
|
| 857 |
+
# too high.
|
| 858 |
+
self.mem_cutoff = CutoffEstimator(0.05)
|
| 859 |
+
|
| 860 |
+
# actually self.num_channels is no longer needed except for an assertion.
|
| 861 |
+
self.num_channels = num_channels
|
| 862 |
+
self.channel_dim = channel_dim
|
| 863 |
+
self.min_positive = min_positive
|
| 864 |
+
self.max_positive = max_positive
|
| 865 |
+
self.min_abs = min_abs
|
| 866 |
+
self.max_abs = max_abs
|
| 867 |
+
self.grad_scale = grad_scale
|
| 868 |
+
|
| 869 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 870 |
+
if (
|
| 871 |
+
torch.jit.is_scripting()
|
| 872 |
+
or not x.requires_grad
|
| 873 |
+
or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
|
| 874 |
+
):
|
| 875 |
+
return _no_op(x)
|
| 876 |
+
|
| 877 |
+
prob = float(self.prob)
|
| 878 |
+
if random.random() < prob:
|
| 879 |
+
# The following inner-functions convert from the way we historically specified
|
| 880 |
+
# these limitations, as limits on the absolute value and the proportion of positive
|
| 881 |
+
# values, to limits on the RMS value and the (mean / stddev).
|
| 882 |
+
def _abs_to_rms(x):
|
| 883 |
+
# for normally distributed data, if the expected absolute value is x, the
|
| 884 |
+
# expected rms value will be sqrt(pi/2) * x.
|
| 885 |
+
return 1.25331413732 * x
|
| 886 |
+
|
| 887 |
+
def _proportion_positive_to_mean(x):
|
| 888 |
+
def _atanh(x):
|
| 889 |
+
eps = 1.0e-10
|
| 890 |
+
# eps is to prevent crashes if x is exactly 0 or 1.
|
| 891 |
+
# we'll just end up returning a fairly large value.
|
| 892 |
+
return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0
|
| 893 |
+
|
| 894 |
+
def _approx_inverse_erf(x):
|
| 895 |
+
# 1 / (sqrt(pi) * ln(2)),
|
| 896 |
+
# see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions
|
| 897 |
+
# this approximation is extremely crude and gets progressively worse for
|
| 898 |
+
# x very close to -1 or +1, but we mostly care about the "middle" region
|
| 899 |
+
# e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
|
| 900 |
+
# and math.erf(0.0407316414078772) = 0.045935330944660666,
|
| 901 |
+
# which is pretty close to 0.05.
|
| 902 |
+
return 0.8139535143 * _atanh(x)
|
| 903 |
+
|
| 904 |
+
# first convert x from the range 0..1 to the range -1..1 which the error
|
| 905 |
+
# function returns
|
| 906 |
+
x = -1 + (2 * x)
|
| 907 |
+
return _approx_inverse_erf(x)
|
| 908 |
+
|
| 909 |
+
min_mean = _proportion_positive_to_mean(float(self.min_positive))
|
| 910 |
+
max_mean = _proportion_positive_to_mean(float(self.max_positive))
|
| 911 |
+
min_rms = _abs_to_rms(float(self.min_abs))
|
| 912 |
+
max_rms = _abs_to_rms(float(self.max_abs))
|
| 913 |
+
grad_scale = float(self.grad_scale)
|
| 914 |
+
|
| 915 |
+
assert x.shape[self.channel_dim] == self.num_channels
|
| 916 |
+
|
| 917 |
+
return BalancerFunction.apply(
|
| 918 |
+
x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim
|
| 919 |
+
)
|
| 920 |
+
else:
|
| 921 |
+
return _no_op(x)
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
def penalize_abs_values_gt(
|
| 925 |
+
x: Tensor, limit: float, penalty: float, name: str = None
|
| 926 |
+
) -> Tensor:
|
| 927 |
+
"""
|
| 928 |
+
Returns x unmodified, but in backprop will put a penalty for the excess of
|
| 929 |
+
the absolute values of elements of x over the limit "limit". E.g. if
|
| 930 |
+
limit == 10.0, then if x has any values over 10 it will get a penalty.
|
| 931 |
+
|
| 932 |
+
Caution: the value of this penalty will be affected by grad scaling used
|
| 933 |
+
in automatic mixed precision training. For this reasons we use this,
|
| 934 |
+
it shouldn't really matter, or may even be helpful; we just use this
|
| 935 |
+
to disallow really implausible values of scores to be given to softmax.
|
| 936 |
+
|
| 937 |
+
The name is for randomly printed debug info.
|
| 938 |
+
"""
|
| 939 |
+
x_sign = x.sign()
|
| 940 |
+
over_limit = (x.abs() - limit) > 0
|
| 941 |
+
# The following is a memory efficient way to penalize the absolute values of
|
| 942 |
+
# x that's over the limit. (The memory efficiency comes when you think
|
| 943 |
+
# about which items torch needs to cache for the autograd, and which ones it
|
| 944 |
+
# can throw away). The numerical value of aux_loss as computed here will
|
| 945 |
+
# actually be larger than it should be, by limit * over_limit.sum(), but it
|
| 946 |
+
# has the same derivative as the real aux_loss which is penalty * (x.abs() -
|
| 947 |
+
# limit).relu().
|
| 948 |
+
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
| 949 |
+
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
| 950 |
+
# sum() due to how with_loss() works.
|
| 951 |
+
x = with_loss(x, aux_loss, name)
|
| 952 |
+
# you must use x for something, or this will be ineffective.
|
| 953 |
+
return x
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
| 957 |
+
if x.ndim == 2:
|
| 958 |
+
return x.diag()
|
| 959 |
+
else:
|
| 960 |
+
(batch, dim, dim) = x.shape
|
| 961 |
+
x = x.reshape(batch, dim * dim)
|
| 962 |
+
x = x[:, :: dim + 1]
|
| 963 |
+
assert x.shape == (batch, dim)
|
| 964 |
+
return x
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
def _whitening_metric(x: Tensor, num_groups: int):
|
| 968 |
+
"""
|
| 969 |
+
Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
|
| 970 |
+
of the centered feature covariance are the same within each group's covariance matrix
|
| 971 |
+
and also between groups.
|
| 972 |
+
Args:
|
| 973 |
+
x: a Tensor of shape (*, num_channels)
|
| 974 |
+
num_groups: the number of groups of channels, a number >=1 that divides num_channels
|
| 975 |
+
Returns:
|
| 976 |
+
Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
|
| 977 |
+
greater than 1.0 otherwise.
|
| 978 |
+
"""
|
| 979 |
+
assert x.dtype != torch.float16
|
| 980 |
+
x = x.reshape(-1, x.shape[-1])
|
| 981 |
+
(num_frames, num_channels) = x.shape
|
| 982 |
+
assert num_channels % num_groups == 0
|
| 983 |
+
channels_per_group = num_channels // num_groups
|
| 984 |
+
x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
|
| 985 |
+
# x now has shape (num_groups, num_frames, channels_per_group)
|
| 986 |
+
# subtract the mean so we use the centered, not uncentered, covariance.
|
| 987 |
+
# My experience has been that when we "mess with the gradients" like this,
|
| 988 |
+
# it's better not do anything that tries to move the mean around, because
|
| 989 |
+
# that can easily cause instability.
|
| 990 |
+
x = x - x.mean(dim=1, keepdim=True)
|
| 991 |
+
# x_covar: (num_groups, channels_per_group, channels_per_group)
|
| 992 |
+
x_covar = torch.matmul(x.transpose(1, 2), x)
|
| 993 |
+
x_covar_mean_diag = _diag(x_covar).mean()
|
| 994 |
+
# the following expression is what we'd get if we took the matrix product
|
| 995 |
+
# of each covariance and measured the mean of its trace, i.e.
|
| 996 |
+
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
|
| 997 |
+
x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
|
| 998 |
+
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
|
| 999 |
+
metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
|
| 1000 |
+
return metric
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
class WhiteningPenaltyFunction(torch.autograd.Function):
|
| 1004 |
+
@staticmethod
|
| 1005 |
+
def forward(ctx, x: Tensor, module: nn.Module) -> Tensor:
|
| 1006 |
+
ctx.save_for_backward(x)
|
| 1007 |
+
ctx.module = module
|
| 1008 |
+
return x
|
| 1009 |
+
|
| 1010 |
+
@staticmethod
|
| 1011 |
+
def backward(ctx, x_grad: Tensor):
|
| 1012 |
+
(x_orig,) = ctx.saved_tensors
|
| 1013 |
+
w = ctx.module
|
| 1014 |
+
|
| 1015 |
+
try:
|
| 1016 |
+
with torch.enable_grad():
|
| 1017 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 1018 |
+
x_detached = x_orig.to(torch.float32).detach()
|
| 1019 |
+
x_detached.requires_grad = True
|
| 1020 |
+
|
| 1021 |
+
metric = _whitening_metric(x_detached, w.num_groups)
|
| 1022 |
+
|
| 1023 |
+
if random.random() < 0.005 or __name__ == "__main__":
|
| 1024 |
+
logging.info(
|
| 1025 |
+
f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
|
| 1026 |
+
f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}"
|
| 1027 |
+
)
|
| 1028 |
+
|
| 1029 |
+
if metric < float(w.whitening_limit):
|
| 1030 |
+
w.prob = w.min_prob
|
| 1031 |
+
return x_grad, None
|
| 1032 |
+
else:
|
| 1033 |
+
w.prob = w.max_prob
|
| 1034 |
+
metric.backward()
|
| 1035 |
+
penalty_grad = x_detached.grad
|
| 1036 |
+
scale = float(w.grad_scale) * (
|
| 1037 |
+
x_grad.to(torch.float32).norm()
|
| 1038 |
+
/ (penalty_grad.norm() + 1.0e-20)
|
| 1039 |
+
)
|
| 1040 |
+
penalty_grad = penalty_grad * scale
|
| 1041 |
+
return x_grad + penalty_grad.to(x_grad.dtype), None
|
| 1042 |
+
except Exception as e:
|
| 1043 |
+
logging.info(
|
| 1044 |
+
f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue."
|
| 1045 |
+
)
|
| 1046 |
+
return x_grad, None
|
| 1047 |
+
|
| 1048 |
+
|
| 1049 |
+
class Whiten(nn.Module):
|
| 1050 |
+
def __init__(
|
| 1051 |
+
self,
|
| 1052 |
+
num_groups: int,
|
| 1053 |
+
whitening_limit: FloatLike,
|
| 1054 |
+
prob: Union[float, Tuple[float, float]],
|
| 1055 |
+
grad_scale: FloatLike,
|
| 1056 |
+
):
|
| 1057 |
+
"""
|
| 1058 |
+
Args:
|
| 1059 |
+
num_groups: the number of groups to divide the channel dim into before
|
| 1060 |
+
whitening. We will attempt to make the feature covariance
|
| 1061 |
+
within each group, after mean subtraction, as "white" as possible,
|
| 1062 |
+
while having the same trace across all groups.
|
| 1063 |
+
whitening_limit: a value greater than 1.0, that dictates how much
|
| 1064 |
+
freedom we have to violate the constraints. 1.0 would mean perfectly
|
| 1065 |
+
white, with exactly the same trace across groups; larger values
|
| 1066 |
+
give more freedom. E.g. 2.0.
|
| 1067 |
+
prob: the probability with which we apply the gradient modification
|
| 1068 |
+
(also affects the grad scale). May be supplied as a float,
|
| 1069 |
+
or as a pair (min_prob, max_prob)
|
| 1070 |
+
|
| 1071 |
+
grad_scale: determines the scale on the gradient term from this object,
|
| 1072 |
+
relative to the rest of the gradient on the attention weights.
|
| 1073 |
+
E.g. 0.02 (you may want to use smaller values than this if prob is large)
|
| 1074 |
+
"""
|
| 1075 |
+
super(Whiten, self).__init__()
|
| 1076 |
+
assert num_groups >= 1
|
| 1077 |
+
assert float(whitening_limit) >= 1
|
| 1078 |
+
assert float(grad_scale) >= 0
|
| 1079 |
+
self.num_groups = num_groups
|
| 1080 |
+
self.whitening_limit = whitening_limit
|
| 1081 |
+
self.grad_scale = grad_scale
|
| 1082 |
+
|
| 1083 |
+
if isinstance(prob, float):
|
| 1084 |
+
prob = (prob, prob)
|
| 1085 |
+
(self.min_prob, self.max_prob) = prob
|
| 1086 |
+
assert 0 < self.min_prob <= self.max_prob <= 1
|
| 1087 |
+
self.prob = self.max_prob
|
| 1088 |
+
self.name = None # will be set in training loop
|
| 1089 |
+
|
| 1090 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1091 |
+
"""
|
| 1092 |
+
In the forward pass, this function just returns the input unmodified.
|
| 1093 |
+
In the backward pass, it will modify the gradients to ensure that the
|
| 1094 |
+
distribution in each group has close to (lambda times I) as the covariance
|
| 1095 |
+
after mean subtraction, with the same lambda across groups.
|
| 1096 |
+
For whitening_limit > 1, there will be more freedom to violate this
|
| 1097 |
+
constraint.
|
| 1098 |
+
|
| 1099 |
+
Args:
|
| 1100 |
+
x: the input of shape (*, num_channels)
|
| 1101 |
+
|
| 1102 |
+
Returns:
|
| 1103 |
+
x, unmodified. You should make sure
|
| 1104 |
+
you use the returned value, or the graph will be freed
|
| 1105 |
+
and nothing will happen in backprop.
|
| 1106 |
+
"""
|
| 1107 |
+
grad_scale = float(self.grad_scale)
|
| 1108 |
+
if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
|
| 1109 |
+
return _no_op(x)
|
| 1110 |
+
else:
|
| 1111 |
+
return WhiteningPenaltyFunction.apply(x, self)
|
| 1112 |
+
|
| 1113 |
+
|
| 1114 |
+
class WithLoss(torch.autograd.Function):
|
| 1115 |
+
@staticmethod
|
| 1116 |
+
def forward(ctx, x: Tensor, y: Tensor, name: str):
|
| 1117 |
+
ctx.y_shape = y.shape
|
| 1118 |
+
if random.random() < 0.002 and name is not None:
|
| 1119 |
+
loss_sum = y.sum().item()
|
| 1120 |
+
logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
|
| 1121 |
+
return x
|
| 1122 |
+
|
| 1123 |
+
@staticmethod
|
| 1124 |
+
def backward(ctx, ans_grad: Tensor):
|
| 1125 |
+
return (
|
| 1126 |
+
ans_grad,
|
| 1127 |
+
torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
|
| 1128 |
+
None,
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
|
| 1132 |
+
def with_loss(x, y, name):
|
| 1133 |
+
# returns x but adds y.sum() to the loss function.
|
| 1134 |
+
return WithLoss.apply(x, y, name)
|
| 1135 |
+
|
| 1136 |
+
|
| 1137 |
+
class ScaleGradFunction(torch.autograd.Function):
|
| 1138 |
+
@staticmethod
|
| 1139 |
+
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
|
| 1140 |
+
ctx.alpha = alpha
|
| 1141 |
+
return x
|
| 1142 |
+
|
| 1143 |
+
@staticmethod
|
| 1144 |
+
def backward(ctx, grad: Tensor):
|
| 1145 |
+
return grad * ctx.alpha, None
|
| 1146 |
+
|
| 1147 |
+
|
| 1148 |
+
def scale_grad(x: Tensor, alpha: float):
|
| 1149 |
+
return ScaleGradFunction.apply(x, alpha)
|
| 1150 |
+
|
| 1151 |
+
|
| 1152 |
+
class ScaleGrad(nn.Module):
|
| 1153 |
+
def __init__(self, alpha: float):
|
| 1154 |
+
super().__init__()
|
| 1155 |
+
self.alpha = alpha
|
| 1156 |
+
|
| 1157 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1158 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
| 1159 |
+
return x
|
| 1160 |
+
return scale_grad(x, self.alpha)
|
| 1161 |
+
|
| 1162 |
+
|
| 1163 |
+
class LimitParamValue(torch.autograd.Function):
|
| 1164 |
+
@staticmethod
|
| 1165 |
+
def forward(ctx, x: Tensor, min: float, max: float):
|
| 1166 |
+
ctx.save_for_backward(x)
|
| 1167 |
+
assert max >= min
|
| 1168 |
+
ctx.min = min
|
| 1169 |
+
ctx.max = max
|
| 1170 |
+
return x
|
| 1171 |
+
|
| 1172 |
+
@staticmethod
|
| 1173 |
+
def backward(ctx, x_grad: Tensor):
|
| 1174 |
+
(x,) = ctx.saved_tensors
|
| 1175 |
+
# where x < ctx.min, ensure all grads are negative (this will tend to make
|
| 1176 |
+
# x more positive).
|
| 1177 |
+
x_grad = x_grad * torch.where(
|
| 1178 |
+
torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0
|
| 1179 |
+
)
|
| 1180 |
+
# where x > ctx.max, ensure all grads are positive (this will tend to make
|
| 1181 |
+
# x more negative).
|
| 1182 |
+
x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
|
| 1183 |
+
return x_grad, None, None
|
| 1184 |
+
|
| 1185 |
+
|
| 1186 |
+
def limit_param_value(
|
| 1187 |
+
x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True
|
| 1188 |
+
):
|
| 1189 |
+
# You apply this to (typically) an nn.Parameter during training to ensure that its
|
| 1190 |
+
# (elements mostly) stays within a supplied range. This is done by modifying the
|
| 1191 |
+
# gradients in backprop.
|
| 1192 |
+
# It's not necessary to do this on every batch: do it only some of the time,
|
| 1193 |
+
# to save a little time.
|
| 1194 |
+
if training and random.random() < prob:
|
| 1195 |
+
return LimitParamValue.apply(x, min, max)
|
| 1196 |
+
else:
|
| 1197 |
+
return x
|
| 1198 |
+
|
| 1199 |
+
|
| 1200 |
+
def _no_op(x: Tensor) -> Tensor:
|
| 1201 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1202 |
+
return x
|
| 1203 |
+
else:
|
| 1204 |
+
# a no-op function that will have a node in the autograd graph,
|
| 1205 |
+
# to avoid certain bugs relating to backward hooks
|
| 1206 |
+
return x.chunk(1, dim=-1)[0]
|
| 1207 |
+
|
| 1208 |
+
|
| 1209 |
+
class Identity(torch.nn.Module):
|
| 1210 |
+
def __init__(self):
|
| 1211 |
+
super(Identity, self).__init__()
|
| 1212 |
+
|
| 1213 |
+
def forward(self, x):
|
| 1214 |
+
return _no_op(x)
|
| 1215 |
+
|
| 1216 |
+
|
| 1217 |
+
class DoubleSwishFunction(torch.autograd.Function):
|
| 1218 |
+
"""
|
| 1219 |
+
double_swish(x) = x * torch.sigmoid(x-1)
|
| 1220 |
+
|
| 1221 |
+
This is a definition, originally motivated by its close numerical
|
| 1222 |
+
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
| 1223 |
+
|
| 1224 |
+
Memory-efficient derivative computation:
|
| 1225 |
+
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
| 1226 |
+
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
| 1227 |
+
Now, s'(x) = s(x) * (1-s(x)).
|
| 1228 |
+
double_swish'(x) = x * s'(x) + s(x).
|
| 1229 |
+
= x * s(x) * (1-s(x)) + s(x).
|
| 1230 |
+
= double_swish(x) * (1-s(x)) + s(x)
|
| 1231 |
+
... so we just need to remember s(x) but not x itself.
|
| 1232 |
+
"""
|
| 1233 |
+
|
| 1234 |
+
@staticmethod
|
| 1235 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
| 1236 |
+
requires_grad = x.requires_grad
|
| 1237 |
+
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
| 1238 |
+
x = x.to(torch.float32)
|
| 1239 |
+
|
| 1240 |
+
s = torch.sigmoid(x - 1.0)
|
| 1241 |
+
y = x * s
|
| 1242 |
+
|
| 1243 |
+
if requires_grad:
|
| 1244 |
+
deriv = y * (1 - s) + s
|
| 1245 |
+
|
| 1246 |
+
# notes on derivative of x * sigmoid(x - 1):
|
| 1247 |
+
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
| 1248 |
+
# min \simeq -0.043638. Take floor as -0.044 so it's a lower bund
|
| 1249 |
+
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
|
| 1250 |
+
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
| 1251 |
+
# floors), should be expectation-preserving.
|
| 1252 |
+
floor = -0.044
|
| 1253 |
+
ceil = 1.2
|
| 1254 |
+
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
| 1255 |
+
deriv
|
| 1256 |
+
)
|
| 1257 |
+
if __name__ == "__main__":
|
| 1258 |
+
# for self-testing only.
|
| 1259 |
+
assert d_scaled.min() >= 0.0
|
| 1260 |
+
assert d_scaled.max() < 256.0
|
| 1261 |
+
d_int = d_scaled.to(torch.uint8)
|
| 1262 |
+
ctx.save_for_backward(d_int)
|
| 1263 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
| 1264 |
+
y = y.to(torch.float16)
|
| 1265 |
+
return y
|
| 1266 |
+
|
| 1267 |
+
@staticmethod
|
| 1268 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
| 1269 |
+
(d,) = ctx.saved_tensors
|
| 1270 |
+
# the same constants as used in forward pass.
|
| 1271 |
+
floor = -0.043637
|
| 1272 |
+
ceil = 1.2
|
| 1273 |
+
|
| 1274 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
| 1275 |
+
return y_grad * d
|
| 1276 |
+
|
| 1277 |
+
|
| 1278 |
+
class DoubleSwish(torch.nn.Module):
|
| 1279 |
+
def __init__(self):
|
| 1280 |
+
super().__init__()
|
| 1281 |
+
|
| 1282 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1283 |
+
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
| 1284 |
+
that we approximate closely with x * sigmoid(x-1).
|
| 1285 |
+
"""
|
| 1286 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1287 |
+
return x * torch.sigmoid(x - 1.0)
|
| 1288 |
+
return DoubleSwishFunction.apply(x)
|
| 1289 |
+
|
| 1290 |
+
|
| 1291 |
+
# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates.
|
| 1292 |
+
class Dropout2(nn.Module):
|
| 1293 |
+
def __init__(self, p: FloatLike):
|
| 1294 |
+
super().__init__()
|
| 1295 |
+
self.p = p
|
| 1296 |
+
|
| 1297 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1298 |
+
return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
|
| 1299 |
+
|
| 1300 |
+
|
| 1301 |
+
class MulForDropout3(torch.autograd.Function):
|
| 1302 |
+
# returns (x * y * alpha) where alpha is a float and y doesn't require
|
| 1303 |
+
# grad and is zero-or-one.
|
| 1304 |
+
@staticmethod
|
| 1305 |
+
@custom_fwd
|
| 1306 |
+
def forward(ctx, x, y, alpha):
|
| 1307 |
+
assert not y.requires_grad
|
| 1308 |
+
ans = x * y * alpha
|
| 1309 |
+
ctx.save_for_backward(ans)
|
| 1310 |
+
ctx.alpha = alpha
|
| 1311 |
+
return ans
|
| 1312 |
+
|
| 1313 |
+
@staticmethod
|
| 1314 |
+
@custom_bwd
|
| 1315 |
+
def backward(ctx, ans_grad):
|
| 1316 |
+
(ans,) = ctx.saved_tensors
|
| 1317 |
+
x_grad = ctx.alpha * ans_grad * (ans != 0)
|
| 1318 |
+
return x_grad, None, None
|
| 1319 |
+
|
| 1320 |
+
|
| 1321 |
+
# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
|
| 1322 |
+
# and it lets you choose one dimension to share the dropout mask over
|
| 1323 |
+
class Dropout3(nn.Module):
|
| 1324 |
+
def __init__(self, p: FloatLike, shared_dim: int):
|
| 1325 |
+
super().__init__()
|
| 1326 |
+
self.p = p
|
| 1327 |
+
self.shared_dim = shared_dim
|
| 1328 |
+
|
| 1329 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1330 |
+
p = float(self.p)
|
| 1331 |
+
if not self.training or p == 0:
|
| 1332 |
+
return _no_op(x)
|
| 1333 |
+
scale = 1.0 / (1 - p)
|
| 1334 |
+
rand_shape = list(x.shape)
|
| 1335 |
+
rand_shape[self.shared_dim] = 1
|
| 1336 |
+
mask = torch.rand(*rand_shape, device=x.device) > p
|
| 1337 |
+
ans = MulForDropout3.apply(x, mask, scale)
|
| 1338 |
+
return ans
|
| 1339 |
+
|
| 1340 |
+
|
| 1341 |
+
class SwooshLFunction(torch.autograd.Function):
|
| 1342 |
+
"""
|
| 1343 |
+
swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
|
| 1344 |
+
"""
|
| 1345 |
+
|
| 1346 |
+
@staticmethod
|
| 1347 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
| 1348 |
+
requires_grad = x.requires_grad
|
| 1349 |
+
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
| 1350 |
+
x = x.to(torch.float32)
|
| 1351 |
+
|
| 1352 |
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 1353 |
+
|
| 1354 |
+
coeff = -0.08
|
| 1355 |
+
|
| 1356 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 1357 |
+
with torch.enable_grad():
|
| 1358 |
+
x = x.detach()
|
| 1359 |
+
x.requires_grad = True
|
| 1360 |
+
y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035
|
| 1361 |
+
|
| 1362 |
+
if not requires_grad:
|
| 1363 |
+
return y
|
| 1364 |
+
|
| 1365 |
+
y.backward(gradient=torch.ones_like(y))
|
| 1366 |
+
|
| 1367 |
+
grad = x.grad
|
| 1368 |
+
floor = coeff
|
| 1369 |
+
ceil = 1.0 + coeff + 0.005
|
| 1370 |
+
|
| 1371 |
+
d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
| 1372 |
+
grad
|
| 1373 |
+
)
|
| 1374 |
+
if __name__ == "__main__":
|
| 1375 |
+
# for self-testing only.
|
| 1376 |
+
assert d_scaled.min() >= 0.0
|
| 1377 |
+
assert d_scaled.max() < 256.0
|
| 1378 |
+
|
| 1379 |
+
d_int = d_scaled.to(torch.uint8)
|
| 1380 |
+
ctx.save_for_backward(d_int)
|
| 1381 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
| 1382 |
+
y = y.to(torch.get_autocast_gpu_dtype())
|
| 1383 |
+
return y
|
| 1384 |
+
|
| 1385 |
+
@staticmethod
|
| 1386 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
| 1387 |
+
(d,) = ctx.saved_tensors
|
| 1388 |
+
# the same constants as used in forward pass.
|
| 1389 |
+
|
| 1390 |
+
coeff = -0.08
|
| 1391 |
+
floor = coeff
|
| 1392 |
+
ceil = 1.0 + coeff + 0.005
|
| 1393 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
| 1394 |
+
return y_grad * d
|
| 1395 |
+
|
| 1396 |
+
|
| 1397 |
+
class SwooshL(torch.nn.Module):
|
| 1398 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1399 |
+
"""Return Swoosh-L activation."""
|
| 1400 |
+
return SwooshLFunction.apply(x)
|
| 1401 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1402 |
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 1403 |
+
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
| 1404 |
+
if not x.requires_grad:
|
| 1405 |
+
return k2.swoosh_l_forward(x)
|
| 1406 |
+
else:
|
| 1407 |
+
return k2.swoosh_l(x)
|
| 1408 |
+
# return SwooshLFunction.apply(x)
|
| 1409 |
+
|
| 1410 |
+
|
| 1411 |
+
class SwooshLOnnx(torch.nn.Module):
|
| 1412 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1413 |
+
"""Return Swoosh-L activation."""
|
| 1414 |
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 1415 |
+
return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
|
| 1416 |
+
|
| 1417 |
+
|
| 1418 |
+
class SwooshRFunction(torch.autograd.Function):
|
| 1419 |
+
"""
|
| 1420 |
+
swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
|
| 1421 |
+
|
| 1422 |
+
derivatives are between -0.08 and 0.92.
|
| 1423 |
+
"""
|
| 1424 |
+
|
| 1425 |
+
@staticmethod
|
| 1426 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
| 1427 |
+
requires_grad = x.requires_grad
|
| 1428 |
+
|
| 1429 |
+
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
| 1430 |
+
x = x.to(torch.float32)
|
| 1431 |
+
|
| 1432 |
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 1433 |
+
|
| 1434 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 1435 |
+
with torch.enable_grad():
|
| 1436 |
+
x = x.detach()
|
| 1437 |
+
x.requires_grad = True
|
| 1438 |
+
y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
|
| 1439 |
+
|
| 1440 |
+
if not requires_grad:
|
| 1441 |
+
return y
|
| 1442 |
+
y.backward(gradient=torch.ones_like(y))
|
| 1443 |
+
|
| 1444 |
+
grad = x.grad
|
| 1445 |
+
floor = -0.08
|
| 1446 |
+
ceil = 0.925
|
| 1447 |
+
|
| 1448 |
+
d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
| 1449 |
+
grad
|
| 1450 |
+
)
|
| 1451 |
+
if __name__ == "__main__":
|
| 1452 |
+
# for self-testing only.
|
| 1453 |
+
assert d_scaled.min() >= 0.0
|
| 1454 |
+
assert d_scaled.max() < 256.0
|
| 1455 |
+
|
| 1456 |
+
d_int = d_scaled.to(torch.uint8)
|
| 1457 |
+
ctx.save_for_backward(d_int)
|
| 1458 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
| 1459 |
+
y = y.to(torch.get_autocast_gpu_dtype())
|
| 1460 |
+
return y
|
| 1461 |
+
|
| 1462 |
+
@staticmethod
|
| 1463 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
| 1464 |
+
(d,) = ctx.saved_tensors
|
| 1465 |
+
# the same constants as used in forward pass.
|
| 1466 |
+
floor = -0.08
|
| 1467 |
+
ceil = 0.925
|
| 1468 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
| 1469 |
+
return y_grad * d
|
| 1470 |
+
|
| 1471 |
+
|
| 1472 |
+
class SwooshR(torch.nn.Module):
|
| 1473 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1474 |
+
"""Return Swoosh-R activation."""
|
| 1475 |
+
# if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1476 |
+
return SwooshRFunction.apply(x)
|
| 1477 |
+
if True:
|
| 1478 |
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 1479 |
+
return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
|
| 1480 |
+
if not x.requires_grad:
|
| 1481 |
+
return k2.swoosh_r_forward(x)
|
| 1482 |
+
else:
|
| 1483 |
+
return k2.swoosh_r(x)
|
| 1484 |
+
# return SwooshRFunction.apply(x)
|
| 1485 |
+
|
| 1486 |
+
|
| 1487 |
+
class SwooshROnnx(torch.nn.Module):
|
| 1488 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1489 |
+
"""Return Swoosh-R activation."""
|
| 1490 |
+
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 1491 |
+
return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687
|
| 1492 |
+
|
| 1493 |
+
|
| 1494 |
+
# simple version of SwooshL that does not redefine the backprop, used in
|
| 1495 |
+
# ActivationDropoutAndLinearFunction.
|
| 1496 |
+
def SwooshLForward(x: Tensor):
|
| 1497 |
+
x_offset = x - 4.0
|
| 1498 |
+
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
| 1499 |
+
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
| 1500 |
+
return log_sum - 0.08 * x - 0.035
|
| 1501 |
+
|
| 1502 |
+
|
| 1503 |
+
# simple version of SwooshR that does not redefine the backprop, used in
|
| 1504 |
+
# ActivationDropoutAndLinearFunction.
|
| 1505 |
+
def SwooshRForward(x: Tensor):
|
| 1506 |
+
x_offset = x - 1.0
|
| 1507 |
+
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
| 1508 |
+
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
| 1509 |
+
return log_sum - 0.08 * x - 0.313261687
|
| 1510 |
+
|
| 1511 |
+
|
| 1512 |
+
class ActivationDropoutAndLinearFunction(torch.autograd.Function):
|
| 1513 |
+
@staticmethod
|
| 1514 |
+
@custom_fwd
|
| 1515 |
+
def forward(
|
| 1516 |
+
ctx,
|
| 1517 |
+
x: Tensor,
|
| 1518 |
+
weight: Tensor,
|
| 1519 |
+
bias: Optional[Tensor],
|
| 1520 |
+
activation: str,
|
| 1521 |
+
dropout_p: float,
|
| 1522 |
+
dropout_shared_dim: Optional[int],
|
| 1523 |
+
):
|
| 1524 |
+
if dropout_p != 0.0:
|
| 1525 |
+
dropout_shape = list(x.shape)
|
| 1526 |
+
if dropout_shared_dim is not None:
|
| 1527 |
+
dropout_shape[dropout_shared_dim] = 1
|
| 1528 |
+
# else it won't be very memory efficient.
|
| 1529 |
+
dropout_mask = (1.0 / (1.0 - dropout_p)) * (
|
| 1530 |
+
torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p
|
| 1531 |
+
)
|
| 1532 |
+
else:
|
| 1533 |
+
dropout_mask = None
|
| 1534 |
+
|
| 1535 |
+
ctx.save_for_backward(x, weight, bias, dropout_mask)
|
| 1536 |
+
|
| 1537 |
+
ctx.activation = activation
|
| 1538 |
+
|
| 1539 |
+
forward_activation_dict = {
|
| 1540 |
+
"SwooshL": k2.swoosh_l_forward,
|
| 1541 |
+
"SwooshR": k2.swoosh_r_forward,
|
| 1542 |
+
}
|
| 1543 |
+
# it will raise a KeyError if this fails. This will be an error. We let it
|
| 1544 |
+
# propagate to the user.
|
| 1545 |
+
activation_func = forward_activation_dict[activation]
|
| 1546 |
+
x = activation_func(x)
|
| 1547 |
+
if dropout_mask is not None:
|
| 1548 |
+
x = x * dropout_mask
|
| 1549 |
+
x = torch.nn.functional.linear(x, weight, bias)
|
| 1550 |
+
return x
|
| 1551 |
+
|
| 1552 |
+
@staticmethod
|
| 1553 |
+
@custom_bwd
|
| 1554 |
+
def backward(ctx, ans_grad: Tensor):
|
| 1555 |
+
saved = ctx.saved_tensors
|
| 1556 |
+
(x, weight, bias, dropout_mask) = saved
|
| 1557 |
+
|
| 1558 |
+
forward_and_deriv_activation_dict = {
|
| 1559 |
+
"SwooshL": k2.swoosh_l_forward_and_deriv,
|
| 1560 |
+
"SwooshR": k2.swoosh_r_forward_and_deriv,
|
| 1561 |
+
}
|
| 1562 |
+
# the following lines a KeyError if the activation is unrecognized.
|
| 1563 |
+
# This will be an error. We let it propagate to the user.
|
| 1564 |
+
func = forward_and_deriv_activation_dict[ctx.activation]
|
| 1565 |
+
|
| 1566 |
+
y, func_deriv = func(x)
|
| 1567 |
+
if dropout_mask is not None:
|
| 1568 |
+
y = y * dropout_mask
|
| 1569 |
+
# now compute derivative of y w.r.t. weight and bias..
|
| 1570 |
+
# y: (..., in_channels), ans_grad: (..., out_channels),
|
| 1571 |
+
(out_channels, in_channels) = weight.shape
|
| 1572 |
+
|
| 1573 |
+
in_channels = y.shape[-1]
|
| 1574 |
+
g = ans_grad.reshape(-1, out_channels)
|
| 1575 |
+
weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels))
|
| 1576 |
+
y_deriv = torch.matmul(ans_grad, weight)
|
| 1577 |
+
bias_deriv = None if bias is None else g.sum(dim=0)
|
| 1578 |
+
x_deriv = y_deriv * func_deriv
|
| 1579 |
+
if dropout_mask is not None:
|
| 1580 |
+
# order versus func_deriv does not matter
|
| 1581 |
+
x_deriv = x_deriv * dropout_mask
|
| 1582 |
+
|
| 1583 |
+
return x_deriv, weight_deriv, bias_deriv, None, None, None
|
| 1584 |
+
|
| 1585 |
+
|
| 1586 |
+
class ActivationDropoutAndLinear(torch.nn.Module):
|
| 1587 |
+
"""
|
| 1588 |
+
This merges an activation function followed by dropout and then a nn.Linear module;
|
| 1589 |
+
it does so in a memory efficient way so that it only stores the input to the whole
|
| 1590 |
+
module. If activation == SwooshL and dropout_shared_dim != None, this will be
|
| 1591 |
+
equivalent to:
|
| 1592 |
+
nn.Sequential(SwooshL(),
|
| 1593 |
+
Dropout3(dropout_p, shared_dim=dropout_shared_dim),
|
| 1594 |
+
ScaledLinear(in_channels, out_channels, bias=bias,
|
| 1595 |
+
initial_scale=initial_scale))
|
| 1596 |
+
If dropout_shared_dim is None, the dropout would be equivalent to
|
| 1597 |
+
Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout
|
| 1598 |
+
mask is smaller.
|
| 1599 |
+
|
| 1600 |
+
Args:
|
| 1601 |
+
in_channels: number of input channels, e.g. 256
|
| 1602 |
+
out_channels: number of output channels, e.g. 256
|
| 1603 |
+
bias: if true, have a bias
|
| 1604 |
+
activation: the activation function, for now just support SwooshL.
|
| 1605 |
+
dropout_p: the dropout probability or schedule (happens after nonlinearity).
|
| 1606 |
+
dropout_shared_dim: the dimension, if any, across which the dropout mask is
|
| 1607 |
+
shared (e.g. the time dimension). If None, this may be less memory
|
| 1608 |
+
efficient if there are modules before this one that cache the input
|
| 1609 |
+
for their backprop (e.g. Balancer or Whiten).
|
| 1610 |
+
"""
|
| 1611 |
+
|
| 1612 |
+
def __init__(
|
| 1613 |
+
self,
|
| 1614 |
+
in_channels: int,
|
| 1615 |
+
out_channels: int,
|
| 1616 |
+
bias: bool = True,
|
| 1617 |
+
activation: str = "SwooshL",
|
| 1618 |
+
dropout_p: FloatLike = 0.0,
|
| 1619 |
+
dropout_shared_dim: Optional[int] = -1,
|
| 1620 |
+
initial_scale: float = 1.0,
|
| 1621 |
+
):
|
| 1622 |
+
super().__init__()
|
| 1623 |
+
# create a temporary module of nn.Linear that we'll steal the
|
| 1624 |
+
# weights and bias from
|
| 1625 |
+
l = ScaledLinear(
|
| 1626 |
+
in_channels, out_channels, bias=bias, initial_scale=initial_scale
|
| 1627 |
+
)
|
| 1628 |
+
|
| 1629 |
+
self.weight = l.weight
|
| 1630 |
+
# register_parameter properly handles making it a parameter when l.bias
|
| 1631 |
+
# is None. I think there is some reason for doing it this way rather
|
| 1632 |
+
# than just setting it to None but I don't know what it is, maybe
|
| 1633 |
+
# something to do with exporting the module..
|
| 1634 |
+
self.register_parameter("bias", l.bias)
|
| 1635 |
+
|
| 1636 |
+
self.activation = activation
|
| 1637 |
+
self.dropout_p = dropout_p
|
| 1638 |
+
self.dropout_shared_dim = dropout_shared_dim
|
| 1639 |
+
|
| 1640 |
+
def forward(self, x: Tensor):
|
| 1641 |
+
# if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1642 |
+
if True:
|
| 1643 |
+
if self.activation == "SwooshL":
|
| 1644 |
+
x = SwooshLForward(x)
|
| 1645 |
+
elif self.activation == "SwooshR":
|
| 1646 |
+
x = SwooshRForward(x)
|
| 1647 |
+
else:
|
| 1648 |
+
assert False, self.activation
|
| 1649 |
+
return torch.nn.functional.linear(x, self.weight, self.bias)
|
| 1650 |
+
|
| 1651 |
+
return ActivationDropoutAndLinearFunction.apply(
|
| 1652 |
+
x,
|
| 1653 |
+
self.weight,
|
| 1654 |
+
self.bias,
|
| 1655 |
+
self.activation,
|
| 1656 |
+
float(self.dropout_p),
|
| 1657 |
+
self.dropout_shared_dim,
|
| 1658 |
+
)
|
| 1659 |
+
|
| 1660 |
+
|
| 1661 |
+
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
|
| 1662 |
+
if num_channels <= x.shape[-1]:
|
| 1663 |
+
return x[..., :num_channels]
|
| 1664 |
+
else:
|
| 1665 |
+
shape = list(x.shape)
|
| 1666 |
+
shape[-1] = num_channels - shape[-1]
|
| 1667 |
+
zeros = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
| 1668 |
+
return torch.cat((x, zeros), dim=-1)
|
| 1669 |
+
|
| 1670 |
+
|
| 1671 |
+
def _test_whiten():
|
| 1672 |
+
for proportion in [0.1, 0.5, 10.0]:
|
| 1673 |
+
logging.info(f"_test_whiten(): proportion = {proportion}")
|
| 1674 |
+
x = torch.randn(100, 128)
|
| 1675 |
+
direction = torch.randn(128)
|
| 1676 |
+
coeffs = torch.randn(100, 1)
|
| 1677 |
+
x += proportion * direction * coeffs
|
| 1678 |
+
|
| 1679 |
+
x.requires_grad = True
|
| 1680 |
+
|
| 1681 |
+
m = Whiten(
|
| 1682 |
+
1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
|
| 1683 |
+
) # grad_scale
|
| 1684 |
+
|
| 1685 |
+
for _ in range(4):
|
| 1686 |
+
y = m(x)
|
| 1687 |
+
|
| 1688 |
+
y_grad = torch.randn_like(x)
|
| 1689 |
+
y.backward(gradient=y_grad)
|
| 1690 |
+
|
| 1691 |
+
if proportion < 0.2:
|
| 1692 |
+
assert torch.allclose(x.grad, y_grad)
|
| 1693 |
+
elif proportion > 1.0:
|
| 1694 |
+
assert not torch.allclose(x.grad, y_grad)
|
| 1695 |
+
|
| 1696 |
+
|
| 1697 |
+
def _test_balancer_sign():
|
| 1698 |
+
probs = torch.arange(0, 1, 0.01)
|
| 1699 |
+
N = 1000
|
| 1700 |
+
x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
|
| 1701 |
+
x = x.detach()
|
| 1702 |
+
x.requires_grad = True
|
| 1703 |
+
m = Balancer(
|
| 1704 |
+
probs.numel(),
|
| 1705 |
+
channel_dim=0,
|
| 1706 |
+
min_positive=0.05,
|
| 1707 |
+
max_positive=0.95,
|
| 1708 |
+
min_abs=0.0,
|
| 1709 |
+
prob=1.0,
|
| 1710 |
+
)
|
| 1711 |
+
|
| 1712 |
+
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
| 1713 |
+
|
| 1714 |
+
y = m(x)
|
| 1715 |
+
y.backward(gradient=y_grad)
|
| 1716 |
+
print("_test_balancer_sign: x = ", x)
|
| 1717 |
+
print("_test_balancer_sign: y grad = ", y_grad)
|
| 1718 |
+
print("_test_balancer_sign: x grad = ", x.grad)
|
| 1719 |
+
|
| 1720 |
+
|
| 1721 |
+
def _test_balancer_magnitude():
|
| 1722 |
+
magnitudes = torch.arange(0, 1, 0.01)
|
| 1723 |
+
N = 1000
|
| 1724 |
+
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
| 1725 |
+
x = x.detach()
|
| 1726 |
+
x.requires_grad = True
|
| 1727 |
+
m = Balancer(
|
| 1728 |
+
magnitudes.numel(),
|
| 1729 |
+
channel_dim=0,
|
| 1730 |
+
min_positive=0.0,
|
| 1731 |
+
max_positive=1.0,
|
| 1732 |
+
min_abs=0.2,
|
| 1733 |
+
max_abs=0.7,
|
| 1734 |
+
prob=1.0,
|
| 1735 |
+
)
|
| 1736 |
+
|
| 1737 |
+
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
| 1738 |
+
|
| 1739 |
+
y = m(x)
|
| 1740 |
+
y.backward(gradient=y_grad)
|
| 1741 |
+
print("_test_balancer_magnitude: x = ", x)
|
| 1742 |
+
print("_test_balancer_magnitude: y grad = ", y_grad)
|
| 1743 |
+
print("_test_balancer_magnitude: x grad = ", x.grad)
|
| 1744 |
+
|
| 1745 |
+
|
| 1746 |
+
def _test_double_swish_deriv():
|
| 1747 |
+
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
| 1748 |
+
x.requires_grad = True
|
| 1749 |
+
m = DoubleSwish()
|
| 1750 |
+
|
| 1751 |
+
tol = (1.2 - (-0.043637)) / 255.0
|
| 1752 |
+
torch.autograd.gradcheck(m, x, atol=tol)
|
| 1753 |
+
|
| 1754 |
+
# for self-test.
|
| 1755 |
+
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
| 1756 |
+
x.requires_grad = True
|
| 1757 |
+
y = m(x)
|
| 1758 |
+
|
| 1759 |
+
|
| 1760 |
+
def _test_swooshl_deriv():
|
| 1761 |
+
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
| 1762 |
+
x.requires_grad = True
|
| 1763 |
+
m = SwooshL()
|
| 1764 |
+
|
| 1765 |
+
tol = 1.0 / 255.0
|
| 1766 |
+
torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
|
| 1767 |
+
|
| 1768 |
+
# for self-test.
|
| 1769 |
+
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
| 1770 |
+
x.requires_grad = True
|
| 1771 |
+
y = m(x)
|
| 1772 |
+
|
| 1773 |
+
|
| 1774 |
+
def _test_swooshr_deriv():
|
| 1775 |
+
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
| 1776 |
+
x.requires_grad = True
|
| 1777 |
+
m = SwooshR()
|
| 1778 |
+
|
| 1779 |
+
tol = 1.0 / 255.0
|
| 1780 |
+
torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
|
| 1781 |
+
|
| 1782 |
+
# for self-test.
|
| 1783 |
+
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
| 1784 |
+
x.requires_grad = True
|
| 1785 |
+
y = m(x)
|
| 1786 |
+
|
| 1787 |
+
|
| 1788 |
+
def _test_softmax():
|
| 1789 |
+
a = torch.randn(2, 10, dtype=torch.float64)
|
| 1790 |
+
b = a.clone()
|
| 1791 |
+
a.requires_grad = True
|
| 1792 |
+
b.requires_grad = True
|
| 1793 |
+
a.softmax(dim=1)[:, 0].sum().backward()
|
| 1794 |
+
print("a grad = ", a.grad)
|
| 1795 |
+
softmax(b, dim=1)[:, 0].sum().backward()
|
| 1796 |
+
print("b grad = ", b.grad)
|
| 1797 |
+
assert torch.allclose(a.grad, b.grad)
|
| 1798 |
+
|
| 1799 |
+
|
| 1800 |
+
def _test_piecewise_linear():
|
| 1801 |
+
p = PiecewiseLinear((0, 10.0))
|
| 1802 |
+
for x in [-100, 0, 100]:
|
| 1803 |
+
assert p(x) == 10.0
|
| 1804 |
+
p = PiecewiseLinear((0, 10.0), (1, 0.0))
|
| 1805 |
+
for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
|
| 1806 |
+
print("x, y = ", x, y)
|
| 1807 |
+
assert p(x) == y, (x, p(x), y)
|
| 1808 |
+
|
| 1809 |
+
q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
|
| 1810 |
+
x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
|
| 1811 |
+
pq = p.max(q)
|
| 1812 |
+
for x in x_vals:
|
| 1813 |
+
y1 = max(p(x), q(x))
|
| 1814 |
+
y2 = pq(x)
|
| 1815 |
+
assert abs(y1 - y2) < 0.001
|
| 1816 |
+
pq = p.min(q)
|
| 1817 |
+
for x in x_vals:
|
| 1818 |
+
y1 = min(p(x), q(x))
|
| 1819 |
+
y2 = pq(x)
|
| 1820 |
+
assert abs(y1 - y2) < 0.001
|
| 1821 |
+
pq = p + q
|
| 1822 |
+
for x in x_vals:
|
| 1823 |
+
y1 = p(x) + q(x)
|
| 1824 |
+
y2 = pq(x)
|
| 1825 |
+
assert abs(y1 - y2) < 0.001
|
| 1826 |
+
|
| 1827 |
+
|
| 1828 |
+
def _test_activation_dropout_and_linear():
|
| 1829 |
+
in_channels = 20
|
| 1830 |
+
out_channels = 30
|
| 1831 |
+
|
| 1832 |
+
for bias in [True, False]:
|
| 1833 |
+
# actually we don't test for dropout_p != 0.0 because forward functions will give
|
| 1834 |
+
# different answers. This is because we are using the k2 implementation of
|
| 1835 |
+
# swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
|
| 1836 |
+
# internally, messing up the random state.
|
| 1837 |
+
for dropout_p in [0.0]:
|
| 1838 |
+
for activation in ["SwooshL", "SwooshR"]:
|
| 1839 |
+
m1 = nn.Sequential(
|
| 1840 |
+
SwooshL() if activation == "SwooshL" else SwooshR(),
|
| 1841 |
+
Dropout3(p=dropout_p, shared_dim=-1),
|
| 1842 |
+
ScaledLinear(
|
| 1843 |
+
in_channels, out_channels, bias=bias, initial_scale=0.5
|
| 1844 |
+
),
|
| 1845 |
+
)
|
| 1846 |
+
m2 = ActivationDropoutAndLinear(
|
| 1847 |
+
in_channels,
|
| 1848 |
+
out_channels,
|
| 1849 |
+
bias=bias,
|
| 1850 |
+
initial_scale=0.5,
|
| 1851 |
+
activation=activation,
|
| 1852 |
+
dropout_p=dropout_p,
|
| 1853 |
+
)
|
| 1854 |
+
with torch.no_grad():
|
| 1855 |
+
m2.weight[:] = m1[2].weight
|
| 1856 |
+
if bias:
|
| 1857 |
+
m2.bias[:] = m1[2].bias
|
| 1858 |
+
# make sure forward gives same result.
|
| 1859 |
+
x1 = torch.randn(10, in_channels)
|
| 1860 |
+
x1.requires_grad = True
|
| 1861 |
+
|
| 1862 |
+
# TEMP.
|
| 1863 |
+
assert torch.allclose(
|
| 1864 |
+
SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03
|
| 1865 |
+
)
|
| 1866 |
+
|
| 1867 |
+
x2 = x1.clone().detach()
|
| 1868 |
+
x2.requires_grad = True
|
| 1869 |
+
seed = 10
|
| 1870 |
+
torch.manual_seed(seed)
|
| 1871 |
+
y1 = m1(x1)
|
| 1872 |
+
y_grad = torch.randn_like(y1)
|
| 1873 |
+
y1.backward(gradient=y_grad)
|
| 1874 |
+
torch.manual_seed(seed)
|
| 1875 |
+
y2 = m2(x2)
|
| 1876 |
+
y2.backward(gradient=y_grad)
|
| 1877 |
+
|
| 1878 |
+
print(
|
| 1879 |
+
f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}"
|
| 1880 |
+
)
|
| 1881 |
+
print("y1 = ", y1)
|
| 1882 |
+
print("y2 = ", y2)
|
| 1883 |
+
assert torch.allclose(y1, y2, atol=0.02)
|
| 1884 |
+
assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
|
| 1885 |
+
if bias:
|
| 1886 |
+
assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
|
| 1887 |
+
print("x1.grad = ", x1.grad)
|
| 1888 |
+
print("x2.grad = ", x2.grad)
|
| 1889 |
+
|
| 1890 |
+
def isclose(a, b):
|
| 1891 |
+
# return true if cosine similarity is > 0.9.
|
| 1892 |
+
return (a * b).sum() > 0.9 * (
|
| 1893 |
+
(a**2).sum() * (b**2).sum()
|
| 1894 |
+
).sqrt()
|
| 1895 |
+
|
| 1896 |
+
# the SwooshL() implementation has a noisy gradient due to 1-byte
|
| 1897 |
+
# storage of it.
|
| 1898 |
+
assert isclose(x1.grad, x2.grad)
|
| 1899 |
+
|
| 1900 |
+
|
| 1901 |
+
if __name__ == "__main__":
|
| 1902 |
+
logging.getLogger().setLevel(logging.INFO)
|
| 1903 |
+
torch.set_num_threads(1)
|
| 1904 |
+
torch.set_num_interop_threads(1)
|
| 1905 |
+
_test_piecewise_linear()
|
| 1906 |
+
_test_softmax()
|
| 1907 |
+
_test_whiten()
|
| 1908 |
+
_test_balancer_sign()
|
| 1909 |
+
_test_balancer_magnitude()
|
| 1910 |
+
_test_double_swish_deriv()
|
| 1911 |
+
_test_swooshr_deriv()
|
| 1912 |
+
_test_swooshl_deriv()
|
| 1913 |
+
_test_activation_dropout_and_linear()
|
subsampling.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey,
|
| 3 |
+
# Zengwei Yao)
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
from typing import Tuple
|
| 20 |
+
import warnings
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch import Tensor, nn
|
| 24 |
+
from scaling import (
|
| 25 |
+
Balancer,
|
| 26 |
+
BiasNorm,
|
| 27 |
+
Dropout3,
|
| 28 |
+
FloatLike,
|
| 29 |
+
Optional,
|
| 30 |
+
ScaledConv2d,
|
| 31 |
+
ScaleGrad,
|
| 32 |
+
ScheduledFloat,
|
| 33 |
+
SwooshL,
|
| 34 |
+
SwooshR,
|
| 35 |
+
Whiten,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ConvNeXt(nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
channels: int,
|
| 47 |
+
hidden_ratio: int = 3,
|
| 48 |
+
kernel_size: Tuple[int, int] = (7, 7),
|
| 49 |
+
layerdrop_rate: FloatLike = None,
|
| 50 |
+
):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
|
| 53 |
+
hidden_channels = channels * hidden_ratio
|
| 54 |
+
if layerdrop_rate is None:
|
| 55 |
+
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
|
| 56 |
+
self.layerdrop_rate = layerdrop_rate
|
| 57 |
+
|
| 58 |
+
self.depthwise_conv = nn.Conv2d(
|
| 59 |
+
in_channels=channels,
|
| 60 |
+
out_channels=channels,
|
| 61 |
+
groups=channels,
|
| 62 |
+
kernel_size=kernel_size,
|
| 63 |
+
padding=self.padding,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.pointwise_conv1 = nn.Conv2d(
|
| 67 |
+
in_channels=channels, out_channels=hidden_channels, kernel_size=1
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.hidden_balancer = Balancer(
|
| 71 |
+
hidden_channels,
|
| 72 |
+
channel_dim=1,
|
| 73 |
+
min_positive=0.3,
|
| 74 |
+
max_positive=1.0,
|
| 75 |
+
min_abs=0.75,
|
| 76 |
+
max_abs=5.0,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self.activation = SwooshL()
|
| 80 |
+
self.pointwise_conv2 = ScaledConv2d(
|
| 81 |
+
in_channels=hidden_channels,
|
| 82 |
+
out_channels=channels,
|
| 83 |
+
kernel_size=1,
|
| 84 |
+
initial_scale=0.01,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.out_balancer = Balancer(
|
| 88 |
+
channels,
|
| 89 |
+
channel_dim=1,
|
| 90 |
+
min_positive=0.4,
|
| 91 |
+
max_positive=0.6,
|
| 92 |
+
min_abs=1.0,
|
| 93 |
+
max_abs=6.0,
|
| 94 |
+
)
|
| 95 |
+
self.out_whiten = Whiten(
|
| 96 |
+
num_groups=1,
|
| 97 |
+
whitening_limit=5.0,
|
| 98 |
+
prob=(0.025, 0.25),
|
| 99 |
+
grad_scale=0.01,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 103 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
| 104 |
+
return self.forward_internal(x)
|
| 105 |
+
layerdrop_rate = float(self.layerdrop_rate)
|
| 106 |
+
|
| 107 |
+
if layerdrop_rate != 0.0:
|
| 108 |
+
batch_size = x.shape[0]
|
| 109 |
+
mask = (
|
| 110 |
+
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
|
| 111 |
+
> layerdrop_rate
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
mask = None
|
| 115 |
+
# turns out this caching idea does not work with --world-size > 1
|
| 116 |
+
# return caching_eval(self.forward_internal, x, mask)
|
| 117 |
+
return self.forward_internal(x, mask)
|
| 118 |
+
|
| 119 |
+
def forward_internal(
|
| 120 |
+
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
|
| 121 |
+
) -> Tensor:
|
| 122 |
+
"""
|
| 123 |
+
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
| 124 |
+
|
| 125 |
+
The returned value has the same shape as x.
|
| 126 |
+
"""
|
| 127 |
+
bypass = x
|
| 128 |
+
x = self.depthwise_conv(x)
|
| 129 |
+
x = self.pointwise_conv1(x)
|
| 130 |
+
x = self.hidden_balancer(x)
|
| 131 |
+
x = self.activation(x)
|
| 132 |
+
x = self.pointwise_conv2(x)
|
| 133 |
+
|
| 134 |
+
if layer_skip_mask is not None:
|
| 135 |
+
x = x * layer_skip_mask
|
| 136 |
+
|
| 137 |
+
x = bypass + x
|
| 138 |
+
x = self.out_balancer(x)
|
| 139 |
+
|
| 140 |
+
if x.requires_grad:
|
| 141 |
+
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
| 142 |
+
x = self.out_whiten(x)
|
| 143 |
+
x = x.transpose(1, 3) # (N, C, H, W)
|
| 144 |
+
|
| 145 |
+
return x
|
| 146 |
+
|
| 147 |
+
def streaming_forward(
|
| 148 |
+
self,
|
| 149 |
+
x: Tensor,
|
| 150 |
+
cached_left_pad: Tensor,
|
| 151 |
+
) -> Tuple[Tensor, Tensor]:
|
| 152 |
+
"""
|
| 153 |
+
Args:
|
| 154 |
+
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
| 155 |
+
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
- The returned value has the same shape as x.
|
| 159 |
+
- Updated cached_left_pad.
|
| 160 |
+
"""
|
| 161 |
+
padding = self.padding
|
| 162 |
+
|
| 163 |
+
# The length without right padding for depth-wise conv
|
| 164 |
+
T = x.size(2) - padding[0]
|
| 165 |
+
|
| 166 |
+
bypass = x[:, :, :T, :]
|
| 167 |
+
|
| 168 |
+
# Pad left side
|
| 169 |
+
assert cached_left_pad.size(2) == padding[0], (
|
| 170 |
+
cached_left_pad.size(2),
|
| 171 |
+
padding[0],
|
| 172 |
+
)
|
| 173 |
+
x = torch.cat([cached_left_pad, x], dim=2)
|
| 174 |
+
# Update cached left padding
|
| 175 |
+
cached_left_pad = x[:, :, T : padding[0] + T, :]
|
| 176 |
+
|
| 177 |
+
# depthwise_conv
|
| 178 |
+
x = torch.nn.functional.conv2d(
|
| 179 |
+
x,
|
| 180 |
+
weight=self.depthwise_conv.weight,
|
| 181 |
+
bias=self.depthwise_conv.bias,
|
| 182 |
+
padding=(0, padding[1]),
|
| 183 |
+
groups=self.depthwise_conv.groups,
|
| 184 |
+
)
|
| 185 |
+
x = self.pointwise_conv1(x)
|
| 186 |
+
x = self.hidden_balancer(x)
|
| 187 |
+
x = self.activation(x)
|
| 188 |
+
x = self.pointwise_conv2(x)
|
| 189 |
+
|
| 190 |
+
x = bypass + x
|
| 191 |
+
return x, cached_left_pad
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class Conv2dSubsampling(nn.Module):
|
| 195 |
+
"""Convolutional 2D subsampling (to 1/2 length).
|
| 196 |
+
|
| 197 |
+
Convert an input of shape (N, T, idim) to an output
|
| 198 |
+
with shape (N, T', odim), where
|
| 199 |
+
T' = (T-3)//2 - 2 == (T-7)//2
|
| 200 |
+
|
| 201 |
+
It is based on
|
| 202 |
+
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
in_channels: int,
|
| 208 |
+
out_channels: int,
|
| 209 |
+
layer1_channels: int = 8,
|
| 210 |
+
layer2_channels: int = 32,
|
| 211 |
+
layer3_channels: int = 128,
|
| 212 |
+
dropout: FloatLike = 0.1,
|
| 213 |
+
) -> None:
|
| 214 |
+
"""
|
| 215 |
+
Args:
|
| 216 |
+
in_channels:
|
| 217 |
+
Number of channels in. The input shape is (N, T, in_channels).
|
| 218 |
+
Caution: It requires: T >=7, in_channels >=7
|
| 219 |
+
out_channels
|
| 220 |
+
Output dim. The output shape is (N, (T-3)//2, out_channels)
|
| 221 |
+
layer1_channels:
|
| 222 |
+
Number of channels in layer1
|
| 223 |
+
layer1_channels:
|
| 224 |
+
Number of channels in layer2
|
| 225 |
+
bottleneck:
|
| 226 |
+
bottleneck dimension for 1d squeeze-excite
|
| 227 |
+
"""
|
| 228 |
+
assert in_channels >= 7
|
| 229 |
+
super().__init__()
|
| 230 |
+
|
| 231 |
+
# The ScaleGrad module is there to prevent the gradients
|
| 232 |
+
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
|
| 233 |
+
# exceeding the range of fp16 when using automatic mixed precision (amp)
|
| 234 |
+
# training. (The second one is necessary to stop its bias from getting
|
| 235 |
+
# a too-large gradient).
|
| 236 |
+
|
| 237 |
+
self.conv = nn.Sequential(
|
| 238 |
+
nn.Conv2d(
|
| 239 |
+
in_channels=1,
|
| 240 |
+
out_channels=layer1_channels,
|
| 241 |
+
kernel_size=3,
|
| 242 |
+
padding=(0, 1), # (time, freq)
|
| 243 |
+
),
|
| 244 |
+
ScaleGrad(0.2),
|
| 245 |
+
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
|
| 246 |
+
SwooshR(),
|
| 247 |
+
nn.Conv2d(
|
| 248 |
+
in_channels=layer1_channels,
|
| 249 |
+
out_channels=layer2_channels,
|
| 250 |
+
kernel_size=3,
|
| 251 |
+
stride=2,
|
| 252 |
+
padding=0,
|
| 253 |
+
),
|
| 254 |
+
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
|
| 255 |
+
SwooshR(),
|
| 256 |
+
nn.Conv2d(
|
| 257 |
+
in_channels=layer2_channels,
|
| 258 |
+
out_channels=layer3_channels,
|
| 259 |
+
kernel_size=3,
|
| 260 |
+
stride=(1, 2), # (time, freq)
|
| 261 |
+
),
|
| 262 |
+
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
|
| 263 |
+
SwooshR(),
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# just one convnext layer
|
| 267 |
+
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
|
| 268 |
+
|
| 269 |
+
# (in_channels-3)//4
|
| 270 |
+
self.out_width = (((in_channels - 1) // 2) - 1) // 2
|
| 271 |
+
self.layer3_channels = layer3_channels
|
| 272 |
+
|
| 273 |
+
self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
|
| 274 |
+
# use a larger than normal grad_scale on this whitening module; there is
|
| 275 |
+
# only one such module, so there is not a concern about adding together
|
| 276 |
+
# many copies of this extra gradient term.
|
| 277 |
+
self.out_whiten = Whiten(
|
| 278 |
+
num_groups=1,
|
| 279 |
+
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
|
| 280 |
+
prob=(0.025, 0.25),
|
| 281 |
+
grad_scale=0.02,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
|
| 285 |
+
# getting large, there is an unnecessary degree of freedom.
|
| 286 |
+
self.out_norm = BiasNorm(out_channels)
|
| 287 |
+
self.dropout = Dropout3(dropout, shared_dim=1)
|
| 288 |
+
|
| 289 |
+
def forward(
|
| 290 |
+
self, x: torch.Tensor, x_lens: torch.Tensor
|
| 291 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 292 |
+
"""Subsample x.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
x:
|
| 296 |
+
Its shape is (N, T, idim).
|
| 297 |
+
x_lens:
|
| 298 |
+
A tensor of shape (batch_size,) containing the number of frames in
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
- a tensor of shape (N, (T-7)//2, odim)
|
| 302 |
+
- output lengths, of shape (batch_size,)
|
| 303 |
+
"""
|
| 304 |
+
# On entry, x is (N, T, idim)
|
| 305 |
+
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
| 306 |
+
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
|
| 307 |
+
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
|
| 308 |
+
# gradients.
|
| 309 |
+
x = self.conv(x)
|
| 310 |
+
x = self.convnext(x)
|
| 311 |
+
|
| 312 |
+
# Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
|
| 313 |
+
b, c, t, f = x.size()
|
| 314 |
+
|
| 315 |
+
x = x.transpose(1, 2).reshape(b, t, c * f)
|
| 316 |
+
# now x: (N, (T-7)//2, out_width * layer3_channels))
|
| 317 |
+
|
| 318 |
+
x = self.out(x)
|
| 319 |
+
# Now x is of shape (N, (T-7)//2, odim)
|
| 320 |
+
x = self.out_whiten(x)
|
| 321 |
+
x = self.out_norm(x)
|
| 322 |
+
x = self.dropout(x)
|
| 323 |
+
|
| 324 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 325 |
+
x_lens = (x_lens - 7) // 2
|
| 326 |
+
else:
|
| 327 |
+
with warnings.catch_warnings():
|
| 328 |
+
warnings.simplefilter("ignore")
|
| 329 |
+
x_lens = (x_lens - 7) // 2
|
| 330 |
+
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
|
| 331 |
+
|
| 332 |
+
return x, x_lens
|
| 333 |
+
|
| 334 |
+
def streaming_forward(
|
| 335 |
+
self,
|
| 336 |
+
x: torch.Tensor,
|
| 337 |
+
x_lens: torch.Tensor,
|
| 338 |
+
cached_left_pad: Tensor,
|
| 339 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 340 |
+
"""Subsample x.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
x:
|
| 344 |
+
Its shape is (N, T, idim).
|
| 345 |
+
x_lens:
|
| 346 |
+
A tensor of shape (batch_size,) containing the number of frames in
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
- a tensor of shape (N, (T-7)//2, odim)
|
| 350 |
+
- output lengths, of shape (batch_size,)
|
| 351 |
+
- updated cache
|
| 352 |
+
"""
|
| 353 |
+
# On entry, x is (N, T, idim)
|
| 354 |
+
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
| 355 |
+
|
| 356 |
+
# T' = (T-7)//2
|
| 357 |
+
x = self.conv(x)
|
| 358 |
+
|
| 359 |
+
# T' = (T-7)//2-3
|
| 360 |
+
x, cached_left_pad = self.convnext.streaming_forward(
|
| 361 |
+
x, cached_left_pad=cached_left_pad
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
# Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
|
| 365 |
+
b, c, t, f = x.size()
|
| 366 |
+
|
| 367 |
+
x = x.transpose(1, 2).reshape(b, t, c * f)
|
| 368 |
+
# now x: (N, T', out_width * layer3_channels))
|
| 369 |
+
|
| 370 |
+
x = self.out(x)
|
| 371 |
+
# Now x is of shape (N, T', odim)
|
| 372 |
+
x = self.out_norm(x)
|
| 373 |
+
|
| 374 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 375 |
+
assert self.convnext.padding[0] == 3
|
| 376 |
+
# The ConvNeXt module needs 3 frames of right padding after subsampling
|
| 377 |
+
x_lens = (x_lens - 7) // 2 - 3
|
| 378 |
+
else:
|
| 379 |
+
with warnings.catch_warnings():
|
| 380 |
+
warnings.simplefilter("ignore")
|
| 381 |
+
# The ConvNeXt module needs 3 frames of right padding after subsampling
|
| 382 |
+
assert self.convnext.padding[0] == 3
|
| 383 |
+
x_lens = (x_lens - 7) // 2 - 3
|
| 384 |
+
|
| 385 |
+
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())
|
| 386 |
+
|
| 387 |
+
return x, x_lens, cached_left_pad
|
| 388 |
+
|
| 389 |
+
@torch.jit.export
|
| 390 |
+
def get_init_states(
|
| 391 |
+
self,
|
| 392 |
+
batch_size: int = 1,
|
| 393 |
+
device: torch.device = torch.device("cpu"),
|
| 394 |
+
) -> Tensor:
|
| 395 |
+
"""Get initial states for Conv2dSubsampling module.
|
| 396 |
+
It is the cached left padding for ConvNeXt module,
|
| 397 |
+
of shape (batch_size, num_channels, left_pad, num_freqs)
|
| 398 |
+
"""
|
| 399 |
+
left_pad = self.convnext.padding[0]
|
| 400 |
+
freq = self.out_width
|
| 401 |
+
channels = self.layer3_channels
|
| 402 |
+
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
|
| 403 |
+
device
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
return cached_embed_left_pad
|
utilities.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class ZipformerConfig:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
# 用 _config 存储所有参数
|
| 7 |
+
self._config = {
|
| 8 |
+
"feature_dim": 128,
|
| 9 |
+
"pos_dim": 48,
|
| 10 |
+
"output_downsampling_factor": 2,
|
| 11 |
+
"downsampling_factor": "1,2,4,8,4,2",
|
| 12 |
+
"num_encoder_layers": "2,2,3,4,3,2",
|
| 13 |
+
"feedforward_dim": "512,768,1024,1536,1024,768",
|
| 14 |
+
"encoder_dim": "192,256,448,768,448,192",
|
| 15 |
+
"encoder_unmasked_dim": "192,192,256,256,256,192",
|
| 16 |
+
"cnn_module_kernel": "31,31,15,15,15,31",
|
| 17 |
+
"num_heads": "4,4,4,8,4,4",
|
| 18 |
+
"causal": True,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
def __getattr__(self, key):
|
| 22 |
+
if key in self._config:
|
| 23 |
+
return self._config[key]
|
| 24 |
+
raise AttributeError(f"'ZipformerConfig' object has no attribute '{key}'")
|
| 25 |
+
|
| 26 |
+
def __setattr__(self, key, value):
|
| 27 |
+
if key == "_config":
|
| 28 |
+
super().__setattr__(key, value)
|
| 29 |
+
else:
|
| 30 |
+
self._config[key] = value
|
| 31 |
+
|
| 32 |
+
def __delattr__(self, key):
|
| 33 |
+
if key in self._config:
|
| 34 |
+
del self._config[key]
|
| 35 |
+
else:
|
| 36 |
+
raise AttributeError(f"'ZipformerConfig' object has no attribute '{key}'")
|
| 37 |
+
|
| 38 |
+
def to_dict(self):
|
| 39 |
+
return dict(self._config)
|
| 40 |
+
|
| 41 |
+
def __repr__(self):
|
| 42 |
+
return f"ZipformerConfig({self._config})"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def str2bool(v):
|
| 47 |
+
"""Used in argparse.ArgumentParser.add_argument to indicate
|
| 48 |
+
that a type is a bool type and user can enter
|
| 49 |
+
|
| 50 |
+
- yes, true, t, y, 1, to represent True
|
| 51 |
+
- no, false, f, n, 0, to represent False
|
| 52 |
+
|
| 53 |
+
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
| 54 |
+
"""
|
| 55 |
+
if isinstance(v, bool):
|
| 56 |
+
return v
|
| 57 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 58 |
+
return True
|
| 59 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 60 |
+
return False
|
| 61 |
+
else:
|
| 62 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| 66 |
+
"""
|
| 67 |
+
Args:
|
| 68 |
+
lengths:
|
| 69 |
+
A 1-D tensor containing sentence lengths.
|
| 70 |
+
max_len:
|
| 71 |
+
The length of masks.
|
| 72 |
+
Returns:
|
| 73 |
+
Return a 2-D bool tensor, where masked positions
|
| 74 |
+
are filled with `True` and non-masked positions are
|
| 75 |
+
filled with `False`.
|
| 76 |
+
|
| 77 |
+
>>> lengths = torch.tensor([1, 3, 2, 5])
|
| 78 |
+
>>> make_pad_mask(lengths)
|
| 79 |
+
tensor([[False, True, True, True, True],
|
| 80 |
+
[False, False, False, True, True],
|
| 81 |
+
[False, False, True, True, True],
|
| 82 |
+
[False, False, False, False, False]])
|
| 83 |
+
"""
|
| 84 |
+
assert lengths.ndim == 1, lengths.ndim
|
| 85 |
+
max_len = max(max_len, lengths.max())
|
| 86 |
+
n = lengths.size(0)
|
| 87 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
| 88 |
+
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
| 89 |
+
|
| 90 |
+
return expaned_lengths >= lengths.unsqueeze(-1)
|
zipformer.py
ADDED
|
@@ -0,0 +1,2469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey,
|
| 3 |
+
# Zengwei Yao)
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
import copy
|
| 20 |
+
import math
|
| 21 |
+
import warnings
|
| 22 |
+
from typing import List, Optional, Tuple, Union
|
| 23 |
+
import logging
|
| 24 |
+
import torch
|
| 25 |
+
import random
|
| 26 |
+
from scaling import (
|
| 27 |
+
Balancer,
|
| 28 |
+
BiasNorm,
|
| 29 |
+
Dropout2,
|
| 30 |
+
ChunkCausalDepthwiseConv1d,
|
| 31 |
+
ActivationDropoutAndLinear,
|
| 32 |
+
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
| 33 |
+
Whiten,
|
| 34 |
+
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
| 35 |
+
penalize_abs_values_gt,
|
| 36 |
+
softmax,
|
| 37 |
+
ScheduledFloat,
|
| 38 |
+
FloatLike,
|
| 39 |
+
limit_param_value,
|
| 40 |
+
convert_num_channels,
|
| 41 |
+
)
|
| 42 |
+
from torch import Tensor, nn
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class EncoderInterface(nn.Module):
|
| 46 |
+
def forward(
|
| 47 |
+
self, x: torch.Tensor, x_lens: torch.Tensor
|
| 48 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 49 |
+
"""
|
| 50 |
+
Args:
|
| 51 |
+
x:
|
| 52 |
+
A tensor of shape (batch_size, input_seq_len, num_features)
|
| 53 |
+
containing the input features.
|
| 54 |
+
x_lens:
|
| 55 |
+
A tensor of shape (batch_size,) containing the number of frames
|
| 56 |
+
in `x` before padding.
|
| 57 |
+
Returns:
|
| 58 |
+
Return a tuple containing two tensors:
|
| 59 |
+
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
|
| 60 |
+
containing unnormalized probabilities, i.e., the output of a
|
| 61 |
+
linear layer.
|
| 62 |
+
- encoder_out_lens, a tensor of shape (batch_size,) containing
|
| 63 |
+
the number of frames in `encoder_out` before padding.
|
| 64 |
+
"""
|
| 65 |
+
raise NotImplementedError("Please implement it in a subclass")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Zipformer2(EncoderInterface):
|
| 69 |
+
"""
|
| 70 |
+
Args:
|
| 71 |
+
|
| 72 |
+
Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length
|
| 73 |
+
as downsampling_factor if they are single ints or one-element tuples. The length of
|
| 74 |
+
downsampling_factor defines the number of stacks.
|
| 75 |
+
|
| 76 |
+
output_downsampling_factor (int): how much to downsample at the output. Note:
|
| 77 |
+
we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
|
| 78 |
+
You should probably leave this at 2.
|
| 79 |
+
downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
|
| 80 |
+
Note: this is in addition to the downsampling factor of 2 that is applied in
|
| 81 |
+
the frontend (self.encoder_embed).
|
| 82 |
+
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
|
| 83 |
+
encoder stack.
|
| 84 |
+
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
|
| 85 |
+
encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
|
| 86 |
+
the encoder stacks for purposes of per-frame dropout (recommend 256 for
|
| 87 |
+
now).
|
| 88 |
+
query_head_dim (int or Tuple[int]): dimension of query and key per attention
|
| 89 |
+
head: per stack, if a tuple..
|
| 90 |
+
pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per
|
| 91 |
+
attention head
|
| 92 |
+
value_head_dim (int or Tuple[int]): dimension of value in each attention head
|
| 93 |
+
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
|
| 94 |
+
Must be at least 4.
|
| 95 |
+
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
|
| 96 |
+
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
|
| 97 |
+
|
| 98 |
+
pos_dim (int): the dimension of each positional-encoding vector prior to projection,
|
| 99 |
+
e.g. 128.
|
| 100 |
+
|
| 101 |
+
dropout (float): dropout rate
|
| 102 |
+
warmup_batches (float): number of batches to warm up over; this controls
|
| 103 |
+
dropout of encoder layers.
|
| 104 |
+
causal (bool): if True, support chunkwise causal convolution. This should
|
| 105 |
+
not hurt WER as no modeling power is lost, but the convolution modules will be
|
| 106 |
+
slightly slower and use more memory. Enables use of the chunk_size and
|
| 107 |
+
left_context_chunks options in forward(), which simulates streaming
|
| 108 |
+
decoding.
|
| 109 |
+
chunk_size: (list of int): only set this to other than [-1] if causal;
|
| 110 |
+
the chunk size will be randomly chosen from this list. -1 means no chunking.
|
| 111 |
+
left_context_frames: (list of int): determines the number of left-
|
| 112 |
+
context chunks for causal training; will be rounded to a number of
|
| 113 |
+
chunks. Must not be less than cnn_module_kernel (after factoring in
|
| 114 |
+
rounding and downsampling); an error will be thrown if this is violated.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
output_downsampling_factor: int = 2,
|
| 120 |
+
downsampling_factor: Tuple[int] = (2, 4),
|
| 121 |
+
encoder_dim: Union[int, Tuple[int]] = 384,
|
| 122 |
+
num_encoder_layers: Union[int, Tuple[int]] = 4,
|
| 123 |
+
encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
|
| 124 |
+
query_head_dim: Union[int, Tuple[int]] = 24,
|
| 125 |
+
pos_head_dim: Union[int, Tuple[int]] = 4,
|
| 126 |
+
value_head_dim: Union[int, Tuple[int]] = 12,
|
| 127 |
+
num_heads: Union[int, Tuple[int]] = 8,
|
| 128 |
+
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
| 129 |
+
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
| 130 |
+
pos_dim: int = 192,
|
| 131 |
+
dropout: FloatLike = None, # see code below for default
|
| 132 |
+
warmup_batches: float = 4000.0,
|
| 133 |
+
causal: bool = False,
|
| 134 |
+
chunk_size: Tuple[int] = [-1],
|
| 135 |
+
left_context_frames: Tuple[int] = [-1],
|
| 136 |
+
) -> None:
|
| 137 |
+
super(Zipformer2, self).__init__()
|
| 138 |
+
|
| 139 |
+
if dropout is None:
|
| 140 |
+
dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
|
| 141 |
+
|
| 142 |
+
def _to_tuple(x):
|
| 143 |
+
"""Converts a single int or a 1-tuple of an int to a tuple with the same length
|
| 144 |
+
as downsampling_factor"""
|
| 145 |
+
if isinstance(x, int):
|
| 146 |
+
x = (x,)
|
| 147 |
+
if len(x) == 1:
|
| 148 |
+
x = x * len(downsampling_factor)
|
| 149 |
+
else:
|
| 150 |
+
assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
|
| 151 |
+
return x
|
| 152 |
+
|
| 153 |
+
self.output_downsampling_factor = output_downsampling_factor # int
|
| 154 |
+
self.downsampling_factor = downsampling_factor # tuple
|
| 155 |
+
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
|
| 156 |
+
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
|
| 157 |
+
encoder_unmasked_dim
|
| 158 |
+
) # tuple
|
| 159 |
+
num_encoder_layers = _to_tuple(num_encoder_layers)
|
| 160 |
+
self.num_encoder_layers = num_encoder_layers
|
| 161 |
+
self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
|
| 162 |
+
self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
|
| 163 |
+
pos_head_dim = _to_tuple(pos_head_dim)
|
| 164 |
+
self.num_heads = num_heads = _to_tuple(num_heads)
|
| 165 |
+
feedforward_dim = _to_tuple(feedforward_dim)
|
| 166 |
+
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
| 167 |
+
|
| 168 |
+
self.causal = causal
|
| 169 |
+
self.chunk_size = chunk_size
|
| 170 |
+
self.left_context_frames = left_context_frames
|
| 171 |
+
|
| 172 |
+
for u, d in zip(encoder_unmasked_dim, encoder_dim):
|
| 173 |
+
assert u <= d
|
| 174 |
+
|
| 175 |
+
# each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
|
| 176 |
+
encoders = []
|
| 177 |
+
|
| 178 |
+
num_encoders = len(downsampling_factor)
|
| 179 |
+
for i in range(num_encoders):
|
| 180 |
+
encoder_layer = Zipformer2EncoderLayer(
|
| 181 |
+
embed_dim=encoder_dim[i],
|
| 182 |
+
pos_dim=pos_dim,
|
| 183 |
+
num_heads=num_heads[i],
|
| 184 |
+
query_head_dim=query_head_dim[i],
|
| 185 |
+
pos_head_dim=pos_head_dim[i],
|
| 186 |
+
value_head_dim=value_head_dim[i],
|
| 187 |
+
feedforward_dim=feedforward_dim[i],
|
| 188 |
+
dropout=dropout,
|
| 189 |
+
cnn_module_kernel=cnn_module_kernel[i],
|
| 190 |
+
causal=causal,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# For the segment of the warmup period, we let the Conv2dSubsampling
|
| 194 |
+
# layer learn something. Then we start to warm up the other encoders.
|
| 195 |
+
encoder = Zipformer2Encoder(
|
| 196 |
+
encoder_layer,
|
| 197 |
+
num_encoder_layers[i],
|
| 198 |
+
pos_dim=pos_dim,
|
| 199 |
+
dropout=dropout,
|
| 200 |
+
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
| 201 |
+
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
| 202 |
+
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
if downsampling_factor[i] != 1:
|
| 206 |
+
encoder = DownsampledZipformer2Encoder(
|
| 207 |
+
encoder,
|
| 208 |
+
dim=encoder_dim[i],
|
| 209 |
+
downsample=downsampling_factor[i],
|
| 210 |
+
dropout=dropout,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
encoders.append(encoder)
|
| 214 |
+
|
| 215 |
+
self.encoders = nn.ModuleList(encoders)
|
| 216 |
+
|
| 217 |
+
if output_downsampling_factor >= 2:
|
| 218 |
+
self.downsample_output = SimpleDownsample(
|
| 219 |
+
max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
self.downsample_output = None
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
|
| 226 |
+
"""
|
| 227 |
+
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
|
| 228 |
+
randomized feature masks, one per encoder.
|
| 229 |
+
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
|
| 230 |
+
some supplied number, e.g. >256, so in effect on those frames we are using
|
| 231 |
+
a smaller encoer dim.
|
| 232 |
+
|
| 233 |
+
We generate the random masks at this level because we want the 2 masks to 'agree'
|
| 234 |
+
all the way up the encoder stack. This will mean that the 1st mask will have
|
| 235 |
+
mask values repeated self.zipformer_subsampling_factor times.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
x: the embeddings (needed for the shape and dtype and device), of shape
|
| 239 |
+
(1, batch_size, encoder_dims0)
|
| 240 |
+
"""
|
| 241 |
+
num_encoders = len(self.encoder_dim)
|
| 242 |
+
if not self.training:
|
| 243 |
+
return [1.0] * num_encoders
|
| 244 |
+
|
| 245 |
+
(num_frames0, batch_size, _encoder_dims0) = x.shape
|
| 246 |
+
|
| 247 |
+
assert self.encoder_dim[0] == _encoder_dims0, (
|
| 248 |
+
self.encoder_dim[0],
|
| 249 |
+
_encoder_dims0,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
feature_mask_dropout_prob = 0.125
|
| 253 |
+
|
| 254 |
+
# mask1 shape: (1, batch_size, 1)
|
| 255 |
+
mask1 = (
|
| 256 |
+
torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob
|
| 257 |
+
).to(x.dtype)
|
| 258 |
+
|
| 259 |
+
# mask2 has additional sequences masked, about twice the number.
|
| 260 |
+
mask2 = torch.logical_and(
|
| 261 |
+
mask1,
|
| 262 |
+
(
|
| 263 |
+
torch.rand(1, batch_size, 1, device=x.device)
|
| 264 |
+
> feature_mask_dropout_prob
|
| 265 |
+
).to(x.dtype),
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# dim: (1, batch_size, 2)
|
| 269 |
+
mask = torch.cat((mask1, mask2), dim=-1)
|
| 270 |
+
|
| 271 |
+
feature_masks = []
|
| 272 |
+
for i in range(num_encoders):
|
| 273 |
+
channels = self.encoder_dim[i]
|
| 274 |
+
feature_mask = torch.ones(
|
| 275 |
+
1, batch_size, channels, dtype=x.dtype, device=x.device
|
| 276 |
+
)
|
| 277 |
+
u1 = self.encoder_unmasked_dim[i]
|
| 278 |
+
u2 = u1 + (channels - u1) // 2
|
| 279 |
+
|
| 280 |
+
feature_mask[:, :, u1:u2] *= mask[..., 0:1]
|
| 281 |
+
feature_mask[:, :, u2:] *= mask[..., 1:2]
|
| 282 |
+
|
| 283 |
+
feature_masks.append(feature_mask)
|
| 284 |
+
|
| 285 |
+
return feature_masks
|
| 286 |
+
|
| 287 |
+
def get_chunk_info(self) -> Tuple[int, int]:
|
| 288 |
+
"""
|
| 289 |
+
Returns chunk_size and left_context_chunks.
|
| 290 |
+
"""
|
| 291 |
+
if not self.causal:
|
| 292 |
+
return -1, -1
|
| 293 |
+
|
| 294 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 295 |
+
assert len(self.chunk_size) == 1, self.chunk_size
|
| 296 |
+
chunk_size = self.chunk_size[0]
|
| 297 |
+
else:
|
| 298 |
+
chunk_size = random.choice(self.chunk_size)
|
| 299 |
+
|
| 300 |
+
if chunk_size == -1:
|
| 301 |
+
left_context_chunks = -1
|
| 302 |
+
else:
|
| 303 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 304 |
+
assert len(self.left_context_frames) == 1, self.left_context_frames
|
| 305 |
+
left_context_frames = self.left_context_frames[0]
|
| 306 |
+
else:
|
| 307 |
+
left_context_frames = random.choice(self.left_context_frames)
|
| 308 |
+
# Note: in Python, -1 // n == -1 for n > 0
|
| 309 |
+
left_context_chunks = left_context_frames // chunk_size
|
| 310 |
+
if left_context_chunks == 0:
|
| 311 |
+
left_context_chunks = 1
|
| 312 |
+
|
| 313 |
+
return chunk_size, left_context_chunks
|
| 314 |
+
|
| 315 |
+
def forward(
|
| 316 |
+
self,
|
| 317 |
+
x: Tensor,
|
| 318 |
+
x_lens: Tensor,
|
| 319 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 320 |
+
return_middle_out: bool = False,
|
| 321 |
+
) -> Tuple[Tensor, Tensor]:
|
| 322 |
+
"""
|
| 323 |
+
Args:
|
| 324 |
+
x:
|
| 325 |
+
The input tensor. Its shape is (seq_len, batch_size, feature_dim).
|
| 326 |
+
x_lens:
|
| 327 |
+
A tensor of shape (batch_size,) containing the number of frames in
|
| 328 |
+
`x` before padding.
|
| 329 |
+
src_key_padding_mask:
|
| 330 |
+
The mask for padding, of shape (batch_size, seq_len); True means
|
| 331 |
+
masked position. May be None.
|
| 332 |
+
Returns:
|
| 333 |
+
Return a tuple containing 2 tensors:
|
| 334 |
+
- embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
|
| 335 |
+
- lengths, a tensor of shape (batch_size,) containing the number
|
| 336 |
+
of frames in `embeddings` before padding.
|
| 337 |
+
"""
|
| 338 |
+
outputs = []
|
| 339 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 340 |
+
feature_masks = [1.0] * len(self.encoder_dim)
|
| 341 |
+
else:
|
| 342 |
+
feature_masks = self.get_feature_masks(x)
|
| 343 |
+
|
| 344 |
+
chunk_size, left_context_chunks = self.get_chunk_info()
|
| 345 |
+
|
| 346 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 347 |
+
# Not support exporting a model for simulating streaming decoding
|
| 348 |
+
attn_mask = None
|
| 349 |
+
else:
|
| 350 |
+
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
|
| 351 |
+
|
| 352 |
+
for i, module in enumerate(self.encoders):
|
| 353 |
+
ds = self.downsampling_factor[i]
|
| 354 |
+
x = convert_num_channels(x, self.encoder_dim[i])
|
| 355 |
+
|
| 356 |
+
x = module(
|
| 357 |
+
x,
|
| 358 |
+
chunk_size=chunk_size,
|
| 359 |
+
feature_mask=feature_masks[i],
|
| 360 |
+
src_key_padding_mask=(
|
| 361 |
+
None
|
| 362 |
+
if src_key_padding_mask is None
|
| 363 |
+
else src_key_padding_mask[..., ::ds]
|
| 364 |
+
),
|
| 365 |
+
attn_mask=attn_mask,
|
| 366 |
+
)
|
| 367 |
+
outputs.append(x)
|
| 368 |
+
|
| 369 |
+
# if the last output has the largest dimension, x will be unchanged,
|
| 370 |
+
# it will be the same as outputs[-1]. Otherwise it will be concatenated
|
| 371 |
+
# from different pieces of 'outputs', taking each dimension from the
|
| 372 |
+
# most recent output that has it present.
|
| 373 |
+
x = self._get_full_dim_output(outputs)
|
| 374 |
+
|
| 375 |
+
if self.output_downsampling_factor >= 2:
|
| 376 |
+
x = self.downsample_output(x)
|
| 377 |
+
# class Downsample has this rounding behavior..
|
| 378 |
+
assert self.output_downsampling_factor == 2, self.output_downsampling_factor
|
| 379 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 380 |
+
lengths = (x_lens + 1) // 2
|
| 381 |
+
else:
|
| 382 |
+
with warnings.catch_warnings():
|
| 383 |
+
warnings.simplefilter("ignore")
|
| 384 |
+
lengths = (x_lens + 1) // 2
|
| 385 |
+
else:
|
| 386 |
+
lengths = x_lens
|
| 387 |
+
if return_middle_out:
|
| 388 |
+
return x, lengths, outputs
|
| 389 |
+
else:
|
| 390 |
+
return x, lengths
|
| 391 |
+
|
| 392 |
+
def _get_attn_mask(
|
| 393 |
+
self, x: Tensor, chunk_size: int, left_context_chunks: int
|
| 394 |
+
) -> Optional[Tensor]:
|
| 395 |
+
"""
|
| 396 |
+
Return None if chunk_size == -1, else return attention mask of shape
|
| 397 |
+
(seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
|
| 398 |
+
means a masked position.
|
| 399 |
+
Args:
|
| 400 |
+
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
|
| 401 |
+
chunk_size: chunk size, must divide
|
| 402 |
+
"""
|
| 403 |
+
if chunk_size <= 0:
|
| 404 |
+
return None
|
| 405 |
+
assert all(chunk_size % d == 0 for d in self.downsampling_factor)
|
| 406 |
+
if left_context_chunks >= 0:
|
| 407 |
+
num_encoders = len(self.encoder_dim)
|
| 408 |
+
assert all(
|
| 409 |
+
chunk_size * left_context_chunks
|
| 410 |
+
>= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
|
| 411 |
+
for i in range(num_encoders)
|
| 412 |
+
)
|
| 413 |
+
else:
|
| 414 |
+
left_context_chunks = 1000000
|
| 415 |
+
|
| 416 |
+
seq_len = x.shape[0]
|
| 417 |
+
|
| 418 |
+
# t is frame index, shape (seq_len,)
|
| 419 |
+
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
| 420 |
+
# c is chunk index for each frame, shape (seq_len,)
|
| 421 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 422 |
+
c = t // chunk_size
|
| 423 |
+
else:
|
| 424 |
+
with warnings.catch_warnings():
|
| 425 |
+
warnings.simplefilter("ignore")
|
| 426 |
+
c = t // chunk_size
|
| 427 |
+
src_c = c
|
| 428 |
+
tgt_c = c.unsqueeze(-1)
|
| 429 |
+
|
| 430 |
+
attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks)
|
| 431 |
+
if __name__ == "__main__":
|
| 432 |
+
logging.info(f"attn_mask = {attn_mask}")
|
| 433 |
+
return attn_mask
|
| 434 |
+
|
| 435 |
+
def _get_full_dim_output(self, outputs: List[Tensor]):
|
| 436 |
+
num_encoders = len(self.encoder_dim)
|
| 437 |
+
assert len(outputs) == num_encoders
|
| 438 |
+
output_dim = max(self.encoder_dim)
|
| 439 |
+
output_pieces = [outputs[-1]]
|
| 440 |
+
cur_dim = self.encoder_dim[-1]
|
| 441 |
+
for i in range(num_encoders - 2, -1, -1):
|
| 442 |
+
d = self.encoder_dim[i]
|
| 443 |
+
if d > cur_dim:
|
| 444 |
+
this_output = outputs[i]
|
| 445 |
+
output_pieces.append(this_output[..., cur_dim:d])
|
| 446 |
+
cur_dim = d
|
| 447 |
+
assert cur_dim == output_dim
|
| 448 |
+
return torch.cat(output_pieces, dim=-1)
|
| 449 |
+
|
| 450 |
+
def streaming_forward(
|
| 451 |
+
self,
|
| 452 |
+
x: Tensor,
|
| 453 |
+
x_lens: Tensor,
|
| 454 |
+
states: List[Tensor],
|
| 455 |
+
src_key_padding_mask: Tensor,
|
| 456 |
+
) -> Tuple[Tensor, Tensor, List[Tensor]]:
|
| 457 |
+
"""
|
| 458 |
+
Args:
|
| 459 |
+
x:
|
| 460 |
+
The input tensor. Its shape is (seq_len, batch_size, feature_dim).
|
| 461 |
+
x_lens:
|
| 462 |
+
A tensor of shape (batch_size,) containing the number of frames in
|
| 463 |
+
`x` before padding.
|
| 464 |
+
states: list of cached tensors of all encoder layers. For layer-i,
|
| 465 |
+
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
|
| 466 |
+
cached_conv1, cached_conv2).
|
| 467 |
+
src_key_padding_mask:
|
| 468 |
+
The mask for padding, of shape (batch_size, seq_len); True means
|
| 469 |
+
masked position. May be None.
|
| 470 |
+
Returns:
|
| 471 |
+
Return a tuple containing 2 tensors:
|
| 472 |
+
- embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
|
| 473 |
+
- lengths, a tensor of shape (batch_size,) containing the number
|
| 474 |
+
of frames in `embeddings` before padding.
|
| 475 |
+
- updated states
|
| 476 |
+
"""
|
| 477 |
+
outputs = []
|
| 478 |
+
new_states = []
|
| 479 |
+
layer_offset = 0
|
| 480 |
+
|
| 481 |
+
for i, module in enumerate(self.encoders):
|
| 482 |
+
num_layers = module.num_layers
|
| 483 |
+
ds = self.downsampling_factor[i]
|
| 484 |
+
x = convert_num_channels(x, self.encoder_dim[i])
|
| 485 |
+
|
| 486 |
+
x, new_layer_states = module.streaming_forward(
|
| 487 |
+
x,
|
| 488 |
+
states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
|
| 489 |
+
left_context_len=self.left_context_frames[0] // ds,
|
| 490 |
+
src_key_padding_mask=src_key_padding_mask[..., ::ds],
|
| 491 |
+
)
|
| 492 |
+
layer_offset += num_layers
|
| 493 |
+
outputs.append(x)
|
| 494 |
+
new_states += new_layer_states
|
| 495 |
+
|
| 496 |
+
# if the last output has the largest dimension, x will be unchanged,
|
| 497 |
+
# it will be the same as outputs[-1]. Otherwise it will be concatenated
|
| 498 |
+
# from different pieces of 'outputs', taking each dimension from the
|
| 499 |
+
# most recent output that has it present.
|
| 500 |
+
x = self._get_full_dim_output(outputs)
|
| 501 |
+
|
| 502 |
+
if self.output_downsampling_factor >= 2:
|
| 503 |
+
x = self.downsample_output(x)
|
| 504 |
+
# class Downsample has this rounding behavior..
|
| 505 |
+
assert self.output_downsampling_factor == 2, self.output_downsampling_factor
|
| 506 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 507 |
+
lengths = (x_lens + 1) // 2
|
| 508 |
+
else:
|
| 509 |
+
with warnings.catch_warnings():
|
| 510 |
+
warnings.simplefilter("ignore")
|
| 511 |
+
lengths = (x_lens + 1) // 2
|
| 512 |
+
else:
|
| 513 |
+
lengths = x_lens
|
| 514 |
+
|
| 515 |
+
return x, lengths, new_states
|
| 516 |
+
|
| 517 |
+
@torch.jit.export
|
| 518 |
+
def get_init_states(
|
| 519 |
+
self,
|
| 520 |
+
batch_size: int = 1,
|
| 521 |
+
device: torch.device = torch.device("cpu"),
|
| 522 |
+
) -> List[Tensor]:
|
| 523 |
+
"""Get initial states.
|
| 524 |
+
|
| 525 |
+
A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
|
| 526 |
+
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
| 527 |
+
"""
|
| 528 |
+
states = []
|
| 529 |
+
for i, module in enumerate(self.encoders):
|
| 530 |
+
num_layers = module.num_layers
|
| 531 |
+
embed_dim = self.encoder_dim[i]
|
| 532 |
+
ds = self.downsampling_factor[i]
|
| 533 |
+
num_heads = self.num_heads[i]
|
| 534 |
+
key_dim = self.query_head_dim[i] * num_heads
|
| 535 |
+
value_dim = self.value_head_dim[i] * num_heads
|
| 536 |
+
downsample_left = self.left_context_frames[0] // ds
|
| 537 |
+
nonlin_attn_head_dim = 3 * embed_dim // 4
|
| 538 |
+
conv_left_pad = self.cnn_module_kernel[i] // 2
|
| 539 |
+
for layer in range(num_layers):
|
| 540 |
+
cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(
|
| 541 |
+
device
|
| 542 |
+
)
|
| 543 |
+
cached_nonlin_attn = torch.zeros(
|
| 544 |
+
1, batch_size, downsample_left, nonlin_attn_head_dim
|
| 545 |
+
).to(device)
|
| 546 |
+
cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(
|
| 547 |
+
device
|
| 548 |
+
)
|
| 549 |
+
cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(
|
| 550 |
+
device
|
| 551 |
+
)
|
| 552 |
+
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
|
| 553 |
+
device
|
| 554 |
+
)
|
| 555 |
+
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
|
| 556 |
+
device
|
| 557 |
+
)
|
| 558 |
+
states += [
|
| 559 |
+
cached_key,
|
| 560 |
+
cached_nonlin_attn,
|
| 561 |
+
cached_val1,
|
| 562 |
+
cached_val2,
|
| 563 |
+
cached_conv1,
|
| 564 |
+
cached_conv2,
|
| 565 |
+
]
|
| 566 |
+
|
| 567 |
+
return states
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
| 571 |
+
return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def _balancer_schedule(min_prob: float):
|
| 575 |
+
return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
class Zipformer2EncoderLayer(nn.Module):
|
| 579 |
+
"""
|
| 580 |
+
Args:
|
| 581 |
+
embed_dim: the number of expected features in the input (required).
|
| 582 |
+
nhead: the number of heads in the multiheadattention models (required).
|
| 583 |
+
feedforward_dim: the dimension of the feedforward network model (default=2048).
|
| 584 |
+
dropout: the dropout value (default=0.1).
|
| 585 |
+
cnn_module_kernel (int): Kernel size of convolution module.
|
| 586 |
+
|
| 587 |
+
Examples::
|
| 588 |
+
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
| 589 |
+
>>> src = torch.rand(10, 32, 512)
|
| 590 |
+
>>> pos_emb = torch.rand(32, 19, 512)
|
| 591 |
+
>>> out = encoder_layer(src, pos_emb)
|
| 592 |
+
"""
|
| 593 |
+
|
| 594 |
+
def __init__(
|
| 595 |
+
self,
|
| 596 |
+
embed_dim: int,
|
| 597 |
+
pos_dim: int,
|
| 598 |
+
num_heads: int,
|
| 599 |
+
query_head_dim: int,
|
| 600 |
+
pos_head_dim: int,
|
| 601 |
+
value_head_dim: int,
|
| 602 |
+
feedforward_dim: int,
|
| 603 |
+
dropout: FloatLike = 0.1,
|
| 604 |
+
cnn_module_kernel: int = 31,
|
| 605 |
+
causal: bool = False,
|
| 606 |
+
attention_skip_rate: FloatLike = ScheduledFloat(
|
| 607 |
+
(0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
|
| 608 |
+
),
|
| 609 |
+
conv_skip_rate: FloatLike = ScheduledFloat(
|
| 610 |
+
(0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
|
| 611 |
+
),
|
| 612 |
+
const_attention_rate: FloatLike = ScheduledFloat(
|
| 613 |
+
(0.0, 0.25), (4000.0, 0.025), default=0
|
| 614 |
+
),
|
| 615 |
+
ff2_skip_rate: FloatLike = ScheduledFloat(
|
| 616 |
+
(0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
|
| 617 |
+
),
|
| 618 |
+
ff3_skip_rate: FloatLike = ScheduledFloat(
|
| 619 |
+
(0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
|
| 620 |
+
),
|
| 621 |
+
bypass_skip_rate: FloatLike = ScheduledFloat(
|
| 622 |
+
(0.0, 0.5), (4000.0, 0.02), default=0
|
| 623 |
+
),
|
| 624 |
+
) -> None:
|
| 625 |
+
super(Zipformer2EncoderLayer, self).__init__()
|
| 626 |
+
self.embed_dim = embed_dim
|
| 627 |
+
|
| 628 |
+
# self.bypass implements layer skipping as well as bypass; see its default values.
|
| 629 |
+
self.bypass = BypassModule(
|
| 630 |
+
embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
|
| 631 |
+
)
|
| 632 |
+
# bypass_mid is bypass used in the middle of the layer.
|
| 633 |
+
self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
|
| 634 |
+
|
| 635 |
+
# skip probability for dynamic modules (meaning: anything but feedforward).
|
| 636 |
+
self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
|
| 637 |
+
# an additional skip probability that applies to ConvModule to stop it from
|
| 638 |
+
# contributing too much early on.
|
| 639 |
+
self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
|
| 640 |
+
|
| 641 |
+
# ff2_skip_rate is to prevent the ff2 module from having output that's too big
|
| 642 |
+
# compared to its residual.
|
| 643 |
+
self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
|
| 644 |
+
self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
|
| 645 |
+
|
| 646 |
+
self.const_attention_rate = copy.deepcopy(const_attention_rate)
|
| 647 |
+
|
| 648 |
+
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
| 649 |
+
embed_dim,
|
| 650 |
+
pos_dim=pos_dim,
|
| 651 |
+
num_heads=num_heads,
|
| 652 |
+
query_head_dim=query_head_dim,
|
| 653 |
+
pos_head_dim=pos_head_dim,
|
| 654 |
+
dropout=0.0,
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
|
| 658 |
+
|
| 659 |
+
self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
|
| 660 |
+
|
| 661 |
+
self.feed_forward1 = FeedforwardModule(
|
| 662 |
+
embed_dim, (feedforward_dim * 3) // 4, dropout
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
|
| 666 |
+
|
| 667 |
+
self.feed_forward3 = FeedforwardModule(
|
| 668 |
+
embed_dim, (feedforward_dim * 5) // 4, dropout
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
self.nonlin_attention = NonlinAttention(
|
| 672 |
+
embed_dim, hidden_channels=3 * embed_dim // 4
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
self.conv_module1 = ConvolutionModule(
|
| 676 |
+
embed_dim, cnn_module_kernel, causal=causal
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
self.conv_module2 = ConvolutionModule(
|
| 680 |
+
embed_dim, cnn_module_kernel, causal=causal
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
# TODO: remove it
|
| 684 |
+
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
| 685 |
+
|
| 686 |
+
self.norm = BiasNorm(embed_dim)
|
| 687 |
+
|
| 688 |
+
self.balancer1 = Balancer(
|
| 689 |
+
embed_dim,
|
| 690 |
+
channel_dim=-1,
|
| 691 |
+
min_positive=0.45,
|
| 692 |
+
max_positive=0.55,
|
| 693 |
+
min_abs=0.2,
|
| 694 |
+
max_abs=4.0,
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
# balancer for output of NonlinAttentionModule
|
| 698 |
+
self.balancer_na = Balancer(
|
| 699 |
+
embed_dim,
|
| 700 |
+
channel_dim=-1,
|
| 701 |
+
min_positive=0.3,
|
| 702 |
+
max_positive=0.7,
|
| 703 |
+
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
|
| 704 |
+
prob=0.05, # out of concern for memory usage
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
# balancer for output of feedforward2, prevent it from staying too
|
| 708 |
+
# small. give this a very small probability, even at the start of
|
| 709 |
+
# training, it's to fix a rare problem and it's OK to fix it slowly.
|
| 710 |
+
self.balancer_ff2 = Balancer(
|
| 711 |
+
embed_dim,
|
| 712 |
+
channel_dim=-1,
|
| 713 |
+
min_positive=0.3,
|
| 714 |
+
max_positive=0.7,
|
| 715 |
+
min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
|
| 716 |
+
max_abs=2.0,
|
| 717 |
+
prob=0.05,
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
self.balancer_ff3 = Balancer(
|
| 721 |
+
embed_dim,
|
| 722 |
+
channel_dim=-1,
|
| 723 |
+
min_positive=0.3,
|
| 724 |
+
max_positive=0.7,
|
| 725 |
+
min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
|
| 726 |
+
max_abs=4.0,
|
| 727 |
+
prob=0.05,
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
self.whiten = Whiten(
|
| 731 |
+
num_groups=1,
|
| 732 |
+
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
| 733 |
+
prob=(0.025, 0.25),
|
| 734 |
+
grad_scale=0.01,
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
self.balancer2 = Balancer(
|
| 738 |
+
embed_dim,
|
| 739 |
+
channel_dim=-1,
|
| 740 |
+
min_positive=0.45,
|
| 741 |
+
max_positive=0.55,
|
| 742 |
+
min_abs=0.1,
|
| 743 |
+
max_abs=4.0,
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
def get_sequence_dropout_mask(
|
| 747 |
+
self, x: Tensor, dropout_rate: float
|
| 748 |
+
) -> Optional[Tensor]:
|
| 749 |
+
if (
|
| 750 |
+
dropout_rate == 0.0
|
| 751 |
+
or not self.training
|
| 752 |
+
or torch.jit.is_scripting()
|
| 753 |
+
or torch.jit.is_tracing()
|
| 754 |
+
):
|
| 755 |
+
return None
|
| 756 |
+
batch_size = x.shape[1]
|
| 757 |
+
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
|
| 758 |
+
return mask
|
| 759 |
+
|
| 760 |
+
def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
|
| 761 |
+
"""
|
| 762 |
+
Apply sequence-level dropout to x.
|
| 763 |
+
x shape: (seq_len, batch_size, embed_dim)
|
| 764 |
+
"""
|
| 765 |
+
dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
|
| 766 |
+
if dropout_mask is None:
|
| 767 |
+
return x
|
| 768 |
+
else:
|
| 769 |
+
return x * dropout_mask
|
| 770 |
+
|
| 771 |
+
def forward(
|
| 772 |
+
self,
|
| 773 |
+
src: Tensor,
|
| 774 |
+
pos_emb: Tensor,
|
| 775 |
+
chunk_size: int = -1,
|
| 776 |
+
attn_mask: Optional[Tensor] = None,
|
| 777 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 778 |
+
) -> Tensor:
|
| 779 |
+
"""
|
| 780 |
+
Pass the input through the encoder layer.
|
| 781 |
+
Args:
|
| 782 |
+
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
| 783 |
+
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
|
| 784 |
+
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
| 785 |
+
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
| 786 |
+
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
| 787 |
+
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
| 788 |
+
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
| 789 |
+
True means masked position. May be None.
|
| 790 |
+
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
| 791 |
+
masked position. May be None.
|
| 792 |
+
|
| 793 |
+
Returns:
|
| 794 |
+
A tensor which has the same shape as src
|
| 795 |
+
"""
|
| 796 |
+
src_orig = src
|
| 797 |
+
|
| 798 |
+
# dropout rate for non-feedforward submodules
|
| 799 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 800 |
+
attention_skip_rate = 0.0
|
| 801 |
+
else:
|
| 802 |
+
attention_skip_rate = (
|
| 803 |
+
float(self.attention_skip_rate) if self.training else 0.0
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
| 807 |
+
attn_weights = self.self_attn_weights(
|
| 808 |
+
src,
|
| 809 |
+
pos_emb=pos_emb,
|
| 810 |
+
attn_mask=attn_mask,
|
| 811 |
+
key_padding_mask=src_key_padding_mask,
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
src = src + self.feed_forward1(src)
|
| 815 |
+
|
| 816 |
+
self_attn_dropout_mask = self.get_sequence_dropout_mask(
|
| 817 |
+
src, attention_skip_rate
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
selected_attn_weights = attn_weights[0:1]
|
| 821 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 822 |
+
pass
|
| 823 |
+
elif not self.training and random.random() < float(self.const_attention_rate):
|
| 824 |
+
# Make attention weights constant. The intention is to
|
| 825 |
+
# encourage these modules to do something similar to an
|
| 826 |
+
# averaging-over-time operation.
|
| 827 |
+
# only need the mask, can just use the 1st one and expand later
|
| 828 |
+
selected_attn_weights = selected_attn_weights[0:1]
|
| 829 |
+
selected_attn_weights = (selected_attn_weights > 0.0).to(
|
| 830 |
+
selected_attn_weights.dtype
|
| 831 |
+
)
|
| 832 |
+
selected_attn_weights = selected_attn_weights * (
|
| 833 |
+
1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
|
| 837 |
+
|
| 838 |
+
src = src + (
|
| 839 |
+
na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
self_attn = self.self_attn1(src, attn_weights)
|
| 843 |
+
|
| 844 |
+
src = src + (
|
| 845 |
+
self_attn
|
| 846 |
+
if self_attn_dropout_mask is None
|
| 847 |
+
else self_attn * self_attn_dropout_mask
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 851 |
+
conv_skip_rate = 0.0
|
| 852 |
+
else:
|
| 853 |
+
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
| 854 |
+
src = src + self.sequence_dropout(
|
| 855 |
+
self.conv_module1(
|
| 856 |
+
src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
|
| 857 |
+
),
|
| 858 |
+
conv_skip_rate,
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 862 |
+
ff2_skip_rate = 0.0
|
| 863 |
+
else:
|
| 864 |
+
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
|
| 865 |
+
src = src + self.sequence_dropout(
|
| 866 |
+
self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
# bypass in the middle of the layer.
|
| 870 |
+
src = self.bypass_mid(src_orig, src)
|
| 871 |
+
|
| 872 |
+
self_attn = self.self_attn2(src, attn_weights)
|
| 873 |
+
|
| 874 |
+
src = src + (
|
| 875 |
+
self_attn
|
| 876 |
+
if self_attn_dropout_mask is None
|
| 877 |
+
else self_attn * self_attn_dropout_mask
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 881 |
+
conv_skip_rate = 0.0
|
| 882 |
+
else:
|
| 883 |
+
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
| 884 |
+
src = src + self.sequence_dropout(
|
| 885 |
+
self.conv_module2(
|
| 886 |
+
src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
|
| 887 |
+
),
|
| 888 |
+
conv_skip_rate,
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 892 |
+
ff3_skip_rate = 0.0
|
| 893 |
+
else:
|
| 894 |
+
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
|
| 895 |
+
src = src + self.sequence_dropout(
|
| 896 |
+
self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
src = self.balancer1(src)
|
| 900 |
+
src = self.norm(src)
|
| 901 |
+
|
| 902 |
+
src = self.bypass(src_orig, src)
|
| 903 |
+
|
| 904 |
+
src = self.balancer2(src)
|
| 905 |
+
src = self.whiten(src)
|
| 906 |
+
|
| 907 |
+
return src
|
| 908 |
+
|
| 909 |
+
def streaming_forward(
|
| 910 |
+
self,
|
| 911 |
+
src: Tensor,
|
| 912 |
+
pos_emb: Tensor,
|
| 913 |
+
cached_key: Tensor,
|
| 914 |
+
cached_nonlin_attn: Tensor,
|
| 915 |
+
cached_val1: Tensor,
|
| 916 |
+
cached_val2: Tensor,
|
| 917 |
+
cached_conv1: Tensor,
|
| 918 |
+
cached_conv2: Tensor,
|
| 919 |
+
left_context_len: int,
|
| 920 |
+
src_key_padding_mask: Tensor,
|
| 921 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
| 922 |
+
"""Pass the input through the encoder layer in streaming forward mode.
|
| 923 |
+
|
| 924 |
+
Args:
|
| 925 |
+
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
| 926 |
+
pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or
|
| 927 |
+
(batch_size, left_context_len+2*seq_len-1, pos_emb_dim)
|
| 928 |
+
cached_key: cached attention key tensor of left context,
|
| 929 |
+
of shape (left_context_len, batch_size, key_dim)
|
| 930 |
+
cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape
|
| 931 |
+
(num_heads, batch_size, left_context_len, head_dim)
|
| 932 |
+
cached_val1: cached left context for the first attention module,
|
| 933 |
+
of shape (left_context_len, batch_size, value_dim)
|
| 934 |
+
cached_val2: cached left context for the second attention module,
|
| 935 |
+
of shape (left_context_len, batch_size, value_dim)
|
| 936 |
+
cached_conv1: cached left context for the first convolution module,
|
| 937 |
+
of shape (batch_size, channels, left_pad)
|
| 938 |
+
cached_conv2: cached left context for the second convolution module,
|
| 939 |
+
of shape (batch_size, channels, left_pad)
|
| 940 |
+
left_context_len: number of left context frames.
|
| 941 |
+
src_key_padding_mask: the mask for padding, of shape
|
| 942 |
+
(batch_size, left_context_len + seq_len); True means masked position.
|
| 943 |
+
May be None.
|
| 944 |
+
|
| 945 |
+
Returns:
|
| 946 |
+
- x, with the same shape as src
|
| 947 |
+
- updated cached_key
|
| 948 |
+
- updated cached_nonlin_attn
|
| 949 |
+
- updated cached_val1
|
| 950 |
+
- updated cached_val2
|
| 951 |
+
- updated cached_conv1
|
| 952 |
+
- updated cached_conv2
|
| 953 |
+
"""
|
| 954 |
+
src_orig = src
|
| 955 |
+
|
| 956 |
+
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
| 957 |
+
attn_weights, cached_key = self.self_attn_weights.streaming_forward(
|
| 958 |
+
src,
|
| 959 |
+
pos_emb=pos_emb,
|
| 960 |
+
cached_key=cached_key,
|
| 961 |
+
left_context_len=left_context_len,
|
| 962 |
+
key_padding_mask=src_key_padding_mask,
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
src = src + self.feed_forward1(src)
|
| 966 |
+
|
| 967 |
+
na, cached_nonlin_attn = self.nonlin_attention.streaming_forward(
|
| 968 |
+
src,
|
| 969 |
+
attn_weights[0:1],
|
| 970 |
+
cached_x=cached_nonlin_attn,
|
| 971 |
+
left_context_len=left_context_len,
|
| 972 |
+
)
|
| 973 |
+
src = src + na
|
| 974 |
+
|
| 975 |
+
self_attn, cached_val1 = self.self_attn1.streaming_forward(
|
| 976 |
+
src,
|
| 977 |
+
attn_weights=attn_weights,
|
| 978 |
+
cached_val=cached_val1,
|
| 979 |
+
left_context_len=left_context_len,
|
| 980 |
+
)
|
| 981 |
+
src = src + self_attn
|
| 982 |
+
|
| 983 |
+
src_conv, cached_conv1 = self.conv_module1.streaming_forward(
|
| 984 |
+
src,
|
| 985 |
+
cache=cached_conv1,
|
| 986 |
+
src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
|
| 987 |
+
)
|
| 988 |
+
src = src + src_conv
|
| 989 |
+
|
| 990 |
+
src = src + self.feed_forward2(src)
|
| 991 |
+
|
| 992 |
+
# bypass in the middle of the layer.
|
| 993 |
+
src = self.bypass_mid(src_orig, src)
|
| 994 |
+
|
| 995 |
+
self_attn, cached_val2 = self.self_attn2.streaming_forward(
|
| 996 |
+
src,
|
| 997 |
+
attn_weights=attn_weights,
|
| 998 |
+
cached_val=cached_val2,
|
| 999 |
+
left_context_len=left_context_len,
|
| 1000 |
+
)
|
| 1001 |
+
src = src + self_attn
|
| 1002 |
+
|
| 1003 |
+
src_conv, cached_conv2 = self.conv_module2.streaming_forward(
|
| 1004 |
+
src,
|
| 1005 |
+
cache=cached_conv2,
|
| 1006 |
+
src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
|
| 1007 |
+
)
|
| 1008 |
+
src = src + src_conv
|
| 1009 |
+
|
| 1010 |
+
src = src + self.feed_forward3(src)
|
| 1011 |
+
|
| 1012 |
+
src = self.norm(src)
|
| 1013 |
+
|
| 1014 |
+
src = self.bypass(src_orig, src)
|
| 1015 |
+
|
| 1016 |
+
return (
|
| 1017 |
+
src,
|
| 1018 |
+
cached_key,
|
| 1019 |
+
cached_nonlin_attn,
|
| 1020 |
+
cached_val1,
|
| 1021 |
+
cached_val2,
|
| 1022 |
+
cached_conv1,
|
| 1023 |
+
cached_conv2,
|
| 1024 |
+
)
|
| 1025 |
+
|
| 1026 |
+
|
| 1027 |
+
class Zipformer2Encoder(nn.Module):
|
| 1028 |
+
r"""Zipformer2Encoder is a stack of N encoder layers
|
| 1029 |
+
|
| 1030 |
+
Args:
|
| 1031 |
+
encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
|
| 1032 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
| 1033 |
+
pos_dim: the dimension for the relative positional encoding
|
| 1034 |
+
|
| 1035 |
+
Examples::
|
| 1036 |
+
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
| 1037 |
+
>>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
|
| 1038 |
+
>>> src = torch.rand(10, 32, 512)
|
| 1039 |
+
>>> out = zipformer_encoder(src)
|
| 1040 |
+
"""
|
| 1041 |
+
|
| 1042 |
+
def __init__(
|
| 1043 |
+
self,
|
| 1044 |
+
encoder_layer: nn.Module,
|
| 1045 |
+
num_layers: int,
|
| 1046 |
+
pos_dim: int,
|
| 1047 |
+
dropout: float,
|
| 1048 |
+
warmup_begin: float,
|
| 1049 |
+
warmup_end: float,
|
| 1050 |
+
initial_layerdrop_rate: float = 0.5,
|
| 1051 |
+
final_layerdrop_rate: float = 0.05,
|
| 1052 |
+
) -> None:
|
| 1053 |
+
super().__init__()
|
| 1054 |
+
self.encoder_pos = CompactRelPositionalEncoding(
|
| 1055 |
+
pos_dim, dropout_rate=0.15, length_factor=1.0
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
+
self.layers = nn.ModuleList(
|
| 1059 |
+
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
| 1060 |
+
)
|
| 1061 |
+
self.num_layers = num_layers
|
| 1062 |
+
|
| 1063 |
+
assert 0 <= warmup_begin <= warmup_end
|
| 1064 |
+
|
| 1065 |
+
delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
|
| 1066 |
+
cur_begin = warmup_begin # interpreted as a training batch index
|
| 1067 |
+
for i in range(num_layers):
|
| 1068 |
+
cur_end = cur_begin + delta
|
| 1069 |
+
self.layers[i].bypass.skip_rate = ScheduledFloat(
|
| 1070 |
+
(cur_begin, initial_layerdrop_rate),
|
| 1071 |
+
(cur_end, final_layerdrop_rate),
|
| 1072 |
+
default=0.0,
|
| 1073 |
+
)
|
| 1074 |
+
cur_begin = cur_end
|
| 1075 |
+
|
| 1076 |
+
def forward(
|
| 1077 |
+
self,
|
| 1078 |
+
src: Tensor,
|
| 1079 |
+
chunk_size: int = -1,
|
| 1080 |
+
feature_mask: Union[Tensor, float] = 1.0,
|
| 1081 |
+
attn_mask: Optional[Tensor] = None,
|
| 1082 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 1083 |
+
) -> Tensor:
|
| 1084 |
+
r"""Pass the input through the encoder layers in turn.
|
| 1085 |
+
|
| 1086 |
+
Args:
|
| 1087 |
+
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
| 1088 |
+
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
| 1089 |
+
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
| 1090 |
+
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
| 1091 |
+
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
| 1092 |
+
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
| 1093 |
+
True means masked position. May be None.
|
| 1094 |
+
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
| 1095 |
+
masked position. May be None.
|
| 1096 |
+
|
| 1097 |
+
Returns: a Tensor with the same shape as src.
|
| 1098 |
+
"""
|
| 1099 |
+
pos_emb = self.encoder_pos(src)
|
| 1100 |
+
output = src
|
| 1101 |
+
|
| 1102 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1103 |
+
output = output * feature_mask
|
| 1104 |
+
|
| 1105 |
+
for i, mod in enumerate(self.layers):
|
| 1106 |
+
output = mod(
|
| 1107 |
+
output,
|
| 1108 |
+
pos_emb,
|
| 1109 |
+
chunk_size=chunk_size,
|
| 1110 |
+
attn_mask=attn_mask,
|
| 1111 |
+
src_key_padding_mask=src_key_padding_mask,
|
| 1112 |
+
)
|
| 1113 |
+
|
| 1114 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1115 |
+
output = output * feature_mask
|
| 1116 |
+
|
| 1117 |
+
return output
|
| 1118 |
+
|
| 1119 |
+
def streaming_forward(
|
| 1120 |
+
self,
|
| 1121 |
+
src: Tensor,
|
| 1122 |
+
states: List[Tensor],
|
| 1123 |
+
left_context_len: int,
|
| 1124 |
+
src_key_padding_mask: Tensor,
|
| 1125 |
+
) -> Tuple[Tensor, List[Tensor]]:
|
| 1126 |
+
r"""Pass the input through the encoder layers in turn.
|
| 1127 |
+
|
| 1128 |
+
Args:
|
| 1129 |
+
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
| 1130 |
+
states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
|
| 1131 |
+
(cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
| 1132 |
+
left_context_len: Number of left context frames.
|
| 1133 |
+
src_key_padding_mask: the mask for padding, of shape
|
| 1134 |
+
(batch_size, left_context_len + seq_len); True means masked position.
|
| 1135 |
+
May be None.
|
| 1136 |
+
|
| 1137 |
+
Returns:
|
| 1138 |
+
- output, a Tensor with the same shape as src.
|
| 1139 |
+
- updated states
|
| 1140 |
+
"""
|
| 1141 |
+
pos_emb = self.encoder_pos(src, left_context_len)
|
| 1142 |
+
output = src
|
| 1143 |
+
|
| 1144 |
+
new_states = []
|
| 1145 |
+
for i, mod in enumerate(self.layers):
|
| 1146 |
+
(
|
| 1147 |
+
cached_key,
|
| 1148 |
+
cached_nonlin_attn,
|
| 1149 |
+
cached_val1,
|
| 1150 |
+
cached_val2,
|
| 1151 |
+
cached_conv1,
|
| 1152 |
+
cached_conv2,
|
| 1153 |
+
) = states[i * 6 : (i + 1) * 6]
|
| 1154 |
+
(
|
| 1155 |
+
output,
|
| 1156 |
+
new_cached_key,
|
| 1157 |
+
new_cached_nonlin_attn,
|
| 1158 |
+
new_cached_val1,
|
| 1159 |
+
new_cached_val2,
|
| 1160 |
+
new_cached_conv1,
|
| 1161 |
+
new_cached_conv2,
|
| 1162 |
+
) = mod.streaming_forward(
|
| 1163 |
+
output,
|
| 1164 |
+
pos_emb,
|
| 1165 |
+
cached_key=cached_key,
|
| 1166 |
+
cached_nonlin_attn=cached_nonlin_attn,
|
| 1167 |
+
cached_val1=cached_val1,
|
| 1168 |
+
cached_val2=cached_val2,
|
| 1169 |
+
cached_conv1=cached_conv1,
|
| 1170 |
+
cached_conv2=cached_conv2,
|
| 1171 |
+
left_context_len=left_context_len,
|
| 1172 |
+
src_key_padding_mask=src_key_padding_mask,
|
| 1173 |
+
)
|
| 1174 |
+
new_states += [
|
| 1175 |
+
new_cached_key,
|
| 1176 |
+
new_cached_nonlin_attn,
|
| 1177 |
+
new_cached_val1,
|
| 1178 |
+
new_cached_val2,
|
| 1179 |
+
new_cached_conv1,
|
| 1180 |
+
new_cached_conv2,
|
| 1181 |
+
]
|
| 1182 |
+
|
| 1183 |
+
return output, new_states
|
| 1184 |
+
|
| 1185 |
+
|
| 1186 |
+
class BypassModule(nn.Module):
|
| 1187 |
+
"""
|
| 1188 |
+
An nn.Module that implements a learnable bypass scale, and also randomized per-sequence
|
| 1189 |
+
layer-skipping. The bypass is limited during early stages of training to be close to
|
| 1190 |
+
"straight-through", i.e. to not do the bypass operation much initially, in order to
|
| 1191 |
+
force all the modules to learn something.
|
| 1192 |
+
"""
|
| 1193 |
+
|
| 1194 |
+
def __init__(
|
| 1195 |
+
self,
|
| 1196 |
+
embed_dim: int,
|
| 1197 |
+
skip_rate: FloatLike = 0.0,
|
| 1198 |
+
straight_through_rate: FloatLike = 0.0,
|
| 1199 |
+
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
|
| 1200 |
+
scale_max: FloatLike = 1.0,
|
| 1201 |
+
):
|
| 1202 |
+
super().__init__()
|
| 1203 |
+
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
| 1204 |
+
self.skip_rate = copy.deepcopy(skip_rate)
|
| 1205 |
+
self.straight_through_rate = copy.deepcopy(straight_through_rate)
|
| 1206 |
+
self.scale_min = copy.deepcopy(scale_min)
|
| 1207 |
+
self.scale_max = copy.deepcopy(scale_max)
|
| 1208 |
+
|
| 1209 |
+
def _get_bypass_scale(self, batch_size: int):
|
| 1210 |
+
# returns bypass-scale of shape (num_channels,),
|
| 1211 |
+
# or (batch_size, num_channels,). This is actually the
|
| 1212 |
+
# scale on the non-residual term, so 0 correponds to bypassing
|
| 1213 |
+
# this module.
|
| 1214 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
| 1215 |
+
return self.bypass_scale
|
| 1216 |
+
else:
|
| 1217 |
+
ans = limit_param_value(
|
| 1218 |
+
self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max)
|
| 1219 |
+
)
|
| 1220 |
+
skip_rate = float(self.skip_rate)
|
| 1221 |
+
if skip_rate != 0.0:
|
| 1222 |
+
mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
|
| 1223 |
+
ans = ans * mask
|
| 1224 |
+
# now ans is of shape (batch_size, num_channels), and is zero for sequences
|
| 1225 |
+
# on which we have randomly chosen to do layer-skipping.
|
| 1226 |
+
straight_through_rate = float(self.straight_through_rate)
|
| 1227 |
+
if straight_through_rate != 0.0:
|
| 1228 |
+
mask = (
|
| 1229 |
+
torch.rand((batch_size, 1), device=ans.device)
|
| 1230 |
+
< straight_through_rate
|
| 1231 |
+
)
|
| 1232 |
+
ans = torch.maximum(ans, mask.to(ans.dtype))
|
| 1233 |
+
return ans
|
| 1234 |
+
|
| 1235 |
+
def forward(self, src_orig: Tensor, src: Tensor):
|
| 1236 |
+
"""
|
| 1237 |
+
Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
|
| 1238 |
+
Returns: something with the same shape as src and src_orig
|
| 1239 |
+
"""
|
| 1240 |
+
bypass_scale = self._get_bypass_scale(src.shape[1])
|
| 1241 |
+
return src_orig + (src - src_orig) * bypass_scale
|
| 1242 |
+
|
| 1243 |
+
|
| 1244 |
+
class DownsampledZipformer2Encoder(nn.Module):
|
| 1245 |
+
r"""
|
| 1246 |
+
DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
|
| 1247 |
+
after convolutional downsampling, and then upsampled again at the output, and combined
|
| 1248 |
+
with the origin input, so that the output has the same shape as the input.
|
| 1249 |
+
"""
|
| 1250 |
+
|
| 1251 |
+
def __init__(
|
| 1252 |
+
self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike
|
| 1253 |
+
):
|
| 1254 |
+
super(DownsampledZipformer2Encoder, self).__init__()
|
| 1255 |
+
self.downsample_factor = downsample
|
| 1256 |
+
self.downsample = SimpleDownsample(dim, downsample, dropout)
|
| 1257 |
+
self.num_layers = encoder.num_layers
|
| 1258 |
+
self.encoder = encoder
|
| 1259 |
+
self.upsample = SimpleUpsample(dim, downsample)
|
| 1260 |
+
self.out_combiner = BypassModule(dim, straight_through_rate=0)
|
| 1261 |
+
|
| 1262 |
+
def forward(
|
| 1263 |
+
self,
|
| 1264 |
+
src: Tensor,
|
| 1265 |
+
chunk_size: int = -1,
|
| 1266 |
+
feature_mask: Union[Tensor, float] = 1.0,
|
| 1267 |
+
attn_mask: Optional[Tensor] = None,
|
| 1268 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 1269 |
+
) -> Tensor:
|
| 1270 |
+
r"""Downsample, go through encoder, upsample.
|
| 1271 |
+
|
| 1272 |
+
Args:
|
| 1273 |
+
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
| 1274 |
+
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
| 1275 |
+
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
| 1276 |
+
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
| 1277 |
+
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
| 1278 |
+
True means masked position. May be None.
|
| 1279 |
+
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
| 1280 |
+
masked position. May be None.
|
| 1281 |
+
|
| 1282 |
+
Returns: a Tensor with the same shape as src.
|
| 1283 |
+
"""
|
| 1284 |
+
src_orig = src
|
| 1285 |
+
src = self.downsample(src)
|
| 1286 |
+
ds = self.downsample_factor
|
| 1287 |
+
if attn_mask is not None:
|
| 1288 |
+
attn_mask = attn_mask[::ds, ::ds]
|
| 1289 |
+
|
| 1290 |
+
src = self.encoder(
|
| 1291 |
+
src,
|
| 1292 |
+
chunk_size=chunk_size // ds,
|
| 1293 |
+
feature_mask=feature_mask,
|
| 1294 |
+
attn_mask=attn_mask,
|
| 1295 |
+
src_key_padding_mask=src_key_padding_mask,
|
| 1296 |
+
)
|
| 1297 |
+
src = self.upsample(src)
|
| 1298 |
+
# remove any extra frames that are not a multiple of downsample_factor
|
| 1299 |
+
src = src[: src_orig.shape[0]]
|
| 1300 |
+
|
| 1301 |
+
return self.out_combiner(src_orig, src)
|
| 1302 |
+
|
| 1303 |
+
def streaming_forward(
|
| 1304 |
+
self,
|
| 1305 |
+
src: Tensor,
|
| 1306 |
+
states: List[Tensor],
|
| 1307 |
+
left_context_len: int,
|
| 1308 |
+
src_key_padding_mask: Tensor,
|
| 1309 |
+
) -> Tuple[Tensor, List[Tensor]]:
|
| 1310 |
+
r"""Downsample, go through encoder, upsample, in streaming forward mode.
|
| 1311 |
+
|
| 1312 |
+
Args:
|
| 1313 |
+
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
| 1314 |
+
states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
|
| 1315 |
+
(cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
| 1316 |
+
left_context_len: Number of left context frames.
|
| 1317 |
+
src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len);
|
| 1318 |
+
True means masked position. May be None.
|
| 1319 |
+
|
| 1320 |
+
Returns:
|
| 1321 |
+
- output, a Tensor with the same shape as src.
|
| 1322 |
+
- updated states
|
| 1323 |
+
"""
|
| 1324 |
+
src_orig = src
|
| 1325 |
+
src = self.downsample(src)
|
| 1326 |
+
|
| 1327 |
+
src, new_states = self.encoder.streaming_forward(
|
| 1328 |
+
src,
|
| 1329 |
+
states=states,
|
| 1330 |
+
left_context_len=left_context_len,
|
| 1331 |
+
src_key_padding_mask=src_key_padding_mask,
|
| 1332 |
+
)
|
| 1333 |
+
src = self.upsample(src)
|
| 1334 |
+
# remove any extra frames that are not a multiple of downsample_factor
|
| 1335 |
+
src = src[: src_orig.shape[0]]
|
| 1336 |
+
|
| 1337 |
+
return self.out_combiner(src_orig, src), new_states
|
| 1338 |
+
|
| 1339 |
+
|
| 1340 |
+
class SimpleDownsample(torch.nn.Module):
|
| 1341 |
+
"""
|
| 1342 |
+
Does downsampling with attention, by weighted sum, and a projection..
|
| 1343 |
+
"""
|
| 1344 |
+
|
| 1345 |
+
def __init__(self, channels: int, downsample: int, dropout: FloatLike):
|
| 1346 |
+
super(SimpleDownsample, self).__init__()
|
| 1347 |
+
|
| 1348 |
+
self.bias = nn.Parameter(torch.zeros(downsample))
|
| 1349 |
+
|
| 1350 |
+
self.name = None # will be set from training code
|
| 1351 |
+
self.dropout = copy.deepcopy(dropout)
|
| 1352 |
+
|
| 1353 |
+
self.downsample = downsample
|
| 1354 |
+
|
| 1355 |
+
def forward(self, src: Tensor) -> Tensor:
|
| 1356 |
+
"""
|
| 1357 |
+
x: (seq_len, batch_size, in_channels)
|
| 1358 |
+
Returns a tensor of shape
|
| 1359 |
+
( (seq_len+downsample-1)//downsample, batch_size, channels)
|
| 1360 |
+
"""
|
| 1361 |
+
(seq_len, batch_size, in_channels) = src.shape
|
| 1362 |
+
ds = self.downsample
|
| 1363 |
+
d_seq_len = (seq_len + ds - 1) // ds
|
| 1364 |
+
|
| 1365 |
+
# Pad to an exact multiple of self.downsample
|
| 1366 |
+
# right-pad src, repeating the last element.
|
| 1367 |
+
pad = d_seq_len * ds - seq_len
|
| 1368 |
+
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
|
| 1369 |
+
src = torch.cat((src, src_extra), dim=0)
|
| 1370 |
+
assert src.shape[0] == d_seq_len * ds
|
| 1371 |
+
|
| 1372 |
+
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
| 1373 |
+
|
| 1374 |
+
weights = self.bias.softmax(dim=0)
|
| 1375 |
+
# weights: (downsample, 1, 1)
|
| 1376 |
+
weights = weights.unsqueeze(-1).unsqueeze(-1)
|
| 1377 |
+
|
| 1378 |
+
# ans1 is the first `in_channels` channels of the output
|
| 1379 |
+
ans = (src * weights).sum(dim=1)
|
| 1380 |
+
|
| 1381 |
+
return ans
|
| 1382 |
+
|
| 1383 |
+
|
| 1384 |
+
class SimpleUpsample(torch.nn.Module):
|
| 1385 |
+
"""
|
| 1386 |
+
A very simple form of upsampling that mostly just repeats the input, but
|
| 1387 |
+
also adds a position-specific bias.
|
| 1388 |
+
"""
|
| 1389 |
+
|
| 1390 |
+
def __init__(self, num_channels: int, upsample: int):
|
| 1391 |
+
super(SimpleUpsample, self).__init__()
|
| 1392 |
+
self.upsample = upsample
|
| 1393 |
+
|
| 1394 |
+
def forward(self, src: Tensor) -> Tensor:
|
| 1395 |
+
"""
|
| 1396 |
+
x: (seq_len, batch_size, num_channels)
|
| 1397 |
+
Returns a tensor of shape
|
| 1398 |
+
( (seq_len*upsample), batch_size, num_channels)
|
| 1399 |
+
"""
|
| 1400 |
+
upsample = self.upsample
|
| 1401 |
+
(seq_len, batch_size, num_channels) = src.shape
|
| 1402 |
+
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
|
| 1403 |
+
src = src.reshape(seq_len * upsample, batch_size, num_channels)
|
| 1404 |
+
return src
|
| 1405 |
+
|
| 1406 |
+
|
| 1407 |
+
class CompactRelPositionalEncoding(torch.nn.Module):
|
| 1408 |
+
"""
|
| 1409 |
+
Relative positional encoding module. This version is "compact" meaning it is able to encode
|
| 1410 |
+
the important information about the relative position in a relatively small number of dimensions.
|
| 1411 |
+
The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001)
|
| 1412 |
+
make very little difference to the embedding. Such differences were potentially important
|
| 1413 |
+
when encoding absolute position, but not important when encoding relative position because there
|
| 1414 |
+
is now no need to compare two large offsets with each other.
|
| 1415 |
+
|
| 1416 |
+
Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval
|
| 1417 |
+
using the atan() function, before doing the fourier transform of that fixed interval. The
|
| 1418 |
+
atan() function would compress the "long tails" too small,
|
| 1419 |
+
making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic
|
| 1420 |
+
function to compress large offsets to a smaller range before applying atan().
|
| 1421 |
+
Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long
|
| 1422 |
+
as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)
|
| 1423 |
+
|
| 1424 |
+
|
| 1425 |
+
Args:
|
| 1426 |
+
embed_dim: Embedding dimension.
|
| 1427 |
+
dropout_rate: Dropout rate.
|
| 1428 |
+
max_len: Maximum input length: just a heuristic for initialization.
|
| 1429 |
+
length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
|
| 1430 |
+
less weight to small differences of offset near the origin.
|
| 1431 |
+
"""
|
| 1432 |
+
|
| 1433 |
+
def __init__(
|
| 1434 |
+
self,
|
| 1435 |
+
embed_dim: int,
|
| 1436 |
+
dropout_rate: FloatLike,
|
| 1437 |
+
max_len: int = 1000,
|
| 1438 |
+
length_factor: float = 1.0,
|
| 1439 |
+
) -> None:
|
| 1440 |
+
"""Construct a CompactRelPositionalEncoding object."""
|
| 1441 |
+
super(CompactRelPositionalEncoding, self).__init__()
|
| 1442 |
+
self.embed_dim = embed_dim
|
| 1443 |
+
assert embed_dim % 2 == 0
|
| 1444 |
+
self.dropout = Dropout2(dropout_rate)
|
| 1445 |
+
self.pe = None
|
| 1446 |
+
assert length_factor >= 1.0
|
| 1447 |
+
self.length_factor = length_factor
|
| 1448 |
+
self.extend_pe(torch.tensor(0.0).expand(max_len))
|
| 1449 |
+
|
| 1450 |
+
def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
|
| 1451 |
+
"""Reset the positional encodings."""
|
| 1452 |
+
T = x.size(0) + left_context_len
|
| 1453 |
+
|
| 1454 |
+
if self.pe is not None:
|
| 1455 |
+
# self.pe contains both positive and negative parts
|
| 1456 |
+
# the length of self.pe is 2 * input_len - 1
|
| 1457 |
+
if self.pe.size(0) >= T * 2 - 1:
|
| 1458 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 1459 |
+
return
|
| 1460 |
+
|
| 1461 |
+
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
| 1462 |
+
x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
|
| 1463 |
+
|
| 1464 |
+
freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
|
| 1465 |
+
|
| 1466 |
+
# `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution
|
| 1467 |
+
# for small time offsets but less resolution for large time offsets.
|
| 1468 |
+
compression_length = self.embed_dim**0.5
|
| 1469 |
+
# x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity;
|
| 1470 |
+
# but it does so more slowly than T for large absolute values of T.
|
| 1471 |
+
# The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
|
| 1472 |
+
# is important.
|
| 1473 |
+
x_compressed = (
|
| 1474 |
+
compression_length
|
| 1475 |
+
* x.sign()
|
| 1476 |
+
* ((x.abs() + compression_length).log() - math.log(compression_length))
|
| 1477 |
+
)
|
| 1478 |
+
|
| 1479 |
+
# if self.length_factor == 1.0, then length_scale is chosen so that the
|
| 1480 |
+
# FFT can exactly separate points close to the origin (T == 0). So this
|
| 1481 |
+
# part of the formulation is not really heuristic.
|
| 1482 |
+
# But empirically, for ASR at least, length_factor > 1.0 seems to work better.
|
| 1483 |
+
length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
|
| 1484 |
+
|
| 1485 |
+
# note for machine implementations: if atan is not available, we can use:
|
| 1486 |
+
# x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
|
| 1487 |
+
# check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
|
| 1488 |
+
x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
|
| 1489 |
+
|
| 1490 |
+
cosines = (x_atan * freqs).cos()
|
| 1491 |
+
sines = (x_atan * freqs).sin()
|
| 1492 |
+
|
| 1493 |
+
pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
|
| 1494 |
+
pe[:, 0::2] = cosines
|
| 1495 |
+
pe[:, 1::2] = sines
|
| 1496 |
+
pe[:, -1] = 1.0 # for bias.
|
| 1497 |
+
|
| 1498 |
+
self.pe = pe.to(dtype=x.dtype)
|
| 1499 |
+
|
| 1500 |
+
def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
|
| 1501 |
+
"""Create positional encoding.
|
| 1502 |
+
|
| 1503 |
+
Args:
|
| 1504 |
+
x (Tensor): Input tensor (time, batch, `*`).
|
| 1505 |
+
left_context_len: (int): Length of cached left context.
|
| 1506 |
+
|
| 1507 |
+
Returns:
|
| 1508 |
+
positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
|
| 1509 |
+
"""
|
| 1510 |
+
self.extend_pe(x, left_context_len)
|
| 1511 |
+
x_size_left = x.size(0) + left_context_len
|
| 1512 |
+
# length of positive side: x.size(0) + left_context_len
|
| 1513 |
+
# length of negative side: x.size(0)
|
| 1514 |
+
pos_emb = self.pe[
|
| 1515 |
+
self.pe.size(0) // 2
|
| 1516 |
+
- x_size_left
|
| 1517 |
+
+ 1 : self.pe.size(0) // 2 # noqa E203
|
| 1518 |
+
+ x.size(0),
|
| 1519 |
+
:,
|
| 1520 |
+
]
|
| 1521 |
+
pos_emb = pos_emb.unsqueeze(0)
|
| 1522 |
+
return self.dropout(pos_emb)
|
| 1523 |
+
|
| 1524 |
+
|
| 1525 |
+
class RelPositionMultiheadAttentionWeights(nn.Module):
|
| 1526 |
+
r"""Module that computes multi-head attention weights with relative position encoding.
|
| 1527 |
+
Various other modules consume the resulting attention weights: see, for example, the
|
| 1528 |
+
SimpleAttention module which allows you to compute conventional attention.
|
| 1529 |
+
|
| 1530 |
+
This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context",
|
| 1531 |
+
we have to write up the differences.
|
| 1532 |
+
|
| 1533 |
+
|
| 1534 |
+
Args:
|
| 1535 |
+
embed_dim: number of channels at the input to this module, e.g. 256
|
| 1536 |
+
pos_dim: dimension of the positional encoding vectors, e.g. 128.
|
| 1537 |
+
num_heads: number of heads to compute weights for, e.g. 8
|
| 1538 |
+
query_head_dim: dimension of the query (and key), per head. e.g. 24.
|
| 1539 |
+
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
|
| 1540 |
+
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
| 1541 |
+
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
|
| 1542 |
+
any given call to forward(), in training time.
|
| 1543 |
+
"""
|
| 1544 |
+
|
| 1545 |
+
def __init__(
|
| 1546 |
+
self,
|
| 1547 |
+
embed_dim: int,
|
| 1548 |
+
pos_dim: int,
|
| 1549 |
+
num_heads: int,
|
| 1550 |
+
query_head_dim: int,
|
| 1551 |
+
pos_head_dim: int,
|
| 1552 |
+
dropout: float = 0.0,
|
| 1553 |
+
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
|
| 1554 |
+
) -> None:
|
| 1555 |
+
super().__init__()
|
| 1556 |
+
self.embed_dim = embed_dim
|
| 1557 |
+
self.num_heads = num_heads
|
| 1558 |
+
self.query_head_dim = query_head_dim
|
| 1559 |
+
self.pos_head_dim = pos_head_dim
|
| 1560 |
+
self.dropout = dropout
|
| 1561 |
+
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
| 1562 |
+
self.name = None # will be overwritten in training code; for diagnostics.
|
| 1563 |
+
|
| 1564 |
+
key_head_dim = query_head_dim
|
| 1565 |
+
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
|
| 1566 |
+
|
| 1567 |
+
# the initial_scale is supposed to take over the "scaling" factor of
|
| 1568 |
+
# head_dim ** -0.5 that has been used in previous forms of attention,
|
| 1569 |
+
# dividing it between the query and key. Note: this module is intended
|
| 1570 |
+
# to be used with the ScaledAdam optimizer; with most other optimizers,
|
| 1571 |
+
# it would be necessary to apply the scaling factor in the forward function.
|
| 1572 |
+
self.in_proj = ScaledLinear(
|
| 1573 |
+
embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
|
| 1574 |
+
)
|
| 1575 |
+
|
| 1576 |
+
self.whiten_keys = Whiten(
|
| 1577 |
+
num_groups=num_heads,
|
| 1578 |
+
whitening_limit=_whitening_schedule(3.0),
|
| 1579 |
+
prob=(0.025, 0.25),
|
| 1580 |
+
grad_scale=0.025,
|
| 1581 |
+
)
|
| 1582 |
+
|
| 1583 |
+
# add a balancer for the keys that runs with very small probability, and
|
| 1584 |
+
# tries to enforce that all dimensions have mean around zero. The
|
| 1585 |
+
# weights produced by this module are invariant to adding a constant to
|
| 1586 |
+
# the keys, so the derivative of the bias is mathematically zero; but
|
| 1587 |
+
# due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
|
| 1588 |
+
# bias because the small numerical roundoff tends to have a non-random
|
| 1589 |
+
# sign. This module is intended to prevent that. Use a very small
|
| 1590 |
+
# probability; that should be suffixient to fix the problem.
|
| 1591 |
+
self.balance_keys = Balancer(
|
| 1592 |
+
key_head_dim * num_heads,
|
| 1593 |
+
channel_dim=-1,
|
| 1594 |
+
min_positive=0.4,
|
| 1595 |
+
max_positive=0.6,
|
| 1596 |
+
min_abs=0.0,
|
| 1597 |
+
max_abs=100.0,
|
| 1598 |
+
prob=0.025,
|
| 1599 |
+
)
|
| 1600 |
+
|
| 1601 |
+
# linear transformation for positional encoding.
|
| 1602 |
+
self.linear_pos = ScaledLinear(
|
| 1603 |
+
pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
|
| 1604 |
+
)
|
| 1605 |
+
|
| 1606 |
+
# the following are for diagnosics only, see --print-diagnostics option
|
| 1607 |
+
self.copy_pos_query = Identity()
|
| 1608 |
+
self.copy_query = Identity()
|
| 1609 |
+
|
| 1610 |
+
def forward(
|
| 1611 |
+
self,
|
| 1612 |
+
x: Tensor,
|
| 1613 |
+
pos_emb: Tensor,
|
| 1614 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 1615 |
+
attn_mask: Optional[Tensor] = None,
|
| 1616 |
+
) -> Tensor:
|
| 1617 |
+
r"""
|
| 1618 |
+
Args:
|
| 1619 |
+
x: input of shape (seq_len, batch_size, embed_dim)
|
| 1620 |
+
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
|
| 1621 |
+
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
|
| 1622 |
+
are True in this mask will be ignored as sources in the attention weighting.
|
| 1623 |
+
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
|
| 1624 |
+
interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
|
| 1625 |
+
saying which positions are allowed to attend to which other positions.
|
| 1626 |
+
Returns:
|
| 1627 |
+
a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
|
| 1628 |
+
interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
|
| 1629 |
+
"""
|
| 1630 |
+
x = self.in_proj(x)
|
| 1631 |
+
query_head_dim = self.query_head_dim
|
| 1632 |
+
pos_head_dim = self.pos_head_dim
|
| 1633 |
+
num_heads = self.num_heads
|
| 1634 |
+
|
| 1635 |
+
seq_len, batch_size, _ = x.shape
|
| 1636 |
+
|
| 1637 |
+
query_dim = query_head_dim * num_heads
|
| 1638 |
+
|
| 1639 |
+
# self-attention
|
| 1640 |
+
q = x[..., 0:query_dim]
|
| 1641 |
+
k = x[..., query_dim : 2 * query_dim]
|
| 1642 |
+
# p is the position-encoding query
|
| 1643 |
+
p = x[..., 2 * query_dim :]
|
| 1644 |
+
assert p.shape[-1] == num_heads * pos_head_dim
|
| 1645 |
+
|
| 1646 |
+
q = self.copy_query(q) # for diagnostics only, does nothing.
|
| 1647 |
+
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
|
| 1648 |
+
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
|
| 1649 |
+
|
| 1650 |
+
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
| 1651 |
+
p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
|
| 1652 |
+
k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
| 1653 |
+
|
| 1654 |
+
# time1 refers to target, time2 refers to source.
|
| 1655 |
+
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
|
| 1656 |
+
p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
|
| 1657 |
+
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
|
| 1658 |
+
|
| 1659 |
+
attn_scores = torch.matmul(q, k)
|
| 1660 |
+
|
| 1661 |
+
use_pos_scores = False
|
| 1662 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1663 |
+
# We can't put random.random() in the same line
|
| 1664 |
+
use_pos_scores = True
|
| 1665 |
+
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
| 1666 |
+
use_pos_scores = True
|
| 1667 |
+
|
| 1668 |
+
if use_pos_scores:
|
| 1669 |
+
pos_emb = self.linear_pos(pos_emb)
|
| 1670 |
+
seq_len2 = 2 * seq_len - 1
|
| 1671 |
+
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
| 1672 |
+
2, 0, 3, 1
|
| 1673 |
+
)
|
| 1674 |
+
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
| 1675 |
+
|
| 1676 |
+
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
| 1677 |
+
# [where seq_len2 represents relative position.]
|
| 1678 |
+
pos_scores = torch.matmul(p, pos_emb)
|
| 1679 |
+
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
| 1680 |
+
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
| 1681 |
+
# not, but let this code define which way round it is supposed to be.
|
| 1682 |
+
if torch.jit.is_tracing():
|
| 1683 |
+
(num_heads, batch_size, time1, n) = pos_scores.shape
|
| 1684 |
+
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
| 1685 |
+
cols = torch.arange(seq_len)
|
| 1686 |
+
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
| 1687 |
+
indexes = rows + cols
|
| 1688 |
+
pos_scores = pos_scores.reshape(-1, n)
|
| 1689 |
+
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
| 1690 |
+
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
|
| 1691 |
+
else:
|
| 1692 |
+
pos_scores = pos_scores.as_strided(
|
| 1693 |
+
(num_heads, batch_size, seq_len, seq_len),
|
| 1694 |
+
(
|
| 1695 |
+
pos_scores.stride(0),
|
| 1696 |
+
pos_scores.stride(1),
|
| 1697 |
+
pos_scores.stride(2) - pos_scores.stride(3),
|
| 1698 |
+
pos_scores.stride(3),
|
| 1699 |
+
),
|
| 1700 |
+
storage_offset=pos_scores.stride(3) * (seq_len - 1),
|
| 1701 |
+
)
|
| 1702 |
+
|
| 1703 |
+
attn_scores = attn_scores + pos_scores
|
| 1704 |
+
|
| 1705 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1706 |
+
pass
|
| 1707 |
+
elif self.training and random.random() < 0.1:
|
| 1708 |
+
# This is a harder way of limiting the attention scores to not be
|
| 1709 |
+
# too large. It incurs a penalty if any of them has an absolute
|
| 1710 |
+
# value greater than 50.0. this should be outside the normal range
|
| 1711 |
+
# of the attention scores. We use this mechanism instead of, say,
|
| 1712 |
+
# something added to the loss function involving the entropy,
|
| 1713 |
+
# because once the entropy gets very small gradients through the
|
| 1714 |
+
# softmax can become very small, and we'd get zero derivatives. The
|
| 1715 |
+
# choices of 1.0e-04 as the scale on the penalty makes this
|
| 1716 |
+
# mechanism vulnerable to the absolute scale of the loss function,
|
| 1717 |
+
# but we view this as a failsafe to avoid "implausible" parameter
|
| 1718 |
+
# values rather than a regularization method that should be active
|
| 1719 |
+
# under normal circumstances.
|
| 1720 |
+
attn_scores = penalize_abs_values_gt(
|
| 1721 |
+
attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
|
| 1722 |
+
)
|
| 1723 |
+
|
| 1724 |
+
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
|
| 1725 |
+
|
| 1726 |
+
if attn_mask is not None:
|
| 1727 |
+
assert attn_mask.dtype == torch.bool
|
| 1728 |
+
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
|
| 1729 |
+
# all scores zero. It's important that this be large enough that exp(-1000)
|
| 1730 |
+
# is exactly zero, for reasons related to const_attention_rate, it
|
| 1731 |
+
# compares the final weights with zero.
|
| 1732 |
+
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
| 1733 |
+
|
| 1734 |
+
if key_padding_mask is not None:
|
| 1735 |
+
assert key_padding_mask.shape == (
|
| 1736 |
+
batch_size,
|
| 1737 |
+
seq_len,
|
| 1738 |
+
), key_padding_mask.shape
|
| 1739 |
+
attn_scores = attn_scores.masked_fill(
|
| 1740 |
+
key_padding_mask.unsqueeze(1),
|
| 1741 |
+
-1000,
|
| 1742 |
+
)
|
| 1743 |
+
|
| 1744 |
+
# We use our own version of softmax, defined in scaling.py, which should
|
| 1745 |
+
# save a little of the memory used in backprop by, if we are in
|
| 1746 |
+
# automatic mixed precision mode (amp / autocast), by only storing the
|
| 1747 |
+
# half-precision output for backprop purposes.
|
| 1748 |
+
attn_weights = softmax(attn_scores, dim=-1)
|
| 1749 |
+
|
| 1750 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1751 |
+
pass
|
| 1752 |
+
elif random.random() < 0.001 and not self.training:
|
| 1753 |
+
self._print_attn_entropy(attn_weights)
|
| 1754 |
+
|
| 1755 |
+
attn_weights = nn.functional.dropout(
|
| 1756 |
+
attn_weights, p=self.dropout, training=self.training
|
| 1757 |
+
)
|
| 1758 |
+
|
| 1759 |
+
return attn_weights
|
| 1760 |
+
|
| 1761 |
+
def streaming_forward(
|
| 1762 |
+
self,
|
| 1763 |
+
x: Tensor,
|
| 1764 |
+
pos_emb: Tensor,
|
| 1765 |
+
cached_key: Tensor,
|
| 1766 |
+
left_context_len: int,
|
| 1767 |
+
key_padding_mask: Tensor,
|
| 1768 |
+
) -> Tuple[Tensor, Tensor]:
|
| 1769 |
+
r"""
|
| 1770 |
+
Args:
|
| 1771 |
+
x: input of shape (seq_len, batch_size, embed_dim)
|
| 1772 |
+
pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim)
|
| 1773 |
+
cached_key: cached attention key tensor of left context,
|
| 1774 |
+
of shape (left_context_len, batch_size, key_dim)
|
| 1775 |
+
left_context_len: number of left context frames.
|
| 1776 |
+
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
|
| 1777 |
+
are True in this mask will be ignored as sources in the attention weighting.
|
| 1778 |
+
|
| 1779 |
+
Returns:
|
| 1780 |
+
- attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2),
|
| 1781 |
+
interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
|
| 1782 |
+
- updated cached attention key tensor of left context.
|
| 1783 |
+
"""
|
| 1784 |
+
x = self.in_proj(x)
|
| 1785 |
+
query_head_dim = self.query_head_dim
|
| 1786 |
+
pos_head_dim = self.pos_head_dim
|
| 1787 |
+
num_heads = self.num_heads
|
| 1788 |
+
|
| 1789 |
+
seq_len, batch_size, _ = x.shape
|
| 1790 |
+
|
| 1791 |
+
query_dim = query_head_dim * num_heads
|
| 1792 |
+
|
| 1793 |
+
# self-attention
|
| 1794 |
+
q = x[..., 0:query_dim]
|
| 1795 |
+
k = x[..., query_dim : 2 * query_dim]
|
| 1796 |
+
# p is the position-encoding query
|
| 1797 |
+
p = x[..., 2 * query_dim :]
|
| 1798 |
+
assert p.shape[-1] == num_heads * pos_head_dim
|
| 1799 |
+
|
| 1800 |
+
# Pad cached left contexts
|
| 1801 |
+
assert cached_key.shape[0] == left_context_len, (
|
| 1802 |
+
cached_key.shape[0],
|
| 1803 |
+
left_context_len,
|
| 1804 |
+
)
|
| 1805 |
+
k = torch.cat([cached_key, k], dim=0)
|
| 1806 |
+
# Update cached left contexts
|
| 1807 |
+
cached_key = k[-left_context_len:, ...]
|
| 1808 |
+
|
| 1809 |
+
# The length of key
|
| 1810 |
+
k_len = k.shape[0]
|
| 1811 |
+
|
| 1812 |
+
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
| 1813 |
+
p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
|
| 1814 |
+
k = k.reshape(k_len, batch_size, num_heads, query_head_dim)
|
| 1815 |
+
|
| 1816 |
+
# time1 refers to target, time2 refers to source.
|
| 1817 |
+
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
|
| 1818 |
+
p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
|
| 1819 |
+
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
|
| 1820 |
+
|
| 1821 |
+
attn_scores = torch.matmul(q, k)
|
| 1822 |
+
|
| 1823 |
+
pos_emb = self.linear_pos(pos_emb)
|
| 1824 |
+
seq_len2 = 2 * seq_len - 1 + left_context_len
|
| 1825 |
+
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
| 1826 |
+
2, 0, 3, 1
|
| 1827 |
+
)
|
| 1828 |
+
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
| 1829 |
+
|
| 1830 |
+
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
| 1831 |
+
# [where seq_len2 represents relative position.]
|
| 1832 |
+
pos_scores = torch.matmul(p, pos_emb)
|
| 1833 |
+
|
| 1834 |
+
if torch.jit.is_tracing():
|
| 1835 |
+
(num_heads, batch_size, time1, n) = pos_scores.shape
|
| 1836 |
+
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
| 1837 |
+
cols = torch.arange(k_len)
|
| 1838 |
+
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
| 1839 |
+
indexes = rows + cols
|
| 1840 |
+
pos_scores = pos_scores.reshape(-1, n)
|
| 1841 |
+
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
| 1842 |
+
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
|
| 1843 |
+
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
| 1844 |
+
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
| 1845 |
+
# not, but let this code define which way round it is supposed to be.
|
| 1846 |
+
else:
|
| 1847 |
+
pos_scores = pos_scores.as_strided(
|
| 1848 |
+
(num_heads, batch_size, seq_len, k_len),
|
| 1849 |
+
(
|
| 1850 |
+
pos_scores.stride(0),
|
| 1851 |
+
pos_scores.stride(1),
|
| 1852 |
+
pos_scores.stride(2) - pos_scores.stride(3),
|
| 1853 |
+
pos_scores.stride(3),
|
| 1854 |
+
),
|
| 1855 |
+
storage_offset=pos_scores.stride(3) * (seq_len - 1),
|
| 1856 |
+
)
|
| 1857 |
+
|
| 1858 |
+
attn_scores = attn_scores + pos_scores
|
| 1859 |
+
|
| 1860 |
+
assert attn_scores.shape == (
|
| 1861 |
+
num_heads,
|
| 1862 |
+
batch_size,
|
| 1863 |
+
seq_len,
|
| 1864 |
+
k_len,
|
| 1865 |
+
), attn_scores.shape
|
| 1866 |
+
|
| 1867 |
+
if key_padding_mask is not None:
|
| 1868 |
+
assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape
|
| 1869 |
+
attn_scores = attn_scores.masked_fill(
|
| 1870 |
+
key_padding_mask.unsqueeze(1),
|
| 1871 |
+
-1000,
|
| 1872 |
+
)
|
| 1873 |
+
|
| 1874 |
+
attn_weights = attn_scores.softmax(dim=-1)
|
| 1875 |
+
|
| 1876 |
+
return attn_weights, cached_key
|
| 1877 |
+
|
| 1878 |
+
def _print_attn_entropy(self, attn_weights: Tensor):
|
| 1879 |
+
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
| 1880 |
+
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
|
| 1881 |
+
|
| 1882 |
+
with torch.no_grad():
|
| 1883 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 1884 |
+
attn_weights = attn_weights.to(torch.float32)
|
| 1885 |
+
attn_weights_entropy = (
|
| 1886 |
+
-((attn_weights + 1.0e-20).log() * attn_weights)
|
| 1887 |
+
.sum(dim=-1)
|
| 1888 |
+
.mean(dim=(1, 2))
|
| 1889 |
+
)
|
| 1890 |
+
logging.info(
|
| 1891 |
+
f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
|
| 1892 |
+
)
|
| 1893 |
+
|
| 1894 |
+
|
| 1895 |
+
class SelfAttention(nn.Module):
|
| 1896 |
+
"""
|
| 1897 |
+
The simplest possible attention module. This one works with already-computed attention
|
| 1898 |
+
weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
|
| 1899 |
+
|
| 1900 |
+
Args:
|
| 1901 |
+
embed_dim: the input and output embedding dimension
|
| 1902 |
+
num_heads: the number of attention heads
|
| 1903 |
+
value_head_dim: the value dimension per head
|
| 1904 |
+
"""
|
| 1905 |
+
|
| 1906 |
+
def __init__(
|
| 1907 |
+
self,
|
| 1908 |
+
embed_dim: int,
|
| 1909 |
+
num_heads: int,
|
| 1910 |
+
value_head_dim: int,
|
| 1911 |
+
) -> None:
|
| 1912 |
+
super().__init__()
|
| 1913 |
+
self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
|
| 1914 |
+
|
| 1915 |
+
self.out_proj = ScaledLinear(
|
| 1916 |
+
num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
|
| 1917 |
+
)
|
| 1918 |
+
|
| 1919 |
+
self.whiten = Whiten(
|
| 1920 |
+
num_groups=1,
|
| 1921 |
+
whitening_limit=_whitening_schedule(7.5, ratio=3.0),
|
| 1922 |
+
prob=(0.025, 0.25),
|
| 1923 |
+
grad_scale=0.01,
|
| 1924 |
+
)
|
| 1925 |
+
|
| 1926 |
+
def forward(
|
| 1927 |
+
self,
|
| 1928 |
+
x: Tensor,
|
| 1929 |
+
attn_weights: Tensor,
|
| 1930 |
+
) -> Tensor:
|
| 1931 |
+
"""
|
| 1932 |
+
Args:
|
| 1933 |
+
x: input tensor, of shape (seq_len, batch_size, embed_dim)
|
| 1934 |
+
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
|
| 1935 |
+
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
|
| 1936 |
+
attn_weights.sum(dim=-1) == 1.
|
| 1937 |
+
Returns:
|
| 1938 |
+
a tensor with the same shape as x.
|
| 1939 |
+
"""
|
| 1940 |
+
(seq_len, batch_size, embed_dim) = x.shape
|
| 1941 |
+
num_heads = attn_weights.shape[0]
|
| 1942 |
+
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
|
| 1943 |
+
|
| 1944 |
+
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
|
| 1945 |
+
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
| 1946 |
+
# now x: (num_heads, batch_size, seq_len, value_head_dim)
|
| 1947 |
+
value_head_dim = x.shape[-1]
|
| 1948 |
+
|
| 1949 |
+
# todo: see whether there is benefit in overriding matmul
|
| 1950 |
+
x = torch.matmul(attn_weights, x)
|
| 1951 |
+
# v: (num_heads, batch_size, seq_len, value_head_dim)
|
| 1952 |
+
|
| 1953 |
+
x = (
|
| 1954 |
+
x.permute(2, 1, 0, 3)
|
| 1955 |
+
.contiguous()
|
| 1956 |
+
.view(seq_len, batch_size, num_heads * value_head_dim)
|
| 1957 |
+
)
|
| 1958 |
+
|
| 1959 |
+
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
|
| 1960 |
+
x = self.out_proj(x)
|
| 1961 |
+
x = self.whiten(x)
|
| 1962 |
+
|
| 1963 |
+
return x
|
| 1964 |
+
|
| 1965 |
+
def streaming_forward(
|
| 1966 |
+
self,
|
| 1967 |
+
x: Tensor,
|
| 1968 |
+
attn_weights: Tensor,
|
| 1969 |
+
cached_val: Tensor,
|
| 1970 |
+
left_context_len: int,
|
| 1971 |
+
) -> Tuple[Tensor, Tensor]:
|
| 1972 |
+
"""
|
| 1973 |
+
Args:
|
| 1974 |
+
x: input tensor, of shape (seq_len, batch_size, embed_dim)
|
| 1975 |
+
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
|
| 1976 |
+
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
|
| 1977 |
+
attn_weights.sum(dim=-1) == 1.
|
| 1978 |
+
cached_val: cached attention value tensor of left context,
|
| 1979 |
+
of shape (left_context_len, batch_size, value_dim)
|
| 1980 |
+
left_context_len: number of left context frames.
|
| 1981 |
+
|
| 1982 |
+
Returns:
|
| 1983 |
+
- attention weighted output, a tensor with the same shape as x.
|
| 1984 |
+
- updated cached attention value tensor of left context.
|
| 1985 |
+
"""
|
| 1986 |
+
(seq_len, batch_size, embed_dim) = x.shape
|
| 1987 |
+
num_heads = attn_weights.shape[0]
|
| 1988 |
+
seq_len2 = seq_len + left_context_len
|
| 1989 |
+
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)
|
| 1990 |
+
|
| 1991 |
+
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
|
| 1992 |
+
|
| 1993 |
+
# Pad cached left contexts
|
| 1994 |
+
assert cached_val.shape[0] == left_context_len, (
|
| 1995 |
+
cached_val.shape[0],
|
| 1996 |
+
left_context_len,
|
| 1997 |
+
)
|
| 1998 |
+
x = torch.cat([cached_val, x], dim=0)
|
| 1999 |
+
# Update cached left contexts
|
| 2000 |
+
cached_val = x[-left_context_len:, ...]
|
| 2001 |
+
|
| 2002 |
+
x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
| 2003 |
+
# now x: (num_heads, batch_size, seq_len, value_head_dim)
|
| 2004 |
+
value_head_dim = x.shape[-1]
|
| 2005 |
+
|
| 2006 |
+
# todo: see whether there is benefit in overriding matmul
|
| 2007 |
+
x = torch.matmul(attn_weights, x)
|
| 2008 |
+
# v: (num_heads, batch_size, seq_len, value_head_dim)
|
| 2009 |
+
|
| 2010 |
+
x = (
|
| 2011 |
+
x.permute(2, 1, 0, 3)
|
| 2012 |
+
.contiguous()
|
| 2013 |
+
.view(seq_len, batch_size, num_heads * value_head_dim)
|
| 2014 |
+
)
|
| 2015 |
+
|
| 2016 |
+
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
|
| 2017 |
+
x = self.out_proj(x)
|
| 2018 |
+
|
| 2019 |
+
return x, cached_val
|
| 2020 |
+
|
| 2021 |
+
|
| 2022 |
+
class FeedforwardModule(nn.Module):
|
| 2023 |
+
"""Feedforward module in Zipformer2 model."""
|
| 2024 |
+
|
| 2025 |
+
def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
|
| 2026 |
+
super(FeedforwardModule, self).__init__()
|
| 2027 |
+
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
|
| 2028 |
+
|
| 2029 |
+
self.hidden_balancer = Balancer(
|
| 2030 |
+
feedforward_dim,
|
| 2031 |
+
channel_dim=-1,
|
| 2032 |
+
min_positive=0.3,
|
| 2033 |
+
max_positive=1.0,
|
| 2034 |
+
min_abs=0.75,
|
| 2035 |
+
max_abs=5.0,
|
| 2036 |
+
)
|
| 2037 |
+
|
| 2038 |
+
# shared_dim=0 means we share the dropout mask along the time axis
|
| 2039 |
+
self.out_proj = ActivationDropoutAndLinear(
|
| 2040 |
+
feedforward_dim,
|
| 2041 |
+
embed_dim,
|
| 2042 |
+
activation="SwooshL",
|
| 2043 |
+
dropout_p=dropout,
|
| 2044 |
+
dropout_shared_dim=0,
|
| 2045 |
+
bias=True,
|
| 2046 |
+
initial_scale=0.1,
|
| 2047 |
+
)
|
| 2048 |
+
|
| 2049 |
+
self.out_whiten = Whiten(
|
| 2050 |
+
num_groups=1,
|
| 2051 |
+
whitening_limit=_whitening_schedule(7.5),
|
| 2052 |
+
prob=(0.025, 0.25),
|
| 2053 |
+
grad_scale=0.01,
|
| 2054 |
+
)
|
| 2055 |
+
|
| 2056 |
+
def forward(self, x: Tensor):
|
| 2057 |
+
x = self.in_proj(x)
|
| 2058 |
+
x = self.hidden_balancer(x)
|
| 2059 |
+
# out_proj contains SwooshL activation, then dropout, then linear.
|
| 2060 |
+
x = self.out_proj(x)
|
| 2061 |
+
x = self.out_whiten(x)
|
| 2062 |
+
return x
|
| 2063 |
+
|
| 2064 |
+
|
| 2065 |
+
class NonlinAttention(nn.Module):
|
| 2066 |
+
"""This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
|
| 2067 |
+
from the attention module) in place of actual convolution. We also took out the second nonlinearity, the
|
| 2068 |
+
one after the attention mechanism.
|
| 2069 |
+
|
| 2070 |
+
Args:
|
| 2071 |
+
channels (int): The number of channels of conv layers.
|
| 2072 |
+
"""
|
| 2073 |
+
|
| 2074 |
+
def __init__(
|
| 2075 |
+
self,
|
| 2076 |
+
channels: int,
|
| 2077 |
+
hidden_channels: int,
|
| 2078 |
+
) -> None:
|
| 2079 |
+
super().__init__()
|
| 2080 |
+
|
| 2081 |
+
self.hidden_channels = hidden_channels
|
| 2082 |
+
|
| 2083 |
+
self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
|
| 2084 |
+
|
| 2085 |
+
# balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0,
|
| 2086 |
+
# because we noticed that well-trained instances of this module have abs-value before the sigmoid
|
| 2087 |
+
# starting from about 3, and poorly-trained instances of the module have smaller abs values
|
| 2088 |
+
# before the sigmoid.
|
| 2089 |
+
self.balancer = Balancer(
|
| 2090 |
+
hidden_channels,
|
| 2091 |
+
channel_dim=-1,
|
| 2092 |
+
min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
|
| 2093 |
+
max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
|
| 2094 |
+
min_abs=0.5,
|
| 2095 |
+
max_abs=5.0,
|
| 2096 |
+
)
|
| 2097 |
+
self.tanh = nn.Tanh()
|
| 2098 |
+
|
| 2099 |
+
self.identity1 = Identity() # for diagnostics.
|
| 2100 |
+
self.identity2 = Identity() # for diagnostics.
|
| 2101 |
+
self.identity3 = Identity() # for diagnostics.
|
| 2102 |
+
|
| 2103 |
+
self.out_proj = ScaledLinear(
|
| 2104 |
+
hidden_channels, channels, bias=True, initial_scale=0.05
|
| 2105 |
+
)
|
| 2106 |
+
|
| 2107 |
+
self.whiten1 = Whiten(
|
| 2108 |
+
num_groups=1,
|
| 2109 |
+
whitening_limit=_whitening_schedule(5.0),
|
| 2110 |
+
prob=(0.025, 0.25),
|
| 2111 |
+
grad_scale=0.01,
|
| 2112 |
+
)
|
| 2113 |
+
|
| 2114 |
+
self.whiten2 = Whiten(
|
| 2115 |
+
num_groups=1,
|
| 2116 |
+
whitening_limit=_whitening_schedule(5.0, ratio=3.0),
|
| 2117 |
+
prob=(0.025, 0.25),
|
| 2118 |
+
grad_scale=0.01,
|
| 2119 |
+
)
|
| 2120 |
+
|
| 2121 |
+
def forward(
|
| 2122 |
+
self,
|
| 2123 |
+
x: Tensor,
|
| 2124 |
+
attn_weights: Tensor,
|
| 2125 |
+
) -> Tensor:
|
| 2126 |
+
""".
|
| 2127 |
+
Args:
|
| 2128 |
+
x: a Tensor of shape (seq_len, batch_size, num_channels)
|
| 2129 |
+
attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
| 2130 |
+
Returns:
|
| 2131 |
+
a Tensor with the same shape as x
|
| 2132 |
+
"""
|
| 2133 |
+
x = self.in_proj(x)
|
| 2134 |
+
|
| 2135 |
+
(seq_len, batch_size, _) = x.shape
|
| 2136 |
+
hidden_channels = self.hidden_channels
|
| 2137 |
+
|
| 2138 |
+
s, x, y = x.chunk(3, dim=-1)
|
| 2139 |
+
|
| 2140 |
+
# s will go through tanh.
|
| 2141 |
+
|
| 2142 |
+
s = self.balancer(s)
|
| 2143 |
+
s = self.tanh(s)
|
| 2144 |
+
|
| 2145 |
+
s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
|
| 2146 |
+
x = self.whiten1(x)
|
| 2147 |
+
x = x * s
|
| 2148 |
+
x = self.identity1(x) # diagnostics only, it's the identity.
|
| 2149 |
+
|
| 2150 |
+
(seq_len, batch_size, embed_dim) = x.shape
|
| 2151 |
+
num_heads = attn_weights.shape[0]
|
| 2152 |
+
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
|
| 2153 |
+
|
| 2154 |
+
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
| 2155 |
+
# now x: (num_heads, batch_size, seq_len, head_dim)
|
| 2156 |
+
x = torch.matmul(attn_weights, x)
|
| 2157 |
+
# now x: (num_heads, batch_size, seq_len, head_dim)
|
| 2158 |
+
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
| 2159 |
+
|
| 2160 |
+
y = self.identity2(y)
|
| 2161 |
+
x = x * y
|
| 2162 |
+
x = self.identity3(x)
|
| 2163 |
+
|
| 2164 |
+
x = self.out_proj(x)
|
| 2165 |
+
x = self.whiten2(x)
|
| 2166 |
+
return x
|
| 2167 |
+
|
| 2168 |
+
def streaming_forward(
|
| 2169 |
+
self,
|
| 2170 |
+
x: Tensor,
|
| 2171 |
+
attn_weights: Tensor,
|
| 2172 |
+
cached_x: Tensor,
|
| 2173 |
+
left_context_len: int,
|
| 2174 |
+
) -> Tuple[Tensor, Tensor]:
|
| 2175 |
+
""".
|
| 2176 |
+
Args:
|
| 2177 |
+
x: a Tensor of shape (seq_len, batch_size, num_channels)
|
| 2178 |
+
attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
| 2179 |
+
cached_x: left context, a Tensor of shape
|
| 2180 |
+
(num_heads, batch_size, left_context_len, head_dim)
|
| 2181 |
+
left_context_len: number of left context frames.
|
| 2182 |
+
Returns:
|
| 2183 |
+
- a Tensor with the same shape as x
|
| 2184 |
+
- updated left context with same shape as cached_x
|
| 2185 |
+
"""
|
| 2186 |
+
x = self.in_proj(x)
|
| 2187 |
+
|
| 2188 |
+
(seq_len, batch_size, _) = x.shape
|
| 2189 |
+
hidden_channels = self.hidden_channels
|
| 2190 |
+
|
| 2191 |
+
s, x, y = x.chunk(3, dim=-1)
|
| 2192 |
+
|
| 2193 |
+
# s will go through tanh.
|
| 2194 |
+
s = self.tanh(s)
|
| 2195 |
+
|
| 2196 |
+
s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
|
| 2197 |
+
x = x * s
|
| 2198 |
+
|
| 2199 |
+
(seq_len, batch_size, embed_dim) = x.shape
|
| 2200 |
+
num_heads = attn_weights.shape[0]
|
| 2201 |
+
assert attn_weights.shape == (
|
| 2202 |
+
num_heads,
|
| 2203 |
+
batch_size,
|
| 2204 |
+
seq_len,
|
| 2205 |
+
left_context_len + seq_len,
|
| 2206 |
+
)
|
| 2207 |
+
|
| 2208 |
+
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
| 2209 |
+
# now x: (num_heads, batch_size, seq_len, head_dim)
|
| 2210 |
+
|
| 2211 |
+
# Pad cached tensor
|
| 2212 |
+
assert cached_x.shape[2] == left_context_len, (
|
| 2213 |
+
cached_x.shape[2],
|
| 2214 |
+
left_context_len,
|
| 2215 |
+
)
|
| 2216 |
+
x_pad = torch.cat([cached_x, x], dim=2)
|
| 2217 |
+
# Update cached tensor
|
| 2218 |
+
cached_x = x_pad[:, :, -left_context_len:, :]
|
| 2219 |
+
|
| 2220 |
+
x = torch.matmul(attn_weights, x_pad)
|
| 2221 |
+
# now x: (num_heads, batch_size, seq_len, head_dim)
|
| 2222 |
+
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
| 2223 |
+
|
| 2224 |
+
x = x * y
|
| 2225 |
+
|
| 2226 |
+
x = self.out_proj(x)
|
| 2227 |
+
return x, cached_x
|
| 2228 |
+
|
| 2229 |
+
|
| 2230 |
+
class ConvolutionModule(nn.Module):
|
| 2231 |
+
"""ConvolutionModule in Zipformer2 model.
|
| 2232 |
+
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
|
| 2233 |
+
|
| 2234 |
+
Args:
|
| 2235 |
+
channels (int): The number of channels of conv layers.
|
| 2236 |
+
kernel_size (int): Kernerl size of conv layers.
|
| 2237 |
+
bias (bool): Whether to use bias in conv layers (default=True).
|
| 2238 |
+
|
| 2239 |
+
"""
|
| 2240 |
+
|
| 2241 |
+
def __init__(
|
| 2242 |
+
self,
|
| 2243 |
+
channels: int,
|
| 2244 |
+
kernel_size: int,
|
| 2245 |
+
causal: bool,
|
| 2246 |
+
) -> None:
|
| 2247 |
+
"""Construct a ConvolutionModule object."""
|
| 2248 |
+
super(ConvolutionModule, self).__init__()
|
| 2249 |
+
# kernerl_size should be a odd number for 'SAME' padding
|
| 2250 |
+
assert (kernel_size - 1) % 2 == 0
|
| 2251 |
+
|
| 2252 |
+
bottleneck_dim = channels
|
| 2253 |
+
self.causal = causal
|
| 2254 |
+
|
| 2255 |
+
self.in_proj = nn.Linear(
|
| 2256 |
+
channels,
|
| 2257 |
+
2 * bottleneck_dim,
|
| 2258 |
+
)
|
| 2259 |
+
# the gradients on in_proj are a little noisy, likely to do with the
|
| 2260 |
+
# sigmoid in glu.
|
| 2261 |
+
|
| 2262 |
+
# after in_proj we put x through a gated linear unit (nn.functional.glu).
|
| 2263 |
+
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
|
| 2264 |
+
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
|
| 2265 |
+
# between 50 and 100 for different channels. This will cause very peaky and
|
| 2266 |
+
# sparse derivatives for the sigmoid gating function, which will tend to make
|
| 2267 |
+
# the loss function not learn effectively. (for most layers the average absolute values
|
| 2268 |
+
# are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
|
| 2269 |
+
# at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
|
| 2270 |
+
# layers, which likely breaks down as 0.5 for the "linear" half and
|
| 2271 |
+
# 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
|
| 2272 |
+
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
| 2273 |
+
# it will be in a better position to start learning something, i.e. to latch onto
|
| 2274 |
+
# the correct range.
|
| 2275 |
+
self.balancer1 = Balancer(
|
| 2276 |
+
bottleneck_dim,
|
| 2277 |
+
channel_dim=-1,
|
| 2278 |
+
min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
|
| 2279 |
+
max_positive=1.0,
|
| 2280 |
+
min_abs=1.5,
|
| 2281 |
+
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
|
| 2282 |
+
)
|
| 2283 |
+
|
| 2284 |
+
self.activation1 = Identity() # for diagnostics
|
| 2285 |
+
|
| 2286 |
+
self.sigmoid = nn.Sigmoid()
|
| 2287 |
+
|
| 2288 |
+
self.activation2 = Identity() # for diagnostics
|
| 2289 |
+
|
| 2290 |
+
assert kernel_size % 2 == 1
|
| 2291 |
+
|
| 2292 |
+
self.depthwise_conv = (
|
| 2293 |
+
ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size)
|
| 2294 |
+
if causal
|
| 2295 |
+
else nn.Conv1d(
|
| 2296 |
+
in_channels=bottleneck_dim,
|
| 2297 |
+
out_channels=bottleneck_dim,
|
| 2298 |
+
groups=bottleneck_dim,
|
| 2299 |
+
kernel_size=kernel_size,
|
| 2300 |
+
padding=kernel_size // 2,
|
| 2301 |
+
)
|
| 2302 |
+
)
|
| 2303 |
+
|
| 2304 |
+
self.balancer2 = Balancer(
|
| 2305 |
+
bottleneck_dim,
|
| 2306 |
+
channel_dim=1,
|
| 2307 |
+
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
| 2308 |
+
max_positive=1.0,
|
| 2309 |
+
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
|
| 2310 |
+
max_abs=10.0,
|
| 2311 |
+
)
|
| 2312 |
+
|
| 2313 |
+
self.whiten = Whiten(
|
| 2314 |
+
num_groups=1,
|
| 2315 |
+
whitening_limit=_whitening_schedule(7.5),
|
| 2316 |
+
prob=(0.025, 0.25),
|
| 2317 |
+
grad_scale=0.01,
|
| 2318 |
+
)
|
| 2319 |
+
|
| 2320 |
+
self.out_proj = ActivationDropoutAndLinear(
|
| 2321 |
+
bottleneck_dim,
|
| 2322 |
+
channels,
|
| 2323 |
+
activation="SwooshR",
|
| 2324 |
+
dropout_p=0.0,
|
| 2325 |
+
initial_scale=0.05,
|
| 2326 |
+
)
|
| 2327 |
+
|
| 2328 |
+
def forward(
|
| 2329 |
+
self,
|
| 2330 |
+
x: Tensor,
|
| 2331 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 2332 |
+
chunk_size: int = -1,
|
| 2333 |
+
) -> Tensor:
|
| 2334 |
+
"""Compute convolution module.
|
| 2335 |
+
|
| 2336 |
+
Args:
|
| 2337 |
+
x: Input tensor (#time, batch, channels).
|
| 2338 |
+
src_key_padding_mask: the mask for the src keys per batch (optional):
|
| 2339 |
+
(batch, #time), contains True in masked positions.
|
| 2340 |
+
|
| 2341 |
+
Returns:
|
| 2342 |
+
Tensor: Output tensor (#time, batch, channels).
|
| 2343 |
+
|
| 2344 |
+
"""
|
| 2345 |
+
|
| 2346 |
+
x = self.in_proj(x) # (time, batch, 2*channels)
|
| 2347 |
+
|
| 2348 |
+
x, s = x.chunk(2, dim=-1)
|
| 2349 |
+
s = self.balancer1(s)
|
| 2350 |
+
s = self.sigmoid(s)
|
| 2351 |
+
x = self.activation1(x) # identity.
|
| 2352 |
+
x = x * s
|
| 2353 |
+
x = self.activation2(x) # identity
|
| 2354 |
+
|
| 2355 |
+
# (time, batch, channels)
|
| 2356 |
+
|
| 2357 |
+
# exchange the temporal dimension and the feature dimension
|
| 2358 |
+
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
| 2359 |
+
|
| 2360 |
+
if src_key_padding_mask is not None:
|
| 2361 |
+
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
| 2362 |
+
|
| 2363 |
+
if (
|
| 2364 |
+
not torch.jit.is_scripting()
|
| 2365 |
+
and not torch.jit.is_tracing()
|
| 2366 |
+
and chunk_size >= 0
|
| 2367 |
+
):
|
| 2368 |
+
# Not support exporting a model for simulated streaming decoding
|
| 2369 |
+
assert (
|
| 2370 |
+
self.causal
|
| 2371 |
+
), "Must initialize model with causal=True if you use chunk_size"
|
| 2372 |
+
x = self.depthwise_conv(x, chunk_size=chunk_size)
|
| 2373 |
+
else:
|
| 2374 |
+
x = self.depthwise_conv(x)
|
| 2375 |
+
|
| 2376 |
+
x = self.balancer2(x)
|
| 2377 |
+
x = x.permute(2, 0, 1) # (time, batch, channels)
|
| 2378 |
+
|
| 2379 |
+
x = self.whiten(x) # (time, batch, channels)
|
| 2380 |
+
x = self.out_proj(x) # (time, batch, channels)
|
| 2381 |
+
|
| 2382 |
+
return x
|
| 2383 |
+
|
| 2384 |
+
def streaming_forward(
|
| 2385 |
+
self,
|
| 2386 |
+
x: Tensor,
|
| 2387 |
+
cache: Tensor,
|
| 2388 |
+
src_key_padding_mask: Tensor,
|
| 2389 |
+
) -> Tuple[Tensor, Tensor]:
|
| 2390 |
+
"""Compute convolution module in streaming forward mode.
|
| 2391 |
+
|
| 2392 |
+
Args:
|
| 2393 |
+
x: Input tensor (#time, batch, channels).
|
| 2394 |
+
cache: cached left context for depthwise_conv of shape
|
| 2395 |
+
(#batch, channels, left_pad)
|
| 2396 |
+
src_key_padding_mask: the mask for the src keys per batch (optional):
|
| 2397 |
+
(batch, #time), contains True in masked positions.
|
| 2398 |
+
|
| 2399 |
+
Returns:
|
| 2400 |
+
- Output tensor (#time, batch, channels).
|
| 2401 |
+
- Updated cache (#batch, channels, left_pad)
|
| 2402 |
+
"""
|
| 2403 |
+
|
| 2404 |
+
x = self.in_proj(x) # (time, batch, 2*channels)
|
| 2405 |
+
|
| 2406 |
+
x, s = x.chunk(2, dim=2)
|
| 2407 |
+
s = self.sigmoid(s)
|
| 2408 |
+
x = x * s
|
| 2409 |
+
# (time, batch, channels)
|
| 2410 |
+
|
| 2411 |
+
# exchange the temporal dimension and the feature dimension
|
| 2412 |
+
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
| 2413 |
+
|
| 2414 |
+
if src_key_padding_mask is not None:
|
| 2415 |
+
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
| 2416 |
+
|
| 2417 |
+
x, cache = self.depthwise_conv.streaming_forward(x, cache=cache)
|
| 2418 |
+
|
| 2419 |
+
x = x.permute(2, 0, 1) # (time, batch, channels)
|
| 2420 |
+
|
| 2421 |
+
x = self.out_proj(x) # (time, batch, channels)
|
| 2422 |
+
|
| 2423 |
+
return x, cache
|
| 2424 |
+
|
| 2425 |
+
|
| 2426 |
+
class ScalarMultiply(nn.Module):
|
| 2427 |
+
def __init__(self, scale: float):
|
| 2428 |
+
super().__init__()
|
| 2429 |
+
self.scale = scale
|
| 2430 |
+
|
| 2431 |
+
def forward(self, x):
|
| 2432 |
+
return x * self.scale
|
| 2433 |
+
|
| 2434 |
+
|
| 2435 |
+
def _test_zipformer_main(causal: bool = False):
|
| 2436 |
+
batch_size = 5
|
| 2437 |
+
seq_len = 20
|
| 2438 |
+
# Just make sure the forward pass runs.
|
| 2439 |
+
|
| 2440 |
+
c = Zipformer2(
|
| 2441 |
+
encoder_dim=(64, 96),
|
| 2442 |
+
encoder_unmasked_dim=(48, 64),
|
| 2443 |
+
num_heads=(4, 4),
|
| 2444 |
+
causal=causal,
|
| 2445 |
+
chunk_size=(4,) if causal else (-1,),
|
| 2446 |
+
left_context_frames=(64,),
|
| 2447 |
+
)
|
| 2448 |
+
batch_size = 5
|
| 2449 |
+
seq_len = 20
|
| 2450 |
+
# Just make sure the forward pass runs.
|
| 2451 |
+
f = c(
|
| 2452 |
+
torch.randn(seq_len, batch_size, 64),
|
| 2453 |
+
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
| 2454 |
+
)
|
| 2455 |
+
f[0].sum().backward()
|
| 2456 |
+
c.eval()
|
| 2457 |
+
f = c(
|
| 2458 |
+
torch.randn(seq_len, batch_size, 64),
|
| 2459 |
+
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
| 2460 |
+
)
|
| 2461 |
+
f # to remove flake8 warnings
|
| 2462 |
+
|
| 2463 |
+
|
| 2464 |
+
if __name__ == "__main__":
|
| 2465 |
+
logging.getLogger().setLevel(logging.INFO)
|
| 2466 |
+
torch.set_num_threads(1)
|
| 2467 |
+
torch.set_num_interop_threads(1)
|
| 2468 |
+
_test_zipformer_main(False)
|
| 2469 |
+
_test_zipformer_main(True)
|