|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
from dataclasses import dataclass, is_dataclass |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from omegaconf import OmegaConf |
|
|
from utils.data_prep import ( |
|
|
get_batch_starts_ends, |
|
|
get_batch_tensors_and_boundary_info, |
|
|
get_manifest_lines_batch, |
|
|
is_entry_in_all_lines, |
|
|
is_entry_in_any_lines, |
|
|
) |
|
|
from utils.make_output_files import make_ctm, make_new_manifest |
|
|
from utils.viterbi_decoding import viterbi_decoding |
|
|
|
|
|
from nemo.collections.asr.models.ctc_models import EncDecCTCModel |
|
|
from nemo.collections.asr.parts.utils.transcribe_utils import setup_model |
|
|
from nemo.core.config import hydra_runner |
|
|
from nemo.utils import logging |
|
|
|
|
|
|
|
|
""" |
|
|
Align the utterances in manifest_filepath. |
|
|
Results are saved in ctm files in output_dir. |
|
|
|
|
|
Arguments: |
|
|
pretrained_name: string specifying the name of a CTC NeMo ASR model which will be automatically downloaded |
|
|
from NGC and used for generating the log-probs which we will use to do alignment. |
|
|
Note: NFA can only use CTC models (not Transducer models) at the moment. |
|
|
model_path: string specifying the local filepath to a CTC NeMo ASR model which will be used to generate the |
|
|
log-probs which we will use to do alignment. |
|
|
Note: NFA can only use CTC models (not Transducer models) at the moment. |
|
|
Note: if a model_path is provided, it will override the pretrained_name. |
|
|
model_downsample_factor: an int indicating the downsample factor of the ASR model, ie the ratio of input |
|
|
timesteps to output timesteps. |
|
|
If the ASR model is a QuartzNet model, its downsample factor is 2. |
|
|
If the ASR model is a Conformer CTC model, its downsample factor is 4. |
|
|
If the ASR model is a Citirnet model, its downsample factor is 8. |
|
|
manifest_filepath: filepath to the manifest of the data you want to align, |
|
|
containing 'audio_filepath' and 'text' fields. |
|
|
output_dir: the folder where output CTM files and new JSON manifest will be saved. |
|
|
align_using_pred_text: if True, will transcribe the audio using the specified model and then use that transcription |
|
|
as the 'ground truth' for the forced alignment. |
|
|
transcribe_device: None, or a string specifying the device that will be used for generating log-probs (i.e. "transcribing"). |
|
|
The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available |
|
|
(otherwise will set it to 'cpu'). |
|
|
viterbi_device: None, or string specifying the device that will be used for doing Viterbi decoding. |
|
|
The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available |
|
|
(otherwise will set it to 'cpu'). |
|
|
batch_size: int specifying batch size that will be used for generating log-probs and doing Viterbi decoding. |
|
|
additional_ctm_grouping_separator: the string used to separate CTM segments if you want to obtain CTM files at a |
|
|
level that is not the token level or the word level. NFA will always produce token-level and word-level CTM |
|
|
files in: `<output_dir>/tokens/<utt_id>.ctm` and `<output_dir>/words/<utt_id>.ctm`. |
|
|
If `additional_ctm_grouping_separator` is specified, an additional folder |
|
|
`<output_dir>/{tokens/words/additional_segments}/<utt_id>.ctm` will be created containing CTMs |
|
|
for `addtional_ctm_grouping_separator`-separated segments. |
|
|
remove_blank_tokens_from_ctm: a boolean denoting whether to remove <blank> tokens from token-level output CTMs. |
|
|
audio_filepath_parts_in_utt_id: int specifying how many of the 'parts' of the audio_filepath |
|
|
we will use (starting from the final part of the audio_filepath) to determine the |
|
|
utt_id that will be used in the CTM files. Note also that any spaces that are present in the audio_filepath |
|
|
will be replaced with dashes, so as not to change the number of space-separated elements in the |
|
|
CTM files. |
|
|
e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 1 => utt_id will be "e1" |
|
|
e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 2 => utt_id will be "d_e1" |
|
|
e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 3 => utt_id will be "c_d_e1" |
|
|
minimum_timestamp_duration: a float indicating a minimum duration (in seconds) for timestamps in the CTM. If any |
|
|
line in the CTM has a duration lower than the `minimum_timestamp_duration`, it will be enlarged from the |
|
|
middle outwards until it meets the minimum_timestamp_duration, or reaches the beginning or end of the audio |
|
|
file. Note that this may cause timestamps to overlap. |
|
|
""" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AlignmentConfig: |
|
|
|
|
|
pretrained_name: Optional[str] = None |
|
|
model_path: Optional[str] = None |
|
|
model_downsample_factor: Optional[int] = None |
|
|
manifest_filepath: Optional[str] = None |
|
|
output_dir: Optional[str] = None |
|
|
|
|
|
|
|
|
align_using_pred_text: bool = False |
|
|
transcribe_device: Optional[str] = None |
|
|
viterbi_device: Optional[str] = None |
|
|
batch_size: int = 1 |
|
|
additional_ctm_grouping_separator: Optional[str] = None |
|
|
remove_blank_tokens_from_ctm: bool = False |
|
|
minimum_timestamp_duration: float = 0 |
|
|
audio_filepath_parts_in_utt_id: int = 1 |
|
|
|
|
|
|
|
|
@hydra_runner(config_name="AlignmentConfig", schema=AlignmentConfig) |
|
|
def main(cfg: AlignmentConfig): |
|
|
|
|
|
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') |
|
|
|
|
|
if is_dataclass(cfg): |
|
|
cfg = OmegaConf.structured(cfg) |
|
|
|
|
|
|
|
|
if cfg.model_path is None and cfg.pretrained_name is None: |
|
|
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None") |
|
|
|
|
|
if cfg.model_path is not None and cfg.pretrained_name is not None: |
|
|
raise ValueError("One of cfg.model_path and cfg.pretrained_name must be None") |
|
|
|
|
|
if cfg.model_downsample_factor is None: |
|
|
raise ValueError("cfg.model_downsample_factor must be specified") |
|
|
|
|
|
if cfg.manifest_filepath is None: |
|
|
raise ValueError("cfg.manifest_filepath must be specified") |
|
|
|
|
|
if cfg.output_dir is None: |
|
|
raise ValueError("cfg.output_dir must be specified") |
|
|
|
|
|
if cfg.batch_size < 1: |
|
|
raise ValueError("cfg.batch_size cannot be zero or a negative number") |
|
|
|
|
|
if cfg.additional_ctm_grouping_separator == "" or cfg.additional_ctm_grouping_separator == " ": |
|
|
raise ValueError("cfg.additional_grouping_separator cannot be empty string or space character") |
|
|
|
|
|
if cfg.minimum_timestamp_duration < 0: |
|
|
raise ValueError("cfg.minimum_timestamp_duration cannot be a negative number") |
|
|
|
|
|
|
|
|
if not is_entry_in_all_lines(cfg.manifest_filepath, "audio_filepath"): |
|
|
raise RuntimeError( |
|
|
"At least one line in cfg.manifest_filepath does not contain an 'audio_filepath' entry. " |
|
|
"All lines must contain an 'audio_filepath' entry." |
|
|
) |
|
|
|
|
|
if cfg.align_using_pred_text: |
|
|
if is_entry_in_any_lines(cfg.manifest_filepath, "pred_text"): |
|
|
raise RuntimeError( |
|
|
"Cannot specify cfg.align_using_pred_text=True when the manifest at cfg.manifest_filepath " |
|
|
"contains 'pred_text' entries. This is because the audio will be transcribed and may produce " |
|
|
"a different 'pred_text'. This may cause confusion." |
|
|
) |
|
|
else: |
|
|
if not is_entry_in_all_lines(cfg.manifest_filepath, "text"): |
|
|
raise RuntimeError( |
|
|
"At least one line in cfg.manifest_filepath does not contain a 'text' entry. " |
|
|
"NFA requires all lines to contain a 'text' entry when cfg.align_using_pred_text=True." |
|
|
) |
|
|
|
|
|
|
|
|
if cfg.transcribe_device is None: |
|
|
transcribe_device = torch.device("cuda" if torch.cuda.is_available else "cpu") |
|
|
else: |
|
|
transcribe_device = torch.device(cfg.transcribe_device) |
|
|
logging.info(f"Device to be used for transcription step (`transcribe_device`) is {transcribe_device}") |
|
|
|
|
|
if cfg.viterbi_device is None: |
|
|
viterbi_device = torch.device("cuda" if torch.cuda.is_available else "cpu") |
|
|
else: |
|
|
viterbi_device = torch.device(cfg.viterbi_device) |
|
|
logging.info(f"Device to be used for viterbi step (`viterbi_device`) is {viterbi_device}") |
|
|
|
|
|
if transcribe_device.type == 'cuda' or viterbi_device.type == 'cuda': |
|
|
logging.warning( |
|
|
'One or both of transcribe_device and viterbi_device are GPUs. If you run into OOM errors ' |
|
|
'it may help to change both devices to be the CPU.' |
|
|
) |
|
|
|
|
|
|
|
|
model, _ = setup_model(cfg, transcribe_device) |
|
|
model.eval() |
|
|
|
|
|
if not isinstance(model, EncDecCTCModel): |
|
|
raise NotImplementedError( |
|
|
f"Model {cfg.model_name} is not an instance of NeMo EncDecCTCModel." |
|
|
" Currently only instances of EncDecCTCModels are supported" |
|
|
) |
|
|
|
|
|
if cfg.minimum_timestamp_duration > 0: |
|
|
logging.warning( |
|
|
f"cfg.minimum_timestamp_duration has been set to {cfg.minimum_timestamp_duration} seconds. " |
|
|
"This may cause the alignments for some tokens/words/additional segments to be overlapping." |
|
|
) |
|
|
|
|
|
|
|
|
starts, ends = get_batch_starts_ends(cfg.manifest_filepath, cfg.batch_size) |
|
|
|
|
|
if cfg.align_using_pred_text: |
|
|
|
|
|
pred_text_all_lines = [] |
|
|
else: |
|
|
pred_text_all_lines = None |
|
|
|
|
|
|
|
|
for start, end in zip(starts, ends): |
|
|
manifest_lines_batch = get_manifest_lines_batch(cfg.manifest_filepath, start, end) |
|
|
|
|
|
( |
|
|
log_probs_batch, |
|
|
y_batch, |
|
|
T_batch, |
|
|
U_batch, |
|
|
token_info_batch, |
|
|
word_info_batch, |
|
|
segment_info_batch, |
|
|
pred_text_batch, |
|
|
) = get_batch_tensors_and_boundary_info( |
|
|
manifest_lines_batch, model, cfg.additional_ctm_grouping_separator, cfg.align_using_pred_text, |
|
|
) |
|
|
|
|
|
if cfg.align_using_pred_text: |
|
|
pred_text_all_lines.extend(pred_text_batch) |
|
|
|
|
|
alignments_batch = viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device) |
|
|
|
|
|
make_ctm( |
|
|
token_info_batch, |
|
|
alignments_batch, |
|
|
manifest_lines_batch, |
|
|
model, |
|
|
cfg.model_downsample_factor, |
|
|
os.path.join(cfg.output_dir, "tokens"), |
|
|
cfg.remove_blank_tokens_from_ctm, |
|
|
cfg.audio_filepath_parts_in_utt_id, |
|
|
cfg.minimum_timestamp_duration, |
|
|
) |
|
|
|
|
|
make_ctm( |
|
|
word_info_batch, |
|
|
alignments_batch, |
|
|
manifest_lines_batch, |
|
|
model, |
|
|
cfg.model_downsample_factor, |
|
|
os.path.join(cfg.output_dir, "words"), |
|
|
False, |
|
|
cfg.audio_filepath_parts_in_utt_id, |
|
|
cfg.minimum_timestamp_duration, |
|
|
) |
|
|
|
|
|
if cfg.additional_ctm_grouping_separator: |
|
|
make_ctm( |
|
|
segment_info_batch, |
|
|
alignments_batch, |
|
|
manifest_lines_batch, |
|
|
model, |
|
|
cfg.model_downsample_factor, |
|
|
os.path.join(cfg.output_dir, "additional_segments"), |
|
|
False, |
|
|
cfg.audio_filepath_parts_in_utt_id, |
|
|
cfg.minimum_timestamp_duration, |
|
|
) |
|
|
|
|
|
make_new_manifest( |
|
|
cfg.output_dir, |
|
|
cfg.manifest_filepath, |
|
|
cfg.additional_ctm_grouping_separator, |
|
|
cfg.audio_filepath_parts_in_utt_id, |
|
|
pred_text_all_lines, |
|
|
) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|