| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import ast |
| | import math |
| | from functools import partial |
| | from itertools import islice |
| | from pathlib import Path |
| | from typing import Callable, Iterable |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | from lhotse.cut import Cut |
| | from omegaconf import OmegaConf |
| |
|
| | import nemo.collections.speechlm2.data.salm_dataset |
| | from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper |
| | from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config |
| | from nemo.collections.common.data.lhotse.dataloader import LhotseDataLoadingConfig, tokenize, tokenize_with_prompt |
| | from nemo.collections.common.data.lhotse.sampling import ( |
| | MultimodalFixedBucketBatchSizeConstraint2D, |
| | MultimodalSamplingConstraint, |
| | TokenCountFilter, |
| | TokenPerTokenFilter, |
| | ) |
| | from nemo.collections.common.prompts.formatter import PromptFormatter |
| | from nemo.collections.common.tokenizers import AggregateTokenizer, AutoTokenizer, SentencePieceTokenizer |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser( |
| | description="Estimate token bins for Lhotse dynamic bucketing using a sample of the input dataset. " |
| | "The dataset is read either from one or more manifest files and supports data weighting. " |
| | "Unlike estimate_duration_bins.py, this script is intended for text data only. " |
| | "It supports 2D bucketing. ", |
| | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | ) |
| | parser.add_argument( |
| | "input", |
| | help='Path to a data input configuration YAML file. ' |
| | 'This is the only type of input specification supported for text data.', |
| | ) |
| | parser.add_argument( |
| | "-t", |
| | "--tokenizer", |
| | nargs="+", |
| | required=True, |
| | help="Path to one or more SPE tokenizers. More than one means we'll use AggregateTokenizer and --langs argument must also be used. When provided, we'll estimate a 2D distribution for input and output sequence lengths.", |
| | ) |
| | parser.add_argument( |
| | "-a", "--langs", nargs="+", help="Language names for each of AggregateTokenizer sub-tokenizers." |
| | ) |
| | parser.add_argument( |
| | "-b", |
| | "--buckets", |
| | type=int, |
| | default=30, |
| | help="The desired number of buckets (dim0 => covers input sequence length / audio duration).", |
| | ) |
| | parser.add_argument( |
| | "-s", |
| | "--sub-buckets", |
| | type=int, |
| | default=None, |
| | help="The desired number of sub-buckets (dim1 => covers output sequence length / num_tokens). " |
| | "If not provided, we'll only perform 1D bucketing. ", |
| | ) |
| | parser.add_argument( |
| | "-n", |
| | "--num_examples", |
| | type=int, |
| | default=-1, |
| | help="The number of examples (utterances) to estimate the bins. -1 means use all data " |
| | "(be careful: it could be iterated over infinitely).", |
| | ) |
| | parser.add_argument( |
| | "-l", |
| | "--min_tokens", |
| | type=float, |
| | default=-float("inf"), |
| | help="If specified, we'll filter out examples with less tokens than this number.", |
| | ) |
| | parser.add_argument( |
| | "-u", |
| | "--max_tokens", |
| | type=float, |
| | default=float("inf"), |
| | help="If specified, we'll filter out examples with more tokens than this number.", |
| | ) |
| | parser.add_argument( |
| | "--max_tpt", |
| | type=float, |
| | default=float("inf"), |
| | help="If specified, we'll filter out examples with more output tokens per input token than this. ", |
| | ) |
| | parser.add_argument( |
| | "-q", "--quiet", type=bool, default=False, help="When specified, only print the estimated duration bins." |
| | ) |
| | parser.add_argument( |
| | "-f", |
| | "--prompt-format", |
| | type=str, |
| | help="When specified, we'll use a prompt formatter in addition to the tokenizer for the purpose of estimating token count bins. " |
| | "This is useful for accurate 2D bucket estimation with models such as EncDecMultiTaskModel (Canary-1B), " |
| | "or any model where the label sequence consists of a user prompt and a model's response.", |
| | ) |
| | parser.add_argument( |
| | "-p", |
| | "--prompt", |
| | type=str, |
| | help="Prompt slots provided as a Python list of dicts. It is used together with --prompt-format option." |
| | "For example, with Canary-1B you may use: [{'role':'user','slots':{'source_lang':'en','target_lang':'en','task':'asr','pnc':'yes'}]", |
| | ) |
| | parser.add_argument( |
| | "-m", |
| | "--measure-total-length", |
| | type=bool, |
| | default=False, |
| | help="When specified, we'll measure the total length (context+answer, i.e. input_ids) instead of context-only length. Total length is more suitable for decoder-only models while context-only length is more suitable for encoder-decoder models.", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def estimate_token_buckets( |
| | cuts: Iterable[Cut], |
| | num_buckets: int, |
| | num_subbuckets: int | None, |
| | quiet: bool, |
| | ) -> list[tuple[float, float]]: |
| | """ |
| | This function is based on lhotse.dataset.sampling.dynamic_bucketing.estimate_duration_buckets. |
| | It extends it to a 2D bucketing case. |
| | """ |
| | assert num_buckets > 1 |
| | is_2d = num_subbuckets is not None |
| |
|
| | if is_2d: |
| | constraint = MultimodalFixedBucketBatchSizeConstraint2D([(0.0, 0.0)], [0], measure_total_length=False) |
| | else: |
| | constraint = MultimodalSamplingConstraint(measure_total_length=True) |
| |
|
| | |
| | num_input_tokens = [] |
| | num_output_tokens = [] |
| | for c in cuts: |
| | ans = constraint.measure_length(c) |
| | if is_2d: |
| | itoks, otoks = ans |
| | num_input_tokens.append(itoks) |
| | num_output_tokens.append(otoks) |
| | else: |
| | num_input_tokens.append(ans) |
| | num_input_tokens = np.array(num_input_tokens, dtype=np.int32) |
| | if is_2d: |
| | num_output_tokens = np.array(num_output_tokens, dtype=np.int32) |
| | joint = np.rec.fromarrays([num_input_tokens, num_output_tokens]) |
| | joint.sort() |
| | num_input_tokens = joint.f0 |
| | num_output_tokens = joint.f1 |
| | else: |
| | num_input_tokens.sort() |
| |
|
| | |
| | |
| | size_per_bucket = num_input_tokens.sum() / num_buckets |
| |
|
| | if not quiet: |
| | print("Duration distribution:") |
| | print(pd.Series(num_input_tokens).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) |
| | max_input_tokens = num_input_tokens[-1] |
| |
|
| | if is_2d: |
| | tpt = num_output_tokens / num_input_tokens |
| | if not quiet: |
| | print("Output tokens per input token distribution:") |
| | print(pd.Series(tpt).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) |
| | max_tpt = tpt.max() |
| | del tpt |
| |
|
| | bins = [] |
| | bin_indexes = [0] |
| | tot = 0.0 |
| |
|
| | def _estimate_output_token_buckets(max_bucket_duration): |
| | |
| | |
| | |
| | |
| | |
| | nonlocal bins |
| | num_tokens_bucket = num_output_tokens[bin_indexes[-1] : binidx] |
| | num_tokens_bucket.sort() |
| | tokens_per_subbucket = num_tokens_bucket.sum() / num_subbuckets |
| | tot_toks = 0 |
| | |
| | for num_toks in num_tokens_bucket: |
| | |
| | if tot_toks > tokens_per_subbucket: |
| | bins.append((max_bucket_duration, num_toks)) |
| | tot_toks = 0 |
| | tot_toks += num_toks |
| | bins.append((size, math.ceil(size * max_tpt))) |
| |
|
| | |
| | for binidx, size in enumerate(num_input_tokens): |
| | if tot > size_per_bucket: |
| | |
| | if is_2d: |
| | _estimate_output_token_buckets(max_bucket_duration=size) |
| | else: |
| | bins.append(size) |
| | tot = 0.0 |
| | tot += size |
| |
|
| | |
| | if num_subbuckets is not None: |
| | if is_2d: |
| | _estimate_output_token_buckets(max_bucket_duration=max_input_tokens) |
| | else: |
| | bins.append(max_input_tokens) |
| |
|
| | return bins |
| |
|
| |
|
| | def load_tokenizer(paths: list[str], langs: list[str] = None) -> TokenizerWrapper: |
| | if len(paths) == 1: |
| | (p,) = paths |
| | if Path(p).exists(): |
| | tok = SentencePieceTokenizer(p) |
| | else: |
| | |
| | tok = AutoTokenizer(p, use_fast=True) |
| | else: |
| | assert langs is not None and len(paths) == len( |
| | langs |
| | ), f"Cannot create AggregateTokenizer; each tokenizer must have assigned a language via --langs option (we got --tokenizers={paths} and --langs={langs})" |
| | tok = AggregateTokenizer({lang: SentencePieceTokenizer(p) for lang, p in zip(langs, paths)}) |
| | return TokenizerWrapper(tok) |
| |
|
| |
|
| | def apply_tokenizer(cut, tokenizer=None, prompt: PromptFormatter = None): |
| | if prompt is not None: |
| | cut = tokenize_with_prompt(cut, tokenizer, prompt) |
| | elif tokenizer is not None: |
| | cut = tokenize(cut, tokenizer) |
| | return cut |
| |
|
| |
|
| | class RejectionsCounter: |
| | def __init__(self, predicate: Callable, message: str): |
| | self.predicate = predicate |
| | self.message = message |
| | self.total = 0 |
| | self.rejected = 0 |
| |
|
| | def __call__(self, example) -> bool: |
| | ans = self.predicate(example) |
| | self.total += 1 |
| | if not ans: |
| | self.rejected += 1 |
| | return ans |
| |
|
| | def print_report(self) -> None: |
| | if self.rejected: |
| | print(f"{self.message} | Rejected {self.rejected}/{self.total} examples.") |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| |
|
| | if not args.quiet: |
| | pd.set_option('display.float_format', lambda x: '%.2f' % x) |
| |
|
| | tokenizer = None |
| | prompt = None |
| | if args.tokenizer is not None: |
| | tokenizer = load_tokenizer(args.tokenizer, args.langs) |
| | if args.prompt_format is not None: |
| | prompt_defaults = None |
| | if args.prompt is not None: |
| | prompt_defaults = ast.literal_eval(args.prompt) |
| | prompt = PromptFormatter.resolve(args.prompt_format)(tokenizer._tokenizer, defaults=prompt_defaults) |
| |
|
| | assert args.input.endswith(".yaml") |
| | config = OmegaConf.merge( |
| | OmegaConf.structured(LhotseDataLoadingConfig), |
| | OmegaConf.from_dotlist([f"input_cfg={args.input}", "force_finite=True", "metadata_only=True"]), |
| | ) |
| | cuts, _ = read_cutset_from_config(config) |
| | cuts = cuts.map(partial(apply_tokenizer, tokenizer=tokenizer, prompt=prompt), apply_fn=None) |
| | if hasattr(cuts, "prefetch"): |
| | cuts = cuts.prefetch() |
| | token_filter = RejectionsCounter( |
| | TokenCountFilter(args.min_tokens, args.max_tokens, args.measure_total_length), "Token count filtering" |
| | ) |
| | cuts = cuts.filter(token_filter) |
| | tpt_filter = RejectionsCounter(TokenPerTokenFilter(-1, args.max_tpt), "Output tokens per input token filtering") |
| | cuts = cuts.filter(tpt_filter) |
| | if (N := args.num_examples) > 0: |
| | cuts = islice(cuts, N) |
| |
|
| | token_bins = estimate_token_buckets( |
| | cuts, |
| | num_buckets=args.buckets, |
| | num_subbuckets=args.sub_buckets, |
| | quiet=args.quiet, |
| | ) |
| | if args.sub_buckets is not None: |
| | token_bins = "[" + ','.join(f"[{b:d},{sb:d}]" for b, sb in token_bins) + "]" |
| | else: |
| | token_bins = "[" + ','.join(f"{b:d}" for b in token_bins) + "]" |
| | if args.quiet: |
| | print(token_bins) |
| | return |
| | token_filter.print_report() |
| | tpt_filter.print_report() |
| | print("Use the following options in your config:") |
| | print(f"\tnum_buckets={args.buckets}") |
| | print(f"\tbucket_duration_bins={token_bins}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|