marcoyang commited on
Commit
0bdcf6a
·
1 Parent(s): a7df827

add real chunk-wise streaming inference code

Browse files
inference_600m_streaming_forward.py CHANGED
@@ -1,14 +1,14 @@
1
  import argparse
2
  import math
3
- from typing import Dict, List, Optional, Tuple
4
 
5
  from model import MultiKDModel
6
  from scaling import ScheduledFloat
7
  from subsampling import Conv2dSubsampling
8
  from zipformer import Zipformer2
9
 
10
- from lhotse import Fbank, FbankConfig
11
  import torchaudio
 
12
  import torch
13
  from torch import Tensor
14
  import torch.nn as nn
@@ -311,6 +311,44 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
311
 
312
  return state_list
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  def streaming_forward(
315
  features: Tensor,
316
  feature_lens: Tensor,
@@ -318,7 +356,7 @@ def streaming_forward(
318
  states: List[Tensor],
319
  chunk_size: int,
320
  left_context_len: int,
321
- ) -> Tuple[Tensor, Tensor, List[Tensor]]:
322
  """
323
  Returns encoder outputs, output lengths, and updated states.
324
  """
@@ -351,6 +389,7 @@ def streaming_forward(
351
  encoder_out,
352
  encoder_out_lens,
353
  new_encoder_states,
 
354
  ) = model.encoder.streaming_forward(
355
  x=x,
356
  x_lens=x_lens,
@@ -358,12 +397,13 @@ def streaming_forward(
358
  src_key_padding_mask=src_key_padding_mask,
359
  )
360
  encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
 
361
 
362
  new_states = new_encoder_states + [
363
  new_cached_embed_left_pad,
364
  new_processed_lens,
365
  ]
366
- return encoder_out, encoder_out_lens, new_states
367
 
368
  def chunk_forward(
369
  audio: torch.Tensor,
@@ -373,36 +413,47 @@ def chunk_forward(
373
  left_context_frames: int = 256,
374
  ):
375
  # Perform chunk by chunk forward for the encoder. Each chunk is conditioned on the current chunk and left context (maintained by the states)
376
- # At each step, we take a chunk of audio and forward the encoder
377
- # For the first chunk, we wait until the accumulated audio duration to reach (buffer + chunk_duration), the buffer
378
- # is necessary for the convolution subsampling module in the encoder.
 
 
 
 
 
379
  # After the first chunk, we perform normal chunk-by-chunk inference when the accumulated audio reaches chunk_duration
 
380
  # An example of Buffer=2 frames, chunk=5 frames, the latency for the first chunk is 7 frames (as we need to accumulate 7 frames
381
  # for decoding), the rest chunks have latency of 5 frames.
382
- # Each time we feed (5 + 2) frames to the encoder, and then shift 5 frames
383
  # Chunk 1: AAAAAAA
384
  # Chunk 2: AAAAAAA
385
  # Chunk 3: AAAAAAA
386
 
387
- # NOTE: params.chunk_size is the chunk_size regarding to the input of the zipformer encoder, so at fbank level, the chunk size
388
- # is 2 * params.chunk_size
389
-
390
- # fbank extractor
391
- extractor = Fbank(FbankConfig(num_mel_bins=feature_dim))
392
 
393
  device = next(model.parameters()).device
394
 
395
  chunk_size = int(chunk_size)
396
  chunk_size_samples = int(chunk_size * 2 * 160) # chunk size represented in audio samples of 16kHz sampling rate
397
  left_context_len = int(left_context_frames)
398
- pad_length = 7 + 2 * 3 # buffer required by encoder_embed module (i.e. convolution subsampling)
399
- pad_length_samples = (7 + 2 * 3) * 160
400
 
401
- # intialize states, to be maintained during chunk-wise forward
402
- initial_states = get_init_states(model=model, batch_size=1, device=device)
 
 
 
 
 
 
403
 
404
- # start forward chunk by chunk
 
 
 
405
  encoder_outs = []
 
406
  encoder_out_lens = 0
407
  states = initial_states
408
 
@@ -411,21 +462,24 @@ def chunk_forward(
411
 
412
  # the actual loop performing the chunk-wise inference of the encoder
413
  while True:
414
- # prepare the input for processing current chunk
415
- # compute fbank for the current chunk
416
- audio_chunk = audio[:, num_processed_samples: num_processed_samples + (chunk_size_samples + pad_length_samples)]
417
- features = extractor.extract(audio_chunk, sampling_rate=16000)
 
 
 
 
 
 
418
  features = features.to(device)
419
  feature_lens = features.shape[0]
420
-
421
- feature_lens = torch.tensor([feature_lens], device=device) # shape: (1)
422
- features = features.unsqueeze(0) # shape: (1,T,num_mels)
423
 
424
  # the audio chunk could be shorter than the expected length, for example in the last two chunks
425
- # pad the chunk so that the input shape is (chunk_size + buffer)
426
- tail_length = chunk_size * 2 + 7 + 2 * 3 # each prepared chunk should have this length
427
- if features.size(1) < tail_length:
428
- pad_length = tail_length - features.size(1)
429
  feature_lens += pad_length
430
  features = torch.nn.functional.pad(
431
  features,
@@ -437,7 +491,7 @@ def chunk_forward(
437
  states = stack_states([states])
438
 
439
  # forward current chunk in batch=1
440
- encoder_out, encoder_out_len, new_states = streaming_forward(
441
  features=features,
442
  feature_lens=feature_lens,
443
  model=model,
@@ -447,22 +501,26 @@ def chunk_forward(
447
  )
448
 
449
  encoder_outs.append(encoder_out)
 
450
  encoder_out_lens += encoder_out_len
451
 
452
  # update the states
453
  states = unstack_states(new_states)[0]
454
 
455
  num_chunk += 1
456
- num_processed_samples += chunk_size_samples
457
 
458
  if num_processed_samples > audio.shape[1]:
459
  print(f"Audio is exhausted.")
460
  break
461
 
462
  encoder_outs = torch.cat(encoder_outs, dim=1) # shape: (1,T,C)
463
-
464
- return encoder_outs, encoder_out_lens
465
-
 
 
 
466
 
467
 
468
  def main(args):
@@ -484,18 +542,19 @@ def main(args):
484
  audio, fs = torchaudio.load(args.audio)
485
  assert fs == 16000
486
 
487
- encoder_out, encoder_out_lens = chunk_forward(
488
  audio=audio, # shape (1, num_samples)
489
  model=model,
490
  feature_dim=128,
491
  chunk_size=args.chunk_size,
492
  left_context_frames=args.left_context_frames,
493
  )
494
-
495
 
496
  print(encoder_out)
497
  print(encoder_out.shape)
498
- # torch.save(encoder_out, "streaming_forward_encoder_out_no_k2.pt")
 
 
499
 
500
  if __name__=="__main__":
501
  parser = get_parser()
 
1
  import argparse
2
  import math
3
+ from typing import List, Tuple
4
 
5
  from model import MultiKDModel
6
  from scaling import ScheduledFloat
7
  from subsampling import Conv2dSubsampling
8
  from zipformer import Zipformer2
9
 
 
10
  import torchaudio
11
+ from torchaudio.compliance.kaldi import fbank
12
  import torch
13
  from torch import Tensor
14
  import torch.nn as nn
 
311
 
312
  return state_list
313
 
314
+ def compute_fbank(
315
+ wavs: torch.Tensor, wav_lens: torch.Tensor
316
+ ):
317
+ """Compute fbank features
318
+
319
+ Args:
320
+ wavs (torch.Tensor): the mono-channel input waveform, (N, T)
321
+ wav_lens (torch.Tensor): the length of each waveform in samples (N)
322
+
323
+ Returns:
324
+ The fbank features, and their lengths
325
+ """
326
+ assert wavs.ndim == 2, wavs.shape
327
+ low_freq = 20.0
328
+ high_freq=-400.0
329
+ dither=0.0
330
+ snip_egdes=False
331
+
332
+ features = []
333
+ for i, wav in enumerate(wavs):
334
+ feat = fbank(
335
+ wav[:wav_lens[i]].unsqueeze(0),
336
+ sample_frequency=16000, # this is fixed to 16000
337
+ num_mel_bins=128,
338
+ low_freq=low_freq,
339
+ snip_edges=snip_egdes,
340
+ high_freq=high_freq,
341
+ dither=dither,
342
+ energy_floor=1.0e-10,
343
+ )
344
+ features.append(feat)
345
+ feat_len = torch.tensor([f.shape[0] for f in features]).to(wavs.device)
346
+ features = torch.nn.utils.rnn.pad_sequence(
347
+ features, batch_first=True, padding_value=LOG_EPS
348
+ ).to(wavs.device)
349
+ return features, feat_len
350
+
351
+
352
  def streaming_forward(
353
  features: Tensor,
354
  feature_lens: Tensor,
 
356
  states: List[Tensor],
357
  chunk_size: int,
358
  left_context_len: int,
359
+ ) -> Tuple[Tensor, Tensor, List[Tensor], List[Tensor]]:
360
  """
361
  Returns encoder outputs, output lengths, and updated states.
362
  """
 
389
  encoder_out,
390
  encoder_out_lens,
391
  new_encoder_states,
392
+ middle_outs,
393
  ) = model.encoder.streaming_forward(
394
  x=x,
395
  x_lens=x_lens,
 
397
  src_key_padding_mask=src_key_padding_mask,
398
  )
399
  encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
400
+ middle_outs = [m.permute(1, 0, 2) for m in middle_outs] # (T, N, C) ->(N, T, C)
401
 
402
  new_states = new_encoder_states + [
403
  new_cached_embed_left_pad,
404
  new_processed_lens,
405
  ]
406
+ return encoder_out, encoder_out_lens, new_states, middle_outs
407
 
408
  def chunk_forward(
409
  audio: torch.Tensor,
 
413
  left_context_frames: int = 256,
414
  ):
415
  # Perform chunk by chunk forward for the encoder. Each chunk is conditioned on the current chunk and left context (maintained by the states)
416
+ # At each step, we take a chunk of audio and forward the encoder.
417
+ # For the first chunk, we wait until the accumulated audio duration to reach (chunk_duration + buffer), the buffer
418
+ # is necessary for the convolution subsampling modules in the encoder to produce accurate output.
419
+
420
+ # The buffer consists of two parts:
421
+ # 1. Some trailing fbank frames, covered by the convolution kernels in the encoder_embed
422
+ # 2. Some extra tolerance frames, to give precise last fbank frame (the tolerance fbank frame will be removed)
423
+
424
  # After the first chunk, we perform normal chunk-by-chunk inference when the accumulated audio reaches chunk_duration
425
+
426
  # An example of Buffer=2 frames, chunk=5 frames, the latency for the first chunk is 7 frames (as we need to accumulate 7 frames
427
  # for decoding), the rest chunks have latency of 5 frames.
428
+ # Each time we feed (5 + 2) frames to the encoder, and then shift 5 frames
429
  # Chunk 1: AAAAAAA
430
  # Chunk 2: AAAAAAA
431
  # Chunk 3: AAAAAAA
432
 
433
+ # NOTE: chunk_size is the chunk_size regarding to the input of the zipformer encoder, so at fbank level, the chunk size
434
+ # is 2 * chunk_size
 
 
 
435
 
436
  device = next(model.parameters()).device
437
 
438
  chunk_size = int(chunk_size)
439
  chunk_size_samples = int(chunk_size * 2 * 160) # chunk size represented in audio samples of 16kHz sampling rate
440
  left_context_len = int(left_context_frames)
 
 
441
 
442
+ # Buffer-related
443
+ # 1. extra frames required by encoder_embed module (i.e. convolution subsampling)
444
+ pad_length = 7 + 2 * 3 #
445
+ pad_length_samples = (7 + 2 * 3) * 160 # in samples
446
+
447
+ extra_tolerance = 0.01 # 10 ms
448
+ extra_tolerance_samples = int(extra_tolerance * 16000)
449
+ buffer_samples = pad_length_samples + extra_tolerance_samples
450
 
451
+ chunk_size_with_pad = chunk_size * 2 + 7 + 2 * 3 # This is the total number of fbank frames we need to compute for each chunk forward
452
+
453
+ # intializations, to be maintained during chunk-wise forward
454
+ initial_states = get_init_states(model=model, batch_size=1, device=device)
455
  encoder_outs = []
456
+ middle_outs = []
457
  encoder_out_lens = 0
458
  states = initial_states
459
 
 
462
 
463
  # the actual loop performing the chunk-wise inference of the encoder
464
  while True:
465
+ # Get the audio samples
466
+ audio_chunk = audio[
467
+ :,
468
+ num_processed_samples: num_processed_samples + (chunk_size_samples + buffer_samples)
469
+ ] # (1, num_samples)
470
+
471
+ # compute the fbank features for the current chunk
472
+ features, _ = compute_fbank(audio_chunk, torch.tensor([audio_chunk.shape[-1]])) # shape: (T, num_mels)
473
+
474
+ features = features[:, :chunk_size_with_pad, :] # only keep the required fbank frames for current chunk
475
  features = features.to(device)
476
  feature_lens = features.shape[0]
477
+ feature_lens = torch.tensor([features.shape[1]], device=device) # shape: (1)
 
 
478
 
479
  # the audio chunk could be shorter than the expected length, for example in the last two chunks
480
+ # so we need to pad the chunk to the expected length
481
+ if features.size(1) < chunk_size_with_pad:
482
+ pad_length = chunk_size_with_pad - features.size(1)
 
483
  feature_lens += pad_length
484
  features = torch.nn.functional.pad(
485
  features,
 
491
  states = stack_states([states])
492
 
493
  # forward current chunk in batch=1
494
+ encoder_out, encoder_out_len, new_states, middle_out = streaming_forward(
495
  features=features,
496
  feature_lens=feature_lens,
497
  model=model,
 
501
  )
502
 
503
  encoder_outs.append(encoder_out)
504
+ middle_outs.append(middle_out)
505
  encoder_out_lens += encoder_out_len
506
 
507
  # update the states
508
  states = unstack_states(new_states)[0]
509
 
510
  num_chunk += 1
511
+ num_processed_samples += chunk_size_samples # move one chunk forward
512
 
513
  if num_processed_samples > audio.shape[1]:
514
  print(f"Audio is exhausted.")
515
  break
516
 
517
  encoder_outs = torch.cat(encoder_outs, dim=1) # shape: (1,T,C)
518
+ layerwise_outs = []
519
+ for i in range(len(middle_outs[0])): # for each intermediate layer
520
+ layerwise_outs.append(torch.cat([m[i] for m in middle_outs], dim=1)) # shape: (1,T,C)
521
+
522
+ return encoder_outs, encoder_out_lens, layerwise_outs
523
+
524
 
525
 
526
  def main(args):
 
542
  audio, fs = torchaudio.load(args.audio)
543
  assert fs == 16000
544
 
545
+ encoder_out, encoder_out_lens, intermediate_hidden_states = chunk_forward(
546
  audio=audio, # shape (1, num_samples)
547
  model=model,
548
  feature_dim=128,
549
  chunk_size=args.chunk_size,
550
  left_context_frames=args.left_context_frames,
551
  )
 
552
 
553
  print(encoder_out)
554
  print(encoder_out.shape)
555
+ print(intermediate_hidden_states[-1])
556
+ print(intermediate_hidden_states[-1].shape)
557
+
558
 
559
  if __name__=="__main__":
560
  parser = get_parser()
inference_600m_streaming_forward.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ model_version=600m_uniform_out_ds1
4
+ causal=1
5
+ left_context_frames=128
6
+ chunk_size=8
7
+
8
+ python inference_600m_streaming_forward.py \
9
+ --model-version $model_version \
10
+ --ckpt-path v0.2/iter-500000-avg-4.pt \
11
+ --causal $causal \
12
+ --left-context-frames $left_context_frames \
13
+ --chunk-size $chunk_size \
14
+ --audio 1284-1180-0027.flac
zipformer.py CHANGED
@@ -434,6 +434,7 @@ class Zipformer2(nn.Module):
434
  x_lens: Tensor,
435
  states: List[Tensor],
436
  src_key_padding_mask: Tensor,
 
437
  ) -> Tuple[Tensor, Tensor, List[Tensor]]:
438
  """
439
  Args:
@@ -456,6 +457,7 @@ class Zipformer2(nn.Module):
456
  - updated states
457
  """
458
  outputs = []
 
459
  new_states = []
460
  layer_offset = 0
461
 
@@ -464,14 +466,16 @@ class Zipformer2(nn.Module):
464
  ds = self.downsampling_factor[i]
465
  x = convert_num_channels(x, self.encoder_dim[i])
466
 
467
- x, new_layer_states = module.streaming_forward(
468
  x,
469
  states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
470
  left_context_len=self.left_context_frames[0] // ds,
471
  src_key_padding_mask=src_key_padding_mask[..., ::ds],
 
472
  )
473
  layer_offset += num_layers
474
  outputs.append(x)
 
475
  new_states += new_layer_states
476
 
477
  # if the last output has the largest dimension, x will be unchanged,
@@ -479,17 +483,20 @@ class Zipformer2(nn.Module):
479
  # from different pieces of 'outputs', taking each dimension from the
480
  # most recent output that has it present.
481
  x = self._get_full_dim_output(outputs)
482
- x = self.downsample_output(x)
483
- # class Downsample has this rounding behavior..
484
- assert self.output_downsampling_factor == 2
485
- if torch.jit.is_scripting() or torch.jit.is_tracing():
486
- lengths = (x_lens + 1) // 2
487
- else:
488
- with warnings.catch_warnings():
489
- warnings.simplefilter("ignore")
490
  lengths = (x_lens + 1) // 2
 
 
 
 
 
 
491
 
492
- return x, lengths, new_states
493
 
494
  @torch.jit.export
495
  def get_init_states(
 
434
  x_lens: Tensor,
435
  states: List[Tensor],
436
  src_key_padding_mask: Tensor,
437
+ return_middle_out: bool = True,
438
  ) -> Tuple[Tensor, Tensor, List[Tensor]]:
439
  """
440
  Args:
 
457
  - updated states
458
  """
459
  outputs = []
460
+ middle_outputs = []
461
  new_states = []
462
  layer_offset = 0
463
 
 
466
  ds = self.downsampling_factor[i]
467
  x = convert_num_channels(x, self.encoder_dim[i])
468
 
469
+ x, new_layer_states, cur_middle_out = module.streaming_forward(
470
  x,
471
  states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
472
  left_context_len=self.left_context_frames[0] // ds,
473
  src_key_padding_mask=src_key_padding_mask[..., ::ds],
474
+ return_middle_out=return_middle_out,
475
  )
476
  layer_offset += num_layers
477
  outputs.append(x)
478
+ middle_outputs += cur_middle_out
479
  new_states += new_layer_states
480
 
481
  # if the last output has the largest dimension, x will be unchanged,
 
483
  # from different pieces of 'outputs', taking each dimension from the
484
  # most recent output that has it present.
485
  x = self._get_full_dim_output(outputs)
486
+ if self.output_downsampling_factor >= 2:
487
+ x = self.downsample_output(x)
488
+ # class Downsample has this rounding behavior..
489
+ assert self.output_downsampling_factor == 2
490
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
 
 
 
491
  lengths = (x_lens + 1) // 2
492
+ else:
493
+ with warnings.catch_warnings():
494
+ warnings.simplefilter("ignore")
495
+ lengths = (x_lens + 1) // 2
496
+ else:
497
+ lengths = x_lens
498
 
499
+ return x, lengths, new_states, middle_outputs
500
 
501
  @torch.jit.export
502
  def get_init_states(