File size: 6,576 Bytes
74e8c79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import argparse
from io import BytesIO
from typing import List, Tuple
import numpy as np
import torch
import torchaudio
from nemo.collections.asr.models import EncDecCTCModel
from nemo.collections.asr.modules.audio_preprocessing import (
AudioToMelSpectrogramPreprocessor as NeMoAudioToMelSpectrogramPreprocessor,
)
from nemo.collections.asr.parts.preprocessing.features import (
FilterbankFeaturesTA as NeMoFilterbankFeaturesTA,
)
from pyannote.audio import Pipeline
from pydub import AudioSegment
class FilterbankFeaturesTA(NeMoFilterbankFeaturesTA):
def __init__(self, mel_scale: str = "htk", wkwargs=None, **kwargs):
if "window_size" in kwargs:
del kwargs["window_size"]
if "window_stride" in kwargs:
del kwargs["window_stride"]
super().__init__(**kwargs)
self._mel_spec_extractor = torchaudio.transforms.MelSpectrogram(
sample_rate=self._sample_rate,
win_length=self.win_length,
hop_length=self.hop_length,
n_mels=kwargs["nfilt"],
window_fn=self.torch_windows[kwargs["window"]],
mel_scale=mel_scale,
norm=kwargs["mel_norm"],
n_fft=kwargs["n_fft"],
f_max=kwargs.get("highfreq", None),
f_min=kwargs.get("lowfreq", 0),
wkwargs=wkwargs,
)
class AudioToMelSpectrogramPreprocessor(NeMoAudioToMelSpectrogramPreprocessor):
def __init__(self, mel_scale: str = "htk", **kwargs):
super().__init__(**kwargs)
kwargs["nfilt"] = kwargs["features"]
del kwargs["features"]
self.featurizer = (
FilterbankFeaturesTA( # Deprecated arguments; kept for config compatibility
mel_scale=mel_scale,
**kwargs,
)
)
def audiosegment_to_numpy(audiosegment: AudioSegment) -> np.ndarray:
"""Convert AudioSegment to numpy array."""
samples = np.array(audiosegment.get_array_of_samples())
if audiosegment.channels == 2:
samples = samples.reshape((-1, 2))
samples = samples.astype(np.float32, order="C") / 32768.0
return samples
def format_time(seconds: float) -> str:
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = seconds % 60
full_seconds = int(seconds)
milliseconds = int((seconds - full_seconds) * 100)
if hours > 0:
return f"{hours:02}:{minutes:02}:{full_seconds:02}:{milliseconds:02}"
else:
return f"{minutes:02}:{full_seconds:02}:{milliseconds:02}"
def segment_audio(
audio_path: str,
pipeline: Pipeline,
max_duration: float = 22.0,
min_duration: float = 15.0,
new_chunk_threshold: float = 0.2,
) -> Tuple[List[np.ndarray], List[List[float]]]:
# Prepare audio for pyannote vad pipeline
audio = AudioSegment.from_wav(audio_path)
audio_bytes = BytesIO()
audio.export(audio_bytes, format="wav")
audio_bytes.seek(0)
# Process audio with pipeline to obtain segments with speech activity
sad_segments = pipeline({"uri": "filename", "audio": audio_bytes})
segments = []
curr_duration = 0
curr_start = 0
curr_end = 0
boundaries = []
# Concat segments from pipeline into chunks for asr according to max/min duration
for segment in sad_segments.get_timeline().support():
start = max(0, segment.start)
end = min(len(audio) / 1000, segment.end)
if (
curr_duration > min_duration and start - curr_end > new_chunk_threshold
) or (curr_duration + (end - curr_end) > max_duration):
audio_segment = audiosegment_to_numpy(
audio[curr_start * 1000 : curr_end * 1000]
)
segments.append(audio_segment)
boundaries.append([curr_start, curr_end])
curr_start = start
curr_end = end
curr_duration = curr_end - curr_start
if curr_duration != 0:
audio_segment = audiosegment_to_numpy(
audio[curr_start * 1000 : curr_end * 1000]
)
segments.append(audio_segment)
boundaries.append([curr_start, curr_end])
return segments, boundaries
def _parse_args():
parser = argparse.ArgumentParser(
description="Run long-form inference using GigaAM-CTC checkpoint"
)
parser.add_argument("--model_config", help="Path to GigaAM-CTC config file (.yaml)")
parser.add_argument(
"--model_weights", help="Path to GigaAM-CTC checkpoint file (.ckpt)"
)
parser.add_argument("--audio_path", help="Path to audio signal")
parser.add_argument(
"--hf_token", help="HuggingFace token for using pyannote Pipeline"
)
parser.add_argument("--device", help="Device: cpu / cuda")
parser.add_argument("--fp16", help="Run in FP16 mode", default=True)
parser.add_argument(
"--batch_size", help="Batch size for acoustic model inference", default=10
)
return parser.parse_args()
def main(
model_config: str,
model_weights: str,
device: str,
audio_path: str,
hf_token: str,
fp16: bool,
batch_size: int = 10,
):
# Initialize model
model = EncDecCTCModel.from_config_file(model_config)
ckpt = torch.load(model_weights, map_location="cpu")
model.load_state_dict(ckpt, strict=False)
model = model.to(device)
if device != "cpu" and fp16:
model = model.half()
model.preprocessor = model.preprocessor.float()
model.eval()
# Initialize pyannote pipeline
pipeline = Pipeline.from_pretrained(
"pyannote/voice-activity-detection", use_auth_token=hf_token
)
pipeline = pipeline.to(torch.device(device))
# Segment audio
segments, boundaries = segment_audio(audio_path, pipeline)
# Transcribe segments
transcriptions = []
if device != "cpu" and fp16:
with torch.autocast(device_type="cuda", dtype=torch.float16):
transcriptions = model.transcribe(segments, batch_size=batch_size)
else:
transcriptions = model.transcribe(segments, batch_size=batch_size)
for transcription, boundary in zip(transcriptions, boundaries):
print(
f"[{format_time(boundary[0])} - {format_time(boundary[1])}]: {transcription}\n"
)
if __name__ == "__main__":
args = _parse_args()
main(
model_config=args.model_config,
model_weights=args.model_weights,
device=args.device,
audio_path=args.audio_path,
hf_token=args.hf_token,
fp16=args.fp16,
batch_size=args.batch_size,
)
|