|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
|
import glob |
|
|
import json |
|
|
import os |
|
|
from dataclasses import dataclass, is_dataclass |
|
|
from pathlib import Path |
|
|
from typing import List, Optional |
|
|
|
|
|
import pytorch_lightning as pl |
|
|
import torch |
|
|
from omegaconf import OmegaConf |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
from nemo.collections.asr.models import SLUIntentSlotBPEModel |
|
|
from nemo.collections.asr.parts.utils.slu_utils import SequenceGeneratorConfig |
|
|
from nemo.core.config import hydra_runner |
|
|
from nemo.utils import logging |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class InferenceConfig: |
|
|
|
|
|
model_path: Optional[str] = None |
|
|
pretrained_name: Optional[str] = None |
|
|
audio_dir: Optional[str] = None |
|
|
dataset_manifest: Optional[str] = None |
|
|
|
|
|
|
|
|
output_filename: Optional[str] = None |
|
|
batch_size: int = 32 |
|
|
num_workers: int = 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cuda: Optional[int] = None |
|
|
amp: bool = False |
|
|
audio_type: str = "wav" |
|
|
|
|
|
|
|
|
overwrite_transcripts: bool = True |
|
|
|
|
|
|
|
|
sequence_generator: SequenceGeneratorConfig = SequenceGeneratorConfig(type="greedy") |
|
|
|
|
|
|
|
|
def slurp_inference(model, path2manifest: str, batch_size: int = 4, num_workers: int = 0,) -> List[str]: |
|
|
|
|
|
if num_workers is None: |
|
|
num_workers = min(batch_size, os.cpu_count() - 1) |
|
|
|
|
|
|
|
|
hypotheses = [] |
|
|
|
|
|
mode = model.training |
|
|
device = next(model.parameters()).device |
|
|
dither_value = model.preprocessor.featurizer.dither |
|
|
pad_to_value = model.preprocessor.featurizer.pad_to |
|
|
|
|
|
try: |
|
|
model.preprocessor.featurizer.dither = 0.0 |
|
|
model.preprocessor.featurizer.pad_to = 0 |
|
|
|
|
|
model.eval() |
|
|
|
|
|
logging_level = logging.get_verbosity() |
|
|
logging.set_verbosity(logging.WARNING) |
|
|
|
|
|
config = { |
|
|
'manifest_filepath': path2manifest, |
|
|
'batch_size': batch_size, |
|
|
'num_workers': num_workers, |
|
|
} |
|
|
|
|
|
temporary_datalayer = model._setup_transcribe_dataloader(config) |
|
|
for test_batch in tqdm(temporary_datalayer, desc="Transcribing", ncols=80): |
|
|
predictions = model.predict( |
|
|
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) |
|
|
) |
|
|
|
|
|
hypotheses += predictions |
|
|
|
|
|
del predictions |
|
|
del test_batch |
|
|
|
|
|
finally: |
|
|
|
|
|
model.train(mode=mode) |
|
|
model.preprocessor.featurizer.dither = dither_value |
|
|
model.preprocessor.featurizer.pad_to = pad_to_value |
|
|
logging.set_verbosity(logging_level) |
|
|
return hypotheses |
|
|
|
|
|
|
|
|
@hydra_runner(config_name="InferenceConfig", schema=InferenceConfig) |
|
|
def run_inference(cfg: InferenceConfig) -> InferenceConfig: |
|
|
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') |
|
|
|
|
|
if is_dataclass(cfg): |
|
|
cfg = OmegaConf.structured(cfg) |
|
|
|
|
|
if cfg.model_path is None and cfg.pretrained_name is None: |
|
|
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") |
|
|
if cfg.audio_dir is None and cfg.dataset_manifest is None: |
|
|
raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") |
|
|
|
|
|
|
|
|
if cfg.cuda is None: |
|
|
if torch.cuda.is_available(): |
|
|
device = [0] |
|
|
accelerator = 'gpu' |
|
|
else: |
|
|
device = 1 |
|
|
accelerator = 'cpu' |
|
|
else: |
|
|
device = [cfg.cuda] |
|
|
accelerator = 'gpu' |
|
|
|
|
|
map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu') |
|
|
|
|
|
|
|
|
if cfg.model_path is not None: |
|
|
|
|
|
logging.info(f"Restoring model : {cfg.model_path}") |
|
|
model = SLUIntentSlotBPEModel.restore_from(restore_path=cfg.model_path, map_location=map_location) |
|
|
model_name = os.path.splitext(os.path.basename(cfg.model_path))[0] |
|
|
else: |
|
|
|
|
|
model = SLUIntentSlotBPEModel.from_pretrained(model_name=cfg.pretrained_name, map_location=map_location) |
|
|
model_name = cfg.pretrained_name |
|
|
|
|
|
trainer = pl.Trainer(devices=device, accelerator=accelerator) |
|
|
model.set_trainer(trainer) |
|
|
model = model.eval() |
|
|
|
|
|
|
|
|
model.set_decoding_strategy(cfg.sequence_generator) |
|
|
|
|
|
|
|
|
if cfg.audio_dir is not None: |
|
|
filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True)) |
|
|
else: |
|
|
|
|
|
filepaths = [] |
|
|
if os.stat(cfg.dataset_manifest).st_size == 0: |
|
|
logging.error(f"The input dataset_manifest {cfg.dataset_manifest} is empty. Exiting!") |
|
|
return None |
|
|
|
|
|
manifest_dir = Path(cfg.dataset_manifest).parent |
|
|
with open(cfg.dataset_manifest, 'r') as f: |
|
|
has_two_fields = [] |
|
|
for line in f: |
|
|
item = json.loads(line) |
|
|
if "offset" in item and "duration" in item: |
|
|
has_two_fields.append(True) |
|
|
else: |
|
|
has_two_fields.append(False) |
|
|
audio_file = Path(item['audio_filepath']) |
|
|
if not audio_file.is_file() and not audio_file.is_absolute(): |
|
|
audio_file = manifest_dir / audio_file |
|
|
filepaths.append(str(audio_file.absolute())) |
|
|
|
|
|
logging.info(f"\nStart inference with {len(filepaths)} files...\n") |
|
|
|
|
|
|
|
|
if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): |
|
|
logging.info("AMP enabled!\n") |
|
|
autocast = torch.cuda.amp.autocast |
|
|
else: |
|
|
|
|
|
@contextlib.contextmanager |
|
|
def autocast(): |
|
|
yield |
|
|
|
|
|
|
|
|
if cfg.output_filename is None: |
|
|
|
|
|
if cfg.audio_dir is not None: |
|
|
cfg.output_filename = os.path.dirname(os.path.join(cfg.audio_dir, '.')) + '.json' |
|
|
else: |
|
|
cfg.output_filename = cfg.dataset_manifest.replace('.json', f'_{model_name}.json') |
|
|
|
|
|
|
|
|
if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename): |
|
|
logging.info( |
|
|
f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`" |
|
|
f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text." |
|
|
) |
|
|
|
|
|
return cfg |
|
|
|
|
|
|
|
|
with autocast(): |
|
|
with torch.no_grad(): |
|
|
predictions = slurp_inference( |
|
|
model=model, |
|
|
path2manifest=cfg.dataset_manifest, |
|
|
batch_size=cfg.batch_size, |
|
|
num_workers=cfg.num_workers, |
|
|
) |
|
|
|
|
|
logging.info(f"Finished transcribing {len(filepaths)} files !") |
|
|
|
|
|
logging.info(f"Writing transcriptions into file: {cfg.output_filename}") |
|
|
|
|
|
|
|
|
with open(cfg.output_filename, 'w', encoding='utf-8') as f: |
|
|
if cfg.audio_dir is not None: |
|
|
for idx, text in enumerate(predictions): |
|
|
item = {'audio_filepath': filepaths[idx], 'pred_text': text} |
|
|
f.write(json.dumps(item) + "\n") |
|
|
else: |
|
|
with open(cfg.dataset_manifest, 'r') as fr: |
|
|
for idx, line in enumerate(fr): |
|
|
item = json.loads(line) |
|
|
item['pred_text'] = predictions[idx] |
|
|
f.write(json.dumps(item) + "\n") |
|
|
|
|
|
logging.info("Finished writing predictions !") |
|
|
return cfg |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
run_inference() |
|
|
|