Spaces:
Paused
Paused
| import argparse | |
| import os | |
| import re | |
| import math | |
| import torch | |
| from torch.nn.attention import sdpa_kernel, SDPBackend | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForTextToWaveform, | |
| AutoModelForTDT, | |
| AutoModelForSpeechSeq2Seq, | |
| AutoProcessor, | |
| CompileConfig, | |
| ) | |
| from transformers.audio_utils import load_audio | |
| import evaluate | |
| from tqdm import tqdm | |
| import random | |
| import numpy as np | |
| import pandas as pd | |
| from datasets import load_dataset | |
| OUTPUT_PATH = "results.csv" | |
| def upsert(record: dict, file_path: str = OUTPUT_PATH): | |
| """ | |
| Insert or update a record in the stored DataFrame. | |
| Uses 'model_id' as the unique key — updates the row if found, appends if not. | |
| """ | |
| assert "model_id" in record, "'model_id' key is required in the record dict." | |
| if os.path.exists(file_path): | |
| df = pd.read_csv(file_path) if file_path.endswith(".csv") else pd.read_excel(file_path) | |
| else: | |
| df = pd.DataFrame() | |
| new_row = pd.DataFrame([record]) | |
| if not df.empty and "model_id" in df.columns and record["model_id"] in df["model_id"].values: | |
| # Update existing row | |
| df.set_index("model_id", inplace=True) | |
| new_row.set_index("model_id", inplace=True) | |
| df.update(new_row) # updates only fields present in new_row | |
| df = df.reindex(columns=df.columns) # preserve column order | |
| df.reset_index(inplace=True) | |
| print(f"Updated model_id='{record['model_id']}'") | |
| else: | |
| # Append new row | |
| df = pd.concat([df, new_row], ignore_index=True) | |
| print(f"Inserted model_id='{record['model_id']}'") | |
| if file_path.endswith(".csv"): | |
| df.to_csv(file_path, index=False) | |
| else: | |
| df.to_excel(file_path, index=False) | |
| return df | |
| wer_metric = evaluate.load("wer") | |
| torch.set_float32_matmul_precision('high') | |
| DTYPE_BYTES = { | |
| torch.float32: 4, | |
| torch.float16: 2, | |
| torch.bfloat16: 2, | |
| torch.int8: 1, | |
| torch.int4: 0.5, | |
| } | |
| def get_free_gpu_memory_bytes(device: int = 0) -> int: | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("No CUDA device found.") | |
| torch.cuda.synchronize(device) | |
| free, _ = torch.cuda.mem_get_info(device) | |
| return free | |
| def infer_batch_size( | |
| model: torch.nn.Module, | |
| longest_input_length: int, | |
| dtype: torch.dtype = torch.bfloat16, | |
| usable_fraction: float = 0.70, | |
| ) -> int: | |
| """ | |
| Estimate a safe batch size for generation. | |
| Budget: 75% of total GPU memory for everything (weights + activations + batch). | |
| The remaining 25% is left untouched as a buffer for KV cache growth, CUDA | |
| kernels, and other overhead. | |
| Per-sample activation cost is estimated as ~2 bytes * params^0.6 * seq_len, | |
| a empirically-derived heuristic that's dtype-agnostic and arch-agnostic. | |
| Args: | |
| model: Any nn.Module (HF, custom, etc.) | |
| longest_input_length: max(len(input_ids)) across your dataset. | |
| dtype: Dtype the model is loaded in. | |
| usable_fraction: Fraction of *total* VRAM to budget for everything. | |
| Default 0.75 — leaves 25% for cache and other overhead. | |
| Returns: | |
| Recommended batch size (>= 1). | |
| """ | |
| if isinstance(dtype, str): | |
| dtype = getattr(torch, dtype) | |
| bpe = DTYPE_BYTES.get(dtype) | |
| if bpe is None: | |
| raise ValueError(f"Unsupported dtype: {dtype}") | |
| # --- Total VRAM budget --- | |
| free_bytes = get_free_gpu_memory_bytes(0) | |
| _, total_bytes = torch.cuda.mem_get_info(0) | |
| budget_bytes = total_bytes * usable_fraction | |
| print(f"[gpu] Total VRAM : {total_bytes / 1e9:.2f} GB") | |
| print(f"[gpu] Free VRAM : {free_bytes / 1e9:.2f} GB") | |
| print(f"[gpu] Usable budget : {budget_bytes / 1e9:.2f} GB ({usable_fraction*100:.0f}% of total)") | |
| # --- Model weights --- | |
| num_params = sum(p.numel() for p in model.parameters()) | |
| model_bytes = num_params * bpe | |
| print(f"[model] Params : {num_params / 1e9:.3f}B") | |
| print(f"[model] Weight memory : {model_bytes / 1e9:.2f} GB (dtype={dtype})") | |
| remaining_bytes = budget_bytes - model_bytes | |
| if remaining_bytes <= 0: | |
| raise RuntimeError( | |
| f"Model weights alone ({model_bytes/1e9:.2f} GB) exceed the usable budget " | |
| f"({budget_bytes/1e9:.2f} GB). Try a smaller model or lower usable_fraction." | |
| ) | |
| # --- Per-sample activation cost (arch-agnostic heuristic) --- | |
| # Activations ≈ f(params, seq_len). Empirically, hidden states + attention | |
| # buffers scale roughly as params^0.6 per token across diverse architectures. | |
| # Multiply by 2 bytes as a base unit (independent of weight dtype, since | |
| # activations are often kept in fp16/bf16 regardless). | |
| bytes_per_token = 2 * (num_params ** 0.6) | |
| bytes_per_sample = bytes_per_token * longest_input_length | |
| print(f"[input] longest_input_length : {longest_input_length} tokens") | |
| print(f"[input] Est. bytes/sample : {bytes_per_sample / 1e6:.1f} MB") | |
| # --- Batch size --- | |
| batch_size = max(1, math.floor(remaining_bytes / bytes_per_sample)) | |
| print(f"\n✅ Recommended batch_size : {batch_size}") | |
| return batch_size | |
| def main(args): | |
| # Set seed due to randomness in some models (e.g. VibeVoice's acoustic tokenizer sampling) | |
| seed = 42 | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch_dtype = getattr(torch, args.dtype) | |
| config = AutoConfig.from_pretrained(args.model_id, revision=args.revision) | |
| if "dia" in config.model_type: | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| args.model_id, | |
| dtype=torch_dtype, | |
| device_map=args.device, | |
| attn_implementation=args.attn_implementation, | |
| ) | |
| else: | |
| model = AutoModelForTextToWaveform.from_pretrained( | |
| args.model_id, | |
| dtype=torch_dtype, | |
| device_map=args.device, | |
| attn_implementation=args.attn_implementation, | |
| ) | |
| num_params = sum(p.numel() for p in model.parameters()) / 1e9 | |
| print(f"Model size: {num_params:.2f}B parameters") | |
| processor_kwargs = {"device_map": args.device} if "higgs" in config.model_type else {} | |
| processor = AutoProcessor.from_pretrained(args.model_id, revision=args.revision, **processor_kwargs) | |
| # Set generate arguments | |
| if model.can_generate(): | |
| gen_kwargs = {"max_new_tokens": args.max_new_tokens} | |
| if "higgs" not in config.model_type: | |
| gen_kwargs["min_new_tokens"] = args.max_new_tokens | |
| if "csm" in config.model_type: | |
| gen_kwargs["output_audio"] = True | |
| elif args.max_new_tokens: | |
| raise ValueError("`max_new_tokens` should only be set for auto-regressive models, but got a non-generative model.") | |
| if args.torch_compile is not None: | |
| if model.can_generate(): | |
| gen_kwargs["compile_config"] = CompileConfig(mode=args.torch_compile, fullgraph=args.compile_fullgraph) | |
| # enable static k/v cache for autoregressive models | |
| model.generation_config.cache_implementation = "static" | |
| else: | |
| model = torch.compile(model, mode=args.torch_compile, fullgraph=args.compile_fullgraph) | |
| # Ensure warm-up runs when using torch.compile | |
| if args.warmup_steps is None or args.warmup_steps < 1: | |
| print("`--torch_compile` is enabled; forcing `--warmup_steps=10` to trigger compilation before timed runs.") | |
| args.warmup_steps = 10 | |
| def benchmark(batch, text_column): | |
| # Load audio inputs | |
| texts_to_generate = batch[text_column] | |
| minibatch_size = len(texts_to_generate) | |
| sampling_rate = 16_000 | |
| if hasattr(processor.feature_extractor, "sampling_rate"): | |
| sampling_rate = processor.feature_extractor.sampling_rate | |
| # START TIMING | |
| torch.cuda.synchronize(device=args.device) | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| start_event.record() | |
| # 1. Pre-Processing | |
| # 1.1 Pad audios to max batch size if using torch compile to prevent re-compilations | |
| padding_size = None | |
| if minibatch_size != args.batch_size and args.torch_compile is not None: | |
| padding_size = args.batch_size - minibatch_size | |
| padding_duplicate = [texts_to_generate[-1] for _ in range(padding_size)] | |
| texts_to_generate.extend(padding_duplicate) | |
| # Apply jinja template if processor has smth saved | |
| if getattr(processor, "chat_template") is not None: | |
| # CSM uses speaekr ID and not role in conv | |
| if "csm" in config.model_type: | |
| texts_to_generate = [ | |
| processor.apply_chat_template( | |
| [ | |
| { | |
| "role": "0", | |
| "content": [{"type": "text", "text": text}], | |
| } | |
| ], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| return_dict=False, | |
| ) | |
| for text in texts_to_generate | |
| ] | |
| elif "higgs" in config.model_type: | |
| inputs = processor.apply_chat_template( | |
| [[{ | |
| "role": "system", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "Generate audio following instruction." | |
| } | |
| ], | |
| }, | |
| { | |
| "role": "user", | |
| "content": text, | |
| }] for text in texts_to_generate], | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| sampling_rate=24000, | |
| ) | |
| else: | |
| texts_to_generate = [ | |
| processor.apply_chat_template( | |
| [ | |
| { | |
| "role": "user", | |
| "content": text, | |
| } | |
| ], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| return_dict=False, | |
| ) | |
| for text in texts_to_generate | |
| ] | |
| if "higgs" not in config.model_type: | |
| inputs = processor(text=texts_to_generate, return_tensors="pt") | |
| inputs = inputs.to(args.device) | |
| prompt_len = inputs["input_ids"].shape[1] | |
| # 2. Model Inference | |
| if args.torch_compile is not None: | |
| sdpa_backends = [SDPBackend.MATH] | |
| else: | |
| sdpa_backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] | |
| with sdpa_kernel(sdpa_backends): | |
| pred_waveform = model.generate(**inputs, **gen_kwargs) | |
| # 3. Post-processing | |
| # 3.1 Strip padded ids from predictions | |
| if padding_size is not None: | |
| pred_waveform = pred_waveform[:-padding_size, ...] | |
| # 3.2 Convert token ids to text transcription | |
| if config.model_type == 'dia': | |
| prompt_len = processor.get_audio_prompt_len(inputs["decoder_attention_mask"]) | |
| outputs = processor.batch_decode(pred_waveform, audio_prompt_len=prompt_len) | |
| elif "higgs" in config.model_type: | |
| outputs = processor.batch_decode(pred_waveform) | |
| else: | |
| outputs = pred_waveform | |
| # END TIMING | |
| end_event.record() | |
| torch.cuda.synchronize(device=args.device) | |
| runtime = start_event.elapsed_time(end_event) / 1000.0 | |
| # normalize by minibatch size since we want the per-sample time | |
| batch["generation_time_s"] = minibatch_size * [runtime / minibatch_size] | |
| gen_paths = [] | |
| audio_length_s = [] | |
| os.makedirs(f"results_{args.model_id}", exist_ok=True) | |
| for audio, i in zip(outputs, batch['id']): | |
| try: | |
| processor.save_audio(audio, saving_path=f"results_{args.model_id}/output_{i}.wav") | |
| gen_paths.append(f"results_{args.model_id}/output_{i}.wav") | |
| audio_length_s.append(len(audio) / sampling_rate) | |
| except: | |
| # Prob the processor is not yet standard | |
| pass | |
| batch["predictions"] = gen_paths | |
| batch["input_text"] = [sample.lower() for sample in texts_to_generate] | |
| batch["audio_length_s"] = audio_length_s | |
| return batch | |
| dataset = load_dataset(args.dataset_path, split=args.split) | |
| dataset = dataset.add_column("id", list(range(len(dataset)))) # Pass id for easier reference when batch mapping | |
| # Infer batch size from model param count and longest input text in the dataset. Reserv 25% of VRAM for cache and overhead | |
| def add_input_token_length(batch, text_column): | |
| input_ids = processor.tokenizer(batch[text_column], return_tensors=None, padding=False, truncation=False).input_ids | |
| batch["input_token_length"] = [len(ids) for ids in input_ids] | |
| return batch | |
| # TODO: | |
| # dataset = dataset.map( | |
| # add_input_token_length, batch_size=args.batch_size, batched=True, fn_kwargs={"text_column": args.text_column} | |
| # ) | |
| # longest_input_length = max(dataset["input_token_length"]) | |
| # inferred_bs = infer_batch_size( | |
| # model=model, | |
| # longest_input_length=longest_input_length, | |
| # dtype=args.dtype, | |
| # ) | |
| if args.warmup_steps is not None: | |
| num_warmup_samples = args.warmup_steps * args.batch_size | |
| warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset)))) | |
| warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"text_column": args.text_column})) | |
| for _ in tqdm(warmup_dataset, desc="Warming up..."): | |
| continue | |
| if args.max_eval_samples is not None and args.max_eval_samples > 0: | |
| print(f"Subsampling dataset to first {args.max_eval_samples} samples!") | |
| dataset = dataset.select(range(min(args.max_eval_samples, len(dataset)))) | |
| dataset = dataset.map( | |
| benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"text_column": args.text_column} | |
| ) | |
| asr_processor = AutoProcessor.from_pretrained("nvidia/parakeet-tdt-0.6b-v3") | |
| asr_model = AutoModelForTDT.from_pretrained("nvidia/parakeet-tdt-0.6b-v3", device_map="auto") | |
| def transcribe_audio(batch) -> list[str]: | |
| """ | |
| Takes in minibatch of audio paths and returns transcribed text per each file. | |
| Each audio is transctibed with ParakeetTDT model. | |
| """ | |
| sr_rate = asr_processor.feature_extractor.sampling_rate | |
| speech_samples = [load_audio(path, sampling_rate=sr_rate) for path in batch["predictions"]] | |
| inputs = asr_processor(speech_samples, sampling_rate=sr_rate).to(asr_model.device, dtype=asr_model.dtype) | |
| outputs = asr_model.generate(**inputs, return_dict_in_generate=False) | |
| outputs = asr_processor.batch_decode(outputs.sequences, skip_special_tokens=True) | |
| batch["asr_outputs"] = [sample.lower() for sample in outputs] | |
| return batch | |
| dataset = dataset.map(transcribe_audio, batch_size=args.batch_size, batched=True) | |
| all_results = { | |
| "audio_length_s": [], | |
| "generation_time_s": [], | |
| "predictions": [], | |
| "input_text": [], | |
| "asr_outputs": [], | |
| } | |
| result_iter = iter(dataset) | |
| for result in tqdm(result_iter, desc="Samples..."): | |
| for key in all_results: | |
| all_results[key].append(result[key]) | |
| # Write manifest results (WER and RTFX) | |
| # Filtering of empty references is handled inside write_manifest. | |
| # manifest_path = data_utils.write_manifest( | |
| # all_results["references"], | |
| # all_results["predictions"], | |
| # args.model_id, | |
| # args.dataset_path, | |
| # args.dataset, | |
| # args.split, | |
| # audio_length=all_results["audio_length_s"], | |
| # transcription_time=all_results["transcription_time_s"], | |
| # audio_filepaths=all_results["audio_filepath"], | |
| # ) | |
| # print("Results saved at path:", os.path.abspath(manifest_path)) | |
| wer = wer_metric.compute( | |
| references=all_results["input_text"], predictions=all_results["asr_outputs"] | |
| ) | |
| wer = round(100 * wer, 2) | |
| rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["generation_time_s"]), 2) | |
| print("WER:", wer, "%", "RTFx:", rtfx) | |
| return {"model_id": args.model_id, "WER": wer, "RTFX": rtfx} | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model_id", | |
| type=str, | |
| required=True, | |
| help="Model identifier. Should be loadable with 🤗 Transformers", | |
| ) | |
| parser.add_argument( | |
| "--dataset_path", | |
| type=str, | |
| default="hf-audio/tts_leaderboard", | |
| help="Dataset path. By default, it is `hf-audio/tts_leaderboard`", | |
| ) | |
| parser.add_argument( | |
| "--split", | |
| type=str, | |
| default="train", | |
| help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.", | |
| ) | |
| parser.add_argument( | |
| "--text_column", | |
| type=str, | |
| default="text", | |
| help="Name of the column corresponding to the text that has to be generated in dataset with the given `split`.", | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default="cuda", | |
| help="The device to run the pipeline on. `auto` for auto-inferring, 'cpu' for CPU, 'cuda' for the GPU (default).", | |
| ) | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| default=32, | |
| help="Number of samples to go through each streamed batch.", | |
| ) | |
| parser.add_argument( | |
| "--max_eval_samples", | |
| type=int, | |
| default=None, | |
| help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", | |
| ) | |
| parser.add_argument( | |
| "--max_new_tokens", | |
| type=int, | |
| default=8192, | |
| help="Maximum number of tokens to generate (for auto-regressive models).", | |
| ) | |
| parser.add_argument( | |
| "--torch_compile", | |
| type=str, | |
| default=None, | |
| help="Mode for torch compiling model forward pass. Can be either 'default', 'reduce-overhead', 'max-autotune' or 'max-autotune-no-cudagraphs'.", | |
| ) | |
| parser.add_argument( | |
| "--compile_fullgraph", | |
| action="store_true", | |
| help="Whether to do full graph compilation.", | |
| ) | |
| parser.add_argument( | |
| "--dtype", | |
| type=str, | |
| default="bfloat16", | |
| help="The dtype to use for model loading and inference. E.g. 'bfloat16', 'float16', 'float32'.", | |
| ) | |
| parser.add_argument( | |
| "--attn_implementation", | |
| type=str, | |
| default="sdpa", | |
| help="Attention implementation to use for model loading (e.g. 'sdpa', 'eager', 'flash_attention_2').", | |
| ) | |
| parser.add_argument( | |
| "--warmup_steps", | |
| type=int, | |
| default=0, | |
| help="Number of warm-up steps to run before launching the timed runs.", | |
| ) | |
| parser.add_argument( | |
| "--revision", | |
| type=str, | |
| default=None, | |
| help="Model revision to use (e.g. 'refs/pr/11' for a PR branch). Defaults to the main branch.", | |
| ) | |
| args = parser.parse_args() | |
| print("*" * 100) | |
| print(f"Evaluating {args.model_id} on {args.dataset_path} / {args.split}") | |
| print("*" * 100) | |
| output_dict = main(args) | |
| upsert(output_dict) | |