Davidsamuel101 commited on
Commit
818c18e
·
1 Parent(s): 249f9a4

Add test wav file and jit script with LFS support

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.wav filter=lfs diff=lfs merge=lfs -text
jit_pretrained_streaming.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # flake8: noqa
3
+ # Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """
19
+ This script loads torchscript models exported by `torch.jit.script()`
20
+ and uses them to decode waves.
21
+ You can use the following command to get the exported models:
22
+
23
+ ./zipformer/export.py \
24
+ --exp-dir ./zipformer/exp \
25
+ --causal 1 \
26
+ --chunk-size 16 \
27
+ --left-context-frames 128 \
28
+ --tokens data/lang_bpe_500/tokens.txt \
29
+ --epoch 30 \
30
+ --avg 9 \
31
+ --jit 1
32
+
33
+ Usage of this script:
34
+
35
+ ./zipformer/jit_pretrained_streaming.py \
36
+ --nn-model-filename ./zipformer/exp-causal/jit_script_chunk_16_left_128.pt \
37
+ --tokens ./data/lang_bpe_500/tokens.txt \
38
+ /path/to/foo.wav \
39
+ """
40
+
41
+ import argparse
42
+ import logging
43
+ from typing import List, Optional
44
+
45
+ import k2
46
+ import torch
47
+ import torchaudio
48
+ from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
49
+
50
+
51
+ def get_parser():
52
+ parser = argparse.ArgumentParser(
53
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
54
+ )
55
+
56
+ parser.add_argument(
57
+ "--nn-model-filename",
58
+ type=str,
59
+ required=True,
60
+ help="Path to the torchscript model jit_script.pt",
61
+ )
62
+
63
+ parser.add_argument(
64
+ "--tokens",
65
+ type=str,
66
+ help="""Path to tokens.txt.""",
67
+ )
68
+
69
+ parser.add_argument(
70
+ "--sample-rate",
71
+ type=int,
72
+ default=16000,
73
+ help="The sample rate of the input sound file",
74
+ )
75
+
76
+ parser.add_argument(
77
+ "sound_file",
78
+ type=str,
79
+ help="The input sound file(s) to transcribe. "
80
+ "Supported formats are those supported by torchaudio.load(). "
81
+ "For example, wav and flac are supported. "
82
+ "The sample rate has to be 16kHz.",
83
+ )
84
+
85
+ return parser
86
+
87
+
88
+ def read_sound_files(
89
+ filenames: List[str], expected_sample_rate: float
90
+ ) -> List[torch.Tensor]:
91
+ """Read a list of sound files into a list 1-D float32 torch tensors.
92
+ Args:
93
+ filenames:
94
+ A list of sound filenames.
95
+ expected_sample_rate:
96
+ The expected sample rate of the sound files.
97
+ Returns:
98
+ Return a list of 1-D float32 torch tensors.
99
+ """
100
+ ans = []
101
+ for f in filenames:
102
+ wave, sample_rate = torchaudio.load(f)
103
+ assert (
104
+ sample_rate == expected_sample_rate
105
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
106
+ # We use only the first channel
107
+ ans.append(wave[0])
108
+ return ans
109
+
110
+
111
+ def greedy_search(
112
+ decoder: torch.jit.ScriptModule,
113
+ joiner: torch.jit.ScriptModule,
114
+ encoder_out: torch.Tensor,
115
+ decoder_out: Optional[torch.Tensor] = None,
116
+ hyp: Optional[List[int]] = None,
117
+ device: torch.device = torch.device("cpu"),
118
+ ):
119
+ assert encoder_out.ndim == 2
120
+ context_size = decoder.context_size
121
+ blank_id = decoder.blank_id
122
+
123
+ if decoder_out is None:
124
+ assert hyp is None, hyp
125
+ hyp = [blank_id] * context_size
126
+ decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0)
127
+ # decoder_input.shape (1,, 1 context_size)
128
+ decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
129
+ else:
130
+ assert decoder_out.ndim == 2
131
+ assert hyp is not None, hyp
132
+
133
+ T = encoder_out.size(0)
134
+ for i in range(T):
135
+ cur_encoder_out = encoder_out[i : i + 1]
136
+ joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0)
137
+ y = joiner_out.argmax(dim=0).item()
138
+
139
+ if y != blank_id:
140
+ hyp.append(y)
141
+ decoder_input = hyp[-context_size:]
142
+
143
+ decoder_input = torch.tensor(
144
+ decoder_input, dtype=torch.int32, device=device
145
+ ).unsqueeze(0)
146
+ decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
147
+
148
+ return hyp, decoder_out
149
+
150
+
151
+ def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
152
+ """Create a CPU streaming feature extractor.
153
+
154
+ At present, we assume it returns a fbank feature extractor with
155
+ fixed options. In the future, we will support passing in the options
156
+ from outside.
157
+
158
+ Returns:
159
+ Return a CPU streaming feature extractor.
160
+ """
161
+ opts = FbankOptions()
162
+ opts.device = "cpu"
163
+ opts.frame_opts.dither = 0
164
+ opts.frame_opts.snip_edges = False
165
+ opts.frame_opts.samp_freq = sample_rate
166
+ opts.mel_opts.num_bins = 80
167
+ opts.mel_opts.high_freq = -400
168
+ return OnlineFbank(opts)
169
+
170
+
171
+ @torch.no_grad()
172
+ def main():
173
+ parser = get_parser()
174
+ args = parser.parse_args()
175
+ logging.info(vars(args))
176
+
177
+ device = torch.device("cpu")
178
+ if torch.cuda.is_available():
179
+ device = torch.device("cuda", 0)
180
+
181
+ logging.info(f"device: {device}")
182
+
183
+ model = torch.jit.load(args.nn_model_filename)
184
+ model.eval()
185
+ model.to(device)
186
+
187
+ encoder = model.encoder
188
+ decoder = model.decoder
189
+ joiner = model.joiner
190
+
191
+ token_table = k2.SymbolTable.from_file(args.tokens)
192
+ context_size = decoder.context_size
193
+
194
+ logging.info("Constructing Fbank computer")
195
+ online_fbank = create_streaming_feature_extractor(args.sample_rate)
196
+
197
+ logging.info(f"Reading sound files: {args.sound_file}")
198
+ wave_samples = read_sound_files(
199
+ filenames=[args.sound_file],
200
+ expected_sample_rate=args.sample_rate,
201
+ )[0]
202
+ logging.info(wave_samples.shape)
203
+
204
+ logging.info("Decoding started")
205
+
206
+ chunk_length = encoder.chunk_size * 2
207
+ T = chunk_length + encoder.pad_length
208
+
209
+ logging.info(f"chunk_length: {chunk_length}")
210
+ logging.info(f"T: {T}")
211
+
212
+ states = encoder.get_init_states(device=device)
213
+
214
+ tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32)
215
+
216
+ wave_samples = torch.cat([wave_samples, tail_padding])
217
+
218
+ chunk = int(0.25 * args.sample_rate) # 0.2 second
219
+ num_processed_frames = 0
220
+
221
+ hyp = None
222
+ decoder_out = None
223
+
224
+ start = 0
225
+ while start < wave_samples.numel():
226
+ logging.info(f"{start}/{wave_samples.numel()}")
227
+ end = min(start + chunk, wave_samples.numel())
228
+ samples = wave_samples[start:end]
229
+ start += chunk
230
+ online_fbank.accept_waveform(
231
+ sampling_rate=args.sample_rate,
232
+ waveform=samples,
233
+ )
234
+ while online_fbank.num_frames_ready - num_processed_frames >= T:
235
+ frames = []
236
+ for i in range(T):
237
+ frames.append(online_fbank.get_frame(num_processed_frames + i))
238
+ frames = torch.cat(frames, dim=0).to(device).unsqueeze(0)
239
+ x_lens = torch.tensor([T], dtype=torch.int32, device=device)
240
+ encoder_out, out_lens, states = encoder(
241
+ features=frames,
242
+ feature_lengths=x_lens,
243
+ states=states,
244
+ )
245
+ num_processed_frames += chunk_length
246
+
247
+ hyp, decoder_out = greedy_search(
248
+ decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device
249
+ )
250
+
251
+ text = ""
252
+ for i in hyp[context_size:]:
253
+ text += token_table[i]
254
+ text = text.replace("▁", " ").strip()
255
+
256
+ logging.info(args.sound_file)
257
+ logging.info(text)
258
+
259
+ logging.info("Decoding Done")
260
+
261
+
262
+ torch.set_num_threads(4)
263
+ torch.set_num_interop_threads(1)
264
+ torch._C._jit_set_profiling_executor(False)
265
+ torch._C._jit_set_profiling_mode(False)
266
+ torch._C._set_graph_executor_optimize(False)
267
+ if __name__ == "__main__":
268
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
269
+
270
+ logging.basicConfig(format=formatter, level=logging.INFO)
271
+ main()
test_waves/sample_1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fbc9032780bc73e4a396f21806138a44f520fcfe07cbcae4f13ca48f44d0198
3
+ size 229420