File size: 10,488 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
# ASR transcribe/inference with multi-GPU/multi-node support for large datasets
# It supports both tarred and non-tarred datasets
# Arguments
#    model: path to a nemo/PTL checkpoint file or name of a pretrained model
#    predict_ds: config of the dataset/dataloader
#    output_path: path to store the predictions
#    return_predictions: whether to return the predictions as output other than writing into the files
#    use_cer: whether to calculate the error in terms of CER or use the default WER
#
# Results of each GPU/worker is written into a file named 'predictions_{rank}.json, and aggregated results of all workers are written into 'predictions_all.json'

Example for non-tarred datasets:

python transcribe_speech_parallel.py \
    model=stt_en_conformer_ctc_large \
    predict_ds.manifest_filepath=/dataset/manifest_file.json \
    predict_ds.batch_size=16 \
    output_path=/tmp/

Example for Hybrid-CTC/RNNT models with non-tarred datasets:

python transcribe_speech_parallel.py \
    model=stt_en_fastconformer_hybrid_large \
    decoder_type=ctc \
    predict_ds.manifest_filepath=/dataset/manifest_file.json \
    predict_ds.batch_size=16 \
    output_path=/tmp/

Example for tarred datasets:

python transcribe_speech_parallel.py \
    predict_ds.is_tarred=true \
    predict_ds.manifest_filepath=/tarred_dataset/tarred_audio_manifest.json \
    predict_ds.tarred_audio_filepaths=/tarred_dataset/audio__OP_0..127_CL_.tar \
    ...

By default the trainer uses all the GPUs available and default precision is FP32.
By setting the trainer config you may control these configs. For example to do the predictions with AMP on just two GPUs:

python transcribe_speech_parallel.py \
    trainer.precision=16 \
    trainer.devices=2 \
    ...

You may control the dataloader's config by setting the predict_ds:

python transcribe_speech_parallel.py \
    predict_ds.num_workers=8 \
    predict_ds.min_duration=2.0 \
    predict_ds.sample_rate=16000 \
    model=stt_en_conformer_ctc_small \
    ...

"""

import itertools
import json
import os
from dataclasses import dataclass, field, is_dataclass
from typing import Optional

import lightning.pytorch as ptl
import torch
from omegaconf import MISSING, OmegaConf

from nemo.collections.asr.data.audio_to_text_dataset import ASRPredictionWriter
from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel
from nemo.collections.asr.models.configs import ASRDatasetConfig
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyBatchedRNNTInferConfig
from nemo.core.config import TrainerConfig, hydra_runner
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero


@dataclass
class ParallelTranscriptionConfig:
    model: Optional[str] = None  # name
    predict_ds: ASRDatasetConfig = field(
        default_factory=lambda: ASRDatasetConfig(return_sample_id=True, num_workers=4, min_duration=0, max_duration=40)
    )
    output_path: str = MISSING

    # when return_predictions is enabled, the prediction call would keep all the predictions in memory and return them when prediction is done
    return_predictions: bool = False
    use_cer: bool = False

    # decoding strategy for RNNT models
    # Double check whether fused_batch_size=-1 is right
    rnnt_decoding: RNNTDecodingConfig = field(default_factory=lambda: RNNTDecodingConfig(fused_batch_size=-1))

    # Decoding strategy for CTC models
    ctc_decoding: CTCDecodingConfig = field(default_factory=CTCDecodingConfig)

    # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models
    decoder_type: Optional[str] = None
    # att_context_size can be set for cache-aware streaming models with multiple look-aheads
    att_context_size: Optional[list] = None

    trainer: TrainerConfig = field(
        default_factory=lambda: TrainerConfig(devices=-1, accelerator="gpu", strategy="ddp")
    )


def match_train_config(predict_ds, train_ds):
    # It copies the important configurations from the train dataset of the model
    # into the predict_ds to be used for prediction. It is needed to match the training configurations.
    if train_ds is None:
        return

    predict_ds.sample_rate = train_ds.get("sample_rate", 16000)
    cfg_name_list = [
        "int_values",
        "use_start_end_token",
        "blank_index",
        "unk_index",
        "normalize",
        "parser",
        "eos_id",
        "bos_id",
        "pad_id",
    ]

    if is_dataclass(predict_ds):
        predict_ds = OmegaConf.structured(predict_ds)
    for cfg_name in cfg_name_list:
        if hasattr(train_ds, cfg_name):
            setattr(predict_ds, cfg_name, getattr(train_ds, cfg_name))

    return predict_ds


@hydra_runner(config_name="TranscriptionConfig", schema=ParallelTranscriptionConfig)
def main(cfg: ParallelTranscriptionConfig):
    if cfg.model.endswith(".nemo"):
        logging.info("Attempting to initialize from .nemo file")
        model = ASRModel.restore_from(restore_path=cfg.model, map_location="cpu")
    elif cfg.model.endswith(".ckpt"):
        logging.info("Attempting to initialize from .ckpt file")
        model = ASRModel.load_from_checkpoint(checkpoint_path=cfg.model, map_location="cpu")
    else:
        logging.info(
            "Attempting to initialize from a pretrained model as the model name does not have the extension of .nemo or .ckpt"
        )
        model = ASRModel.from_pretrained(model_name=cfg.model, map_location="cpu")

    # Setup decoding strategy
    if hasattr(model, 'change_decoding_strategy') and hasattr(model, 'decoding'):
        if cfg.decoder_type is not None:
            decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding
            if hasattr(model, 'cur_decoder'):
                model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type)
            else:
                model.change_decoding_strategy(decoding_cfg)

        # Check if ctc or rnnt model
        elif hasattr(model, 'joint'):  # RNNT model
            model.change_decoding_strategy(cfg.rnnt_decoding)
        else:
            model.change_decoding_strategy(cfg.ctc_decoding)

    cfg.predict_ds.return_sample_id = True
    cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds)

    if cfg.predict_ds.use_lhotse:
        OmegaConf.set_struct(cfg.predict_ds, False)
        cfg.trainer.use_distributed_sampler = False
        cfg.predict_ds.force_finite = True
        cfg.predict_ds.force_map_dataset = True
        cfg.predict_ds.do_transcribe = True
        OmegaConf.set_struct(cfg.predict_ds, True)

    if isinstance(model, EncDecMultiTaskModel):
        cfg.trainer.use_distributed_sampler = False
        OmegaConf.set_struct(cfg.predict_ds, False)
        cfg.predict_ds.use_lhotse = True
        cfg.predict_ds.lang_field = "target_lang"
        OmegaConf.set_struct(cfg.predict_ds, True)

    trainer = ptl.Trainer(**cfg.trainer)

    if cfg.predict_ds.use_lhotse:
        OmegaConf.set_struct(cfg.predict_ds, False)
        cfg.predict_ds.global_rank = trainer.global_rank
        cfg.predict_ds.world_size = trainer.world_size
        OmegaConf.set_struct(cfg.predict_ds, True)

    data_loader = model._setup_dataloader_from_config(cfg.predict_ds)

    os.makedirs(cfg.output_path, exist_ok=True)
    # trainer.global_rank is not valid before predict() is called. Need this hack to find the correct global_rank.
    global_rank = trainer.node_rank * trainer.num_devices + int(os.environ.get("LOCAL_RANK", 0))
    output_file = os.path.join(cfg.output_path, f"predictions_{global_rank}.json")
    predictor_writer = ASRPredictionWriter(dataset=data_loader.dataset, output_file=output_file)
    trainer.callbacks.extend([predictor_writer])

    predictions = trainer.predict(model=model, dataloaders=data_loader, return_predictions=cfg.return_predictions)
    if predictions is not None:
        predictions = list(itertools.chain.from_iterable(predictions))
    samples_num = predictor_writer.close_output_file()

    logging.info(
        f"Prediction on rank {global_rank} is done for {samples_num} samples and results are stored in {output_file}."
    )

    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    samples_num = 0
    pred_text_list = []
    text_list = []
    if is_global_rank_zero():
        output_file = os.path.join(cfg.output_path, f"predictions_all.json")
        logging.info(f"Prediction files are being aggregated in {output_file}.")
        with open(output_file, 'w') as outf:
            for rank in range(trainer.world_size):
                input_file = os.path.join(cfg.output_path, f"predictions_{rank}.json")
                with open(input_file, 'r') as inpf:
                    lines = inpf.readlines()
                    for line in lines:
                        item = json.loads(line)
                        pred_text_list.append(item["pred_text"])
                        text_list.append(item["text"])
                        outf.write(json.dumps(item) + "\n")
                        samples_num += 1
        wer_cer = word_error_rate(hypotheses=pred_text_list, references=text_list, use_cer=cfg.use_cer)
        logging.info(
            f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}."
        )
        logging.info("{} for all predictions is {:.4f}.".format("CER" if cfg.use_cer else "WER", wer_cer))


if __name__ == '__main__':
    main()