3outeille's picture
3outeille HF Staff
Sync eval.py
3862b8f verified
Raw
History Blame Contribute Delete
19.9 kB
import argparse
import os
import re
import math
import torch
from torch.nn.attention import sdpa_kernel, SDPBackend
from transformers import (
AutoConfig,
AutoModelForTextToWaveform,
AutoModelForTDT,
AutoModelForSpeechSeq2Seq,
AutoProcessor,
CompileConfig,
)
from transformers.audio_utils import load_audio
import evaluate
from tqdm import tqdm
import random
import numpy as np
import pandas as pd
from datasets import load_dataset
OUTPUT_PATH = "results.csv"
def upsert(record: dict, file_path: str = OUTPUT_PATH):
"""
Insert or update a record in the stored DataFrame.
Uses 'model_id' as the unique key — updates the row if found, appends if not.
"""
assert "model_id" in record, "'model_id' key is required in the record dict."
if os.path.exists(file_path):
df = pd.read_csv(file_path) if file_path.endswith(".csv") else pd.read_excel(file_path)
else:
df = pd.DataFrame()
new_row = pd.DataFrame([record])
if not df.empty and "model_id" in df.columns and record["model_id"] in df["model_id"].values:
# Update existing row
df.set_index("model_id", inplace=True)
new_row.set_index("model_id", inplace=True)
df.update(new_row) # updates only fields present in new_row
df = df.reindex(columns=df.columns) # preserve column order
df.reset_index(inplace=True)
print(f"Updated model_id='{record['model_id']}'")
else:
# Append new row
df = pd.concat([df, new_row], ignore_index=True)
print(f"Inserted model_id='{record['model_id']}'")
if file_path.endswith(".csv"):
df.to_csv(file_path, index=False)
else:
df.to_excel(file_path, index=False)
return df
wer_metric = evaluate.load("wer")
torch.set_float32_matmul_precision('high')
DTYPE_BYTES = {
torch.float32: 4,
torch.float16: 2,
torch.bfloat16: 2,
torch.int8: 1,
torch.int4: 0.5,
}
def get_free_gpu_memory_bytes(device: int = 0) -> int:
if not torch.cuda.is_available():
raise RuntimeError("No CUDA device found.")
torch.cuda.synchronize(device)
free, _ = torch.cuda.mem_get_info(device)
return free
def infer_batch_size(
model: torch.nn.Module,
longest_input_length: int,
dtype: torch.dtype = torch.bfloat16,
usable_fraction: float = 0.70,
) -> int:
"""
Estimate a safe batch size for generation.
Budget: 75% of total GPU memory for everything (weights + activations + batch).
The remaining 25% is left untouched as a buffer for KV cache growth, CUDA
kernels, and other overhead.
Per-sample activation cost is estimated as ~2 bytes * params^0.6 * seq_len,
a empirically-derived heuristic that's dtype-agnostic and arch-agnostic.
Args:
model: Any nn.Module (HF, custom, etc.)
longest_input_length: max(len(input_ids)) across your dataset.
dtype: Dtype the model is loaded in.
usable_fraction: Fraction of *total* VRAM to budget for everything.
Default 0.75 — leaves 25% for cache and other overhead.
Returns:
Recommended batch size (>= 1).
"""
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
bpe = DTYPE_BYTES.get(dtype)
if bpe is None:
raise ValueError(f"Unsupported dtype: {dtype}")
# --- Total VRAM budget ---
free_bytes = get_free_gpu_memory_bytes(0)
_, total_bytes = torch.cuda.mem_get_info(0)
budget_bytes = total_bytes * usable_fraction
print(f"[gpu] Total VRAM : {total_bytes / 1e9:.2f} GB")
print(f"[gpu] Free VRAM : {free_bytes / 1e9:.2f} GB")
print(f"[gpu] Usable budget : {budget_bytes / 1e9:.2f} GB ({usable_fraction*100:.0f}% of total)")
# --- Model weights ---
num_params = sum(p.numel() for p in model.parameters())
model_bytes = num_params * bpe
print(f"[model] Params : {num_params / 1e9:.3f}B")
print(f"[model] Weight memory : {model_bytes / 1e9:.2f} GB (dtype={dtype})")
remaining_bytes = budget_bytes - model_bytes
if remaining_bytes <= 0:
raise RuntimeError(
f"Model weights alone ({model_bytes/1e9:.2f} GB) exceed the usable budget "
f"({budget_bytes/1e9:.2f} GB). Try a smaller model or lower usable_fraction."
)
# --- Per-sample activation cost (arch-agnostic heuristic) ---
# Activations ≈ f(params, seq_len). Empirically, hidden states + attention
# buffers scale roughly as params^0.6 per token across diverse architectures.
# Multiply by 2 bytes as a base unit (independent of weight dtype, since
# activations are often kept in fp16/bf16 regardless).
bytes_per_token = 2 * (num_params ** 0.6)
bytes_per_sample = bytes_per_token * longest_input_length
print(f"[input] longest_input_length : {longest_input_length} tokens")
print(f"[input] Est. bytes/sample : {bytes_per_sample / 1e6:.1f} MB")
# --- Batch size ---
batch_size = max(1, math.floor(remaining_bytes / bytes_per_sample))
print(f"\n✅ Recommended batch_size : {batch_size}")
return batch_size
def main(args):
# Set seed due to randomness in some models (e.g. VibeVoice's acoustic tokenizer sampling)
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch_dtype = getattr(torch, args.dtype)
config = AutoConfig.from_pretrained(args.model_id, revision=args.revision)
if "dia" in config.model_type:
model = AutoModelForSpeechSeq2Seq.from_pretrained(
args.model_id,
dtype=torch_dtype,
device_map=args.device,
attn_implementation=args.attn_implementation,
)
else:
model = AutoModelForTextToWaveform.from_pretrained(
args.model_id,
dtype=torch_dtype,
device_map=args.device,
attn_implementation=args.attn_implementation,
)
num_params = sum(p.numel() for p in model.parameters()) / 1e9
print(f"Model size: {num_params:.2f}B parameters")
processor_kwargs = {"device_map": args.device} if "higgs" in config.model_type else {}
processor = AutoProcessor.from_pretrained(args.model_id, revision=args.revision, **processor_kwargs)
# Set generate arguments
if model.can_generate():
gen_kwargs = {"max_new_tokens": args.max_new_tokens}
if "higgs" not in config.model_type:
gen_kwargs["min_new_tokens"] = args.max_new_tokens
if "csm" in config.model_type:
gen_kwargs["output_audio"] = True
elif args.max_new_tokens:
raise ValueError("`max_new_tokens` should only be set for auto-regressive models, but got a non-generative model.")
if args.torch_compile is not None:
if model.can_generate():
gen_kwargs["compile_config"] = CompileConfig(mode=args.torch_compile, fullgraph=args.compile_fullgraph)
# enable static k/v cache for autoregressive models
model.generation_config.cache_implementation = "static"
else:
model = torch.compile(model, mode=args.torch_compile, fullgraph=args.compile_fullgraph)
# Ensure warm-up runs when using torch.compile
if args.warmup_steps is None or args.warmup_steps < 1:
print("`--torch_compile` is enabled; forcing `--warmup_steps=10` to trigger compilation before timed runs.")
args.warmup_steps = 10
def benchmark(batch, text_column):
# Load audio inputs
texts_to_generate = batch[text_column]
minibatch_size = len(texts_to_generate)
sampling_rate = 16_000
if hasattr(processor.feature_extractor, "sampling_rate"):
sampling_rate = processor.feature_extractor.sampling_rate
# START TIMING
torch.cuda.synchronize(device=args.device)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
# 1. Pre-Processing
# 1.1 Pad audios to max batch size if using torch compile to prevent re-compilations
padding_size = None
if minibatch_size != args.batch_size and args.torch_compile is not None:
padding_size = args.batch_size - minibatch_size
padding_duplicate = [texts_to_generate[-1] for _ in range(padding_size)]
texts_to_generate.extend(padding_duplicate)
# Apply jinja template if processor has smth saved
if getattr(processor, "chat_template") is not None:
# CSM uses speaekr ID and not role in conv
if "csm" in config.model_type:
texts_to_generate = [
processor.apply_chat_template(
[
{
"role": "0",
"content": [{"type": "text", "text": text}],
}
],
tokenize=False,
add_generation_prompt=True,
return_dict=False,
)
for text in texts_to_generate
]
elif "higgs" in config.model_type:
inputs = processor.apply_chat_template(
[[{
"role": "system",
"content": [
{
"type": "text",
"text": "Generate audio following instruction."
}
],
},
{
"role": "user",
"content": text,
}] for text in texts_to_generate],
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
sampling_rate=24000,
)
else:
texts_to_generate = [
processor.apply_chat_template(
[
{
"role": "user",
"content": text,
}
],
tokenize=False,
add_generation_prompt=True,
return_dict=False,
)
for text in texts_to_generate
]
if "higgs" not in config.model_type:
inputs = processor(text=texts_to_generate, return_tensors="pt")
inputs = inputs.to(args.device)
prompt_len = inputs["input_ids"].shape[1]
# 2. Model Inference
if args.torch_compile is not None:
sdpa_backends = [SDPBackend.MATH]
else:
sdpa_backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
with sdpa_kernel(sdpa_backends):
pred_waveform = model.generate(**inputs, **gen_kwargs)
# 3. Post-processing
# 3.1 Strip padded ids from predictions
if padding_size is not None:
pred_waveform = pred_waveform[:-padding_size, ...]
# 3.2 Convert token ids to text transcription
if config.model_type == 'dia':
prompt_len = processor.get_audio_prompt_len(inputs["decoder_attention_mask"])
outputs = processor.batch_decode(pred_waveform, audio_prompt_len=prompt_len)
elif "higgs" in config.model_type:
outputs = processor.batch_decode(pred_waveform)
else:
outputs = pred_waveform
# END TIMING
end_event.record()
torch.cuda.synchronize(device=args.device)
runtime = start_event.elapsed_time(end_event) / 1000.0
# normalize by minibatch size since we want the per-sample time
batch["generation_time_s"] = minibatch_size * [runtime / minibatch_size]
gen_paths = []
audio_length_s = []
os.makedirs(f"results_{args.model_id}", exist_ok=True)
for audio, i in zip(outputs, batch['id']):
try:
processor.save_audio(audio, saving_path=f"results_{args.model_id}/output_{i}.wav")
gen_paths.append(f"results_{args.model_id}/output_{i}.wav")
audio_length_s.append(len(audio) / sampling_rate)
except:
# Prob the processor is not yet standard
pass
batch["predictions"] = gen_paths
batch["input_text"] = [sample.lower() for sample in texts_to_generate]
batch["audio_length_s"] = audio_length_s
return batch
dataset = load_dataset(args.dataset_path, split=args.split)
dataset = dataset.add_column("id", list(range(len(dataset)))) # Pass id for easier reference when batch mapping
# Infer batch size from model param count and longest input text in the dataset. Reserv 25% of VRAM for cache and overhead
def add_input_token_length(batch, text_column):
input_ids = processor.tokenizer(batch[text_column], return_tensors=None, padding=False, truncation=False).input_ids
batch["input_token_length"] = [len(ids) for ids in input_ids]
return batch
# TODO:
# dataset = dataset.map(
# add_input_token_length, batch_size=args.batch_size, batched=True, fn_kwargs={"text_column": args.text_column}
# )
# longest_input_length = max(dataset["input_token_length"])
# inferred_bs = infer_batch_size(
# model=model,
# longest_input_length=longest_input_length,
# dtype=args.dtype,
# )
if args.warmup_steps is not None:
num_warmup_samples = args.warmup_steps * args.batch_size
warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"text_column": args.text_column}))
for _ in tqdm(warmup_dataset, desc="Warming up..."):
continue
if args.max_eval_samples is not None and args.max_eval_samples > 0:
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
dataset = dataset.map(
benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"text_column": args.text_column}
)
asr_processor = AutoProcessor.from_pretrained("nvidia/parakeet-tdt-0.6b-v3")
asr_model = AutoModelForTDT.from_pretrained("nvidia/parakeet-tdt-0.6b-v3", device_map="auto")
def transcribe_audio(batch) -> list[str]:
"""
Takes in minibatch of audio paths and returns transcribed text per each file.
Each audio is transctibed with ParakeetTDT model.
"""
sr_rate = asr_processor.feature_extractor.sampling_rate
speech_samples = [load_audio(path, sampling_rate=sr_rate) for path in batch["predictions"]]
inputs = asr_processor(speech_samples, sampling_rate=sr_rate).to(asr_model.device, dtype=asr_model.dtype)
outputs = asr_model.generate(**inputs, return_dict_in_generate=False)
outputs = asr_processor.batch_decode(outputs.sequences, skip_special_tokens=True)
batch["asr_outputs"] = [sample.lower() for sample in outputs]
return batch
dataset = dataset.map(transcribe_audio, batch_size=args.batch_size, batched=True)
all_results = {
"audio_length_s": [],
"generation_time_s": [],
"predictions": [],
"input_text": [],
"asr_outputs": [],
}
result_iter = iter(dataset)
for result in tqdm(result_iter, desc="Samples..."):
for key in all_results:
all_results[key].append(result[key])
# Write manifest results (WER and RTFX)
# Filtering of empty references is handled inside write_manifest.
# manifest_path = data_utils.write_manifest(
# all_results["references"],
# all_results["predictions"],
# args.model_id,
# args.dataset_path,
# args.dataset,
# args.split,
# audio_length=all_results["audio_length_s"],
# transcription_time=all_results["transcription_time_s"],
# audio_filepaths=all_results["audio_filepath"],
# )
# print("Results saved at path:", os.path.abspath(manifest_path))
wer = wer_metric.compute(
references=all_results["input_text"], predictions=all_results["asr_outputs"]
)
wer = round(100 * wer, 2)
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["generation_time_s"]), 2)
print("WER:", wer, "%", "RTFx:", rtfx)
return {"model_id": args.model_id, "WER": wer, "RTFX": rtfx}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id",
type=str,
required=True,
help="Model identifier. Should be loadable with 🤗 Transformers",
)
parser.add_argument(
"--dataset_path",
type=str,
default="hf-audio/tts_leaderboard",
help="Dataset path. By default, it is `hf-audio/tts_leaderboard`",
)
parser.add_argument(
"--split",
type=str,
default="train",
help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.",
)
parser.add_argument(
"--text_column",
type=str,
default="text",
help="Name of the column corresponding to the text that has to be generated in dataset with the given `split`.",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="The device to run the pipeline on. `auto` for auto-inferring, 'cpu' for CPU, 'cuda' for the GPU (default).",
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Number of samples to go through each streamed batch.",
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=8192,
help="Maximum number of tokens to generate (for auto-regressive models).",
)
parser.add_argument(
"--torch_compile",
type=str,
default=None,
help="Mode for torch compiling model forward pass. Can be either 'default', 'reduce-overhead', 'max-autotune' or 'max-autotune-no-cudagraphs'.",
)
parser.add_argument(
"--compile_fullgraph",
action="store_true",
help="Whether to do full graph compilation.",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
help="The dtype to use for model loading and inference. E.g. 'bfloat16', 'float16', 'float32'.",
)
parser.add_argument(
"--attn_implementation",
type=str,
default="sdpa",
help="Attention implementation to use for model loading (e.g. 'sdpa', 'eager', 'flash_attention_2').",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=0,
help="Number of warm-up steps to run before launching the timed runs.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="Model revision to use (e.g. 'refs/pr/11' for a PR branch). Defaults to the main branch.",
)
args = parser.parse_args()
print("*" * 100)
print(f"Evaluating {args.model_id} on {args.dataset_path} / {args.split}")
print("*" * 100)
output_dict = main(args)
upsert(output_dict)