import argparse import os import re import math import torch from torch.nn.attention import sdpa_kernel, SDPBackend from transformers import ( AutoConfig, AutoModelForTextToWaveform, AutoModelForTDT, AutoModelForSpeechSeq2Seq, AutoProcessor, CompileConfig, ) from transformers.audio_utils import load_audio import evaluate from tqdm import tqdm import random import numpy as np import pandas as pd from datasets import load_dataset OUTPUT_PATH = "results.csv" def upsert(record: dict, file_path: str = OUTPUT_PATH): """ Insert or update a record in the stored DataFrame. Uses 'model_id' as the unique key — updates the row if found, appends if not. """ assert "model_id" in record, "'model_id' key is required in the record dict." if os.path.exists(file_path): df = pd.read_csv(file_path) if file_path.endswith(".csv") else pd.read_excel(file_path) else: df = pd.DataFrame() new_row = pd.DataFrame([record]) if not df.empty and "model_id" in df.columns and record["model_id"] in df["model_id"].values: # Update existing row df.set_index("model_id", inplace=True) new_row.set_index("model_id", inplace=True) df.update(new_row) # updates only fields present in new_row df = df.reindex(columns=df.columns) # preserve column order df.reset_index(inplace=True) print(f"Updated model_id='{record['model_id']}'") else: # Append new row df = pd.concat([df, new_row], ignore_index=True) print(f"Inserted model_id='{record['model_id']}'") if file_path.endswith(".csv"): df.to_csv(file_path, index=False) else: df.to_excel(file_path, index=False) return df wer_metric = evaluate.load("wer") torch.set_float32_matmul_precision('high') DTYPE_BYTES = { torch.float32: 4, torch.float16: 2, torch.bfloat16: 2, torch.int8: 1, torch.int4: 0.5, } def get_free_gpu_memory_bytes(device: int = 0) -> int: if not torch.cuda.is_available(): raise RuntimeError("No CUDA device found.") torch.cuda.synchronize(device) free, _ = torch.cuda.mem_get_info(device) return free def infer_batch_size( model: torch.nn.Module, longest_input_length: int, dtype: torch.dtype = torch.bfloat16, usable_fraction: float = 0.70, ) -> int: """ Estimate a safe batch size for generation. Budget: 75% of total GPU memory for everything (weights + activations + batch). The remaining 25% is left untouched as a buffer for KV cache growth, CUDA kernels, and other overhead. Per-sample activation cost is estimated as ~2 bytes * params^0.6 * seq_len, a empirically-derived heuristic that's dtype-agnostic and arch-agnostic. Args: model: Any nn.Module (HF, custom, etc.) longest_input_length: max(len(input_ids)) across your dataset. dtype: Dtype the model is loaded in. usable_fraction: Fraction of *total* VRAM to budget for everything. Default 0.75 — leaves 25% for cache and other overhead. Returns: Recommended batch size (>= 1). """ if isinstance(dtype, str): dtype = getattr(torch, dtype) bpe = DTYPE_BYTES.get(dtype) if bpe is None: raise ValueError(f"Unsupported dtype: {dtype}") # --- Total VRAM budget --- free_bytes = get_free_gpu_memory_bytes(0) _, total_bytes = torch.cuda.mem_get_info(0) budget_bytes = total_bytes * usable_fraction print(f"[gpu] Total VRAM : {total_bytes / 1e9:.2f} GB") print(f"[gpu] Free VRAM : {free_bytes / 1e9:.2f} GB") print(f"[gpu] Usable budget : {budget_bytes / 1e9:.2f} GB ({usable_fraction*100:.0f}% of total)") # --- Model weights --- num_params = sum(p.numel() for p in model.parameters()) model_bytes = num_params * bpe print(f"[model] Params : {num_params / 1e9:.3f}B") print(f"[model] Weight memory : {model_bytes / 1e9:.2f} GB (dtype={dtype})") remaining_bytes = budget_bytes - model_bytes if remaining_bytes <= 0: raise RuntimeError( f"Model weights alone ({model_bytes/1e9:.2f} GB) exceed the usable budget " f"({budget_bytes/1e9:.2f} GB). Try a smaller model or lower usable_fraction." ) # --- Per-sample activation cost (arch-agnostic heuristic) --- # Activations ā‰ˆ f(params, seq_len). Empirically, hidden states + attention # buffers scale roughly as params^0.6 per token across diverse architectures. # Multiply by 2 bytes as a base unit (independent of weight dtype, since # activations are often kept in fp16/bf16 regardless). bytes_per_token = 2 * (num_params ** 0.6) bytes_per_sample = bytes_per_token * longest_input_length print(f"[input] longest_input_length : {longest_input_length} tokens") print(f"[input] Est. bytes/sample : {bytes_per_sample / 1e6:.1f} MB") # --- Batch size --- batch_size = max(1, math.floor(remaining_bytes / bytes_per_sample)) print(f"\nāœ… Recommended batch_size : {batch_size}") return batch_size def main(args): # Set seed due to randomness in some models (e.g. VibeVoice's acoustic tokenizer sampling) seed = 42 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch_dtype = getattr(torch, args.dtype) config = AutoConfig.from_pretrained(args.model_id, revision=args.revision) if "dia" in config.model_type: model = AutoModelForSpeechSeq2Seq.from_pretrained( args.model_id, dtype=torch_dtype, device_map=args.device, attn_implementation=args.attn_implementation, ) else: model = AutoModelForTextToWaveform.from_pretrained( args.model_id, dtype=torch_dtype, device_map=args.device, attn_implementation=args.attn_implementation, ) num_params = sum(p.numel() for p in model.parameters()) / 1e9 print(f"Model size: {num_params:.2f}B parameters") processor_kwargs = {"device_map": args.device} if "higgs" in config.model_type else {} processor = AutoProcessor.from_pretrained(args.model_id, revision=args.revision, **processor_kwargs) # Set generate arguments if model.can_generate(): gen_kwargs = {"max_new_tokens": args.max_new_tokens} if "higgs" not in config.model_type: gen_kwargs["min_new_tokens"] = args.max_new_tokens if "csm" in config.model_type: gen_kwargs["output_audio"] = True elif args.max_new_tokens: raise ValueError("`max_new_tokens` should only be set for auto-regressive models, but got a non-generative model.") if args.torch_compile is not None: if model.can_generate(): gen_kwargs["compile_config"] = CompileConfig(mode=args.torch_compile, fullgraph=args.compile_fullgraph) # enable static k/v cache for autoregressive models model.generation_config.cache_implementation = "static" else: model = torch.compile(model, mode=args.torch_compile, fullgraph=args.compile_fullgraph) # Ensure warm-up runs when using torch.compile if args.warmup_steps is None or args.warmup_steps < 1: print("`--torch_compile` is enabled; forcing `--warmup_steps=10` to trigger compilation before timed runs.") args.warmup_steps = 10 def benchmark(batch, text_column): # Load audio inputs texts_to_generate = batch[text_column] minibatch_size = len(texts_to_generate) sampling_rate = 16_000 if hasattr(processor.feature_extractor, "sampling_rate"): sampling_rate = processor.feature_extractor.sampling_rate # START TIMING torch.cuda.synchronize(device=args.device) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() # 1. Pre-Processing # 1.1 Pad audios to max batch size if using torch compile to prevent re-compilations padding_size = None if minibatch_size != args.batch_size and args.torch_compile is not None: padding_size = args.batch_size - minibatch_size padding_duplicate = [texts_to_generate[-1] for _ in range(padding_size)] texts_to_generate.extend(padding_duplicate) # Apply jinja template if processor has smth saved if getattr(processor, "chat_template") is not None: # CSM uses speaekr ID and not role in conv if "csm" in config.model_type: texts_to_generate = [ processor.apply_chat_template( [ { "role": "0", "content": [{"type": "text", "text": text}], } ], tokenize=False, add_generation_prompt=True, return_dict=False, ) for text in texts_to_generate ] elif "higgs" in config.model_type: inputs = processor.apply_chat_template( [[{ "role": "system", "content": [ { "type": "text", "text": "Generate audio following instruction." } ], }, { "role": "user", "content": text, }] for text in texts_to_generate], tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", sampling_rate=24000, ) else: texts_to_generate = [ processor.apply_chat_template( [ { "role": "user", "content": text, } ], tokenize=False, add_generation_prompt=True, return_dict=False, ) for text in texts_to_generate ] if "higgs" not in config.model_type: inputs = processor(text=texts_to_generate, return_tensors="pt") inputs = inputs.to(args.device) prompt_len = inputs["input_ids"].shape[1] # 2. Model Inference if args.torch_compile is not None: sdpa_backends = [SDPBackend.MATH] else: sdpa_backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] with sdpa_kernel(sdpa_backends): pred_waveform = model.generate(**inputs, **gen_kwargs) # 3. Post-processing # 3.1 Strip padded ids from predictions if padding_size is not None: pred_waveform = pred_waveform[:-padding_size, ...] # 3.2 Convert token ids to text transcription if config.model_type == 'dia': prompt_len = processor.get_audio_prompt_len(inputs["decoder_attention_mask"]) outputs = processor.batch_decode(pred_waveform, audio_prompt_len=prompt_len) elif "higgs" in config.model_type: outputs = processor.batch_decode(pred_waveform) else: outputs = pred_waveform # END TIMING end_event.record() torch.cuda.synchronize(device=args.device) runtime = start_event.elapsed_time(end_event) / 1000.0 # normalize by minibatch size since we want the per-sample time batch["generation_time_s"] = minibatch_size * [runtime / minibatch_size] gen_paths = [] audio_length_s = [] os.makedirs(f"results_{args.model_id}", exist_ok=True) for audio, i in zip(outputs, batch['id']): try: processor.save_audio(audio, saving_path=f"results_{args.model_id}/output_{i}.wav") gen_paths.append(f"results_{args.model_id}/output_{i}.wav") audio_length_s.append(len(audio) / sampling_rate) except: # Prob the processor is not yet standard pass batch["predictions"] = gen_paths batch["input_text"] = [sample.lower() for sample in texts_to_generate] batch["audio_length_s"] = audio_length_s return batch dataset = load_dataset(args.dataset_path, split=args.split) dataset = dataset.add_column("id", list(range(len(dataset)))) # Pass id for easier reference when batch mapping # Infer batch size from model param count and longest input text in the dataset. Reserv 25% of VRAM for cache and overhead def add_input_token_length(batch, text_column): input_ids = processor.tokenizer(batch[text_column], return_tensors=None, padding=False, truncation=False).input_ids batch["input_token_length"] = [len(ids) for ids in input_ids] return batch # TODO: # dataset = dataset.map( # add_input_token_length, batch_size=args.batch_size, batched=True, fn_kwargs={"text_column": args.text_column} # ) # longest_input_length = max(dataset["input_token_length"]) # inferred_bs = infer_batch_size( # model=model, # longest_input_length=longest_input_length, # dtype=args.dtype, # ) if args.warmup_steps is not None: num_warmup_samples = args.warmup_steps * args.batch_size warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset)))) warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"text_column": args.text_column})) for _ in tqdm(warmup_dataset, desc="Warming up..."): continue if args.max_eval_samples is not None and args.max_eval_samples > 0: print(f"Subsampling dataset to first {args.max_eval_samples} samples!") dataset = dataset.select(range(min(args.max_eval_samples, len(dataset)))) dataset = dataset.map( benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"text_column": args.text_column} ) asr_processor = AutoProcessor.from_pretrained("nvidia/parakeet-tdt-0.6b-v3") asr_model = AutoModelForTDT.from_pretrained("nvidia/parakeet-tdt-0.6b-v3", device_map="auto") def transcribe_audio(batch) -> list[str]: """ Takes in minibatch of audio paths and returns transcribed text per each file. Each audio is transctibed with ParakeetTDT model. """ sr_rate = asr_processor.feature_extractor.sampling_rate speech_samples = [load_audio(path, sampling_rate=sr_rate) for path in batch["predictions"]] inputs = asr_processor(speech_samples, sampling_rate=sr_rate).to(asr_model.device, dtype=asr_model.dtype) outputs = asr_model.generate(**inputs, return_dict_in_generate=False) outputs = asr_processor.batch_decode(outputs.sequences, skip_special_tokens=True) batch["asr_outputs"] = [sample.lower() for sample in outputs] return batch dataset = dataset.map(transcribe_audio, batch_size=args.batch_size, batched=True) all_results = { "audio_length_s": [], "generation_time_s": [], "predictions": [], "input_text": [], "asr_outputs": [], } result_iter = iter(dataset) for result in tqdm(result_iter, desc="Samples..."): for key in all_results: all_results[key].append(result[key]) # Write manifest results (WER and RTFX) # Filtering of empty references is handled inside write_manifest. # manifest_path = data_utils.write_manifest( # all_results["references"], # all_results["predictions"], # args.model_id, # args.dataset_path, # args.dataset, # args.split, # audio_length=all_results["audio_length_s"], # transcription_time=all_results["transcription_time_s"], # audio_filepaths=all_results["audio_filepath"], # ) # print("Results saved at path:", os.path.abspath(manifest_path)) wer = wer_metric.compute( references=all_results["input_text"], predictions=all_results["asr_outputs"] ) wer = round(100 * wer, 2) rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["generation_time_s"]), 2) print("WER:", wer, "%", "RTFx:", rtfx) return {"model_id": args.model_id, "WER": wer, "RTFX": rtfx} if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--model_id", type=str, required=True, help="Model identifier. Should be loadable with šŸ¤— Transformers", ) parser.add_argument( "--dataset_path", type=str, default="hf-audio/tts_leaderboard", help="Dataset path. By default, it is `hf-audio/tts_leaderboard`", ) parser.add_argument( "--split", type=str, default="train", help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.", ) parser.add_argument( "--text_column", type=str, default="text", help="Name of the column corresponding to the text that has to be generated in dataset with the given `split`.", ) parser.add_argument( "--device", type=str, default="cuda", help="The device to run the pipeline on. `auto` for auto-inferring, 'cpu' for CPU, 'cuda' for the GPU (default).", ) parser.add_argument( "--batch_size", type=int, default=32, help="Number of samples to go through each streamed batch.", ) parser.add_argument( "--max_eval_samples", type=int, default=None, help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", ) parser.add_argument( "--max_new_tokens", type=int, default=8192, help="Maximum number of tokens to generate (for auto-regressive models).", ) parser.add_argument( "--torch_compile", type=str, default=None, help="Mode for torch compiling model forward pass. Can be either 'default', 'reduce-overhead', 'max-autotune' or 'max-autotune-no-cudagraphs'.", ) parser.add_argument( "--compile_fullgraph", action="store_true", help="Whether to do full graph compilation.", ) parser.add_argument( "--dtype", type=str, default="bfloat16", help="The dtype to use for model loading and inference. E.g. 'bfloat16', 'float16', 'float32'.", ) parser.add_argument( "--attn_implementation", type=str, default="sdpa", help="Attention implementation to use for model loading (e.g. 'sdpa', 'eager', 'flash_attention_2').", ) parser.add_argument( "--warmup_steps", type=int, default=0, help="Number of warm-up steps to run before launching the timed runs.", ) parser.add_argument( "--revision", type=str, default=None, help="Model revision to use (e.g. 'refs/pr/11' for a PR branch). Defaults to the main branch.", ) args = parser.parse_args() print("*" * 100) print(f"Evaluating {args.model_id} on {args.dataset_path} / {args.split}") print("*" * 100) output_dict = main(args) upsert(output_dict)