hetchyy's picture
Upload folder using huggingface_hub
fdbb877 verified
"""Phoneme ASR using wav2vec2 CTC model."""
import os
import time
import torch
import numpy as np
from typing import List
from config import (
PHONEME_ASR_MODELS, PHONEME_ASR_MODEL_DEFAULT, DTYPE, CPU_DTYPE,
IS_HF_SPACE, TORCH_COMPILE,
BATCHING_STRATEGY, INFERENCE_BATCH_SIZE,
MAX_BATCH_SECONDS, MAX_BATCH_SECONDS_CPU, MAX_PAD_WASTE, MIN_BATCH_SIZE,
)
from ..core.zero_gpu import ZERO_GPU_AVAILABLE, is_user_forced_cpu, model_device_lock
_cache = {} # model_name -> {"model": Model, "processor": Processor, "device": str}
_TORCH_DTYPE = torch.float16 if DTYPE == "float16" else torch.float32
_CPU_TORCH_DTYPE = (
torch.bfloat16 if CPU_DTYPE in ("bfloat16", "bf16")
else torch.float16 if CPU_DTYPE == "float16"
else torch.float32
)
def _get_hf_token():
"""Get HF token from env var or stored login."""
token = os.environ.get("HF_TOKEN")
if not token:
try:
from huggingface_hub import HfFolder
token = HfFolder.get_token()
except Exception:
pass
return token
def _get_device_and_dtype():
"""Get the best available device and dtype.
CPU dtype is governed by CPU_DTYPE env var (default fp32). fp16 is only
safe on CPUs with AVX512_FP16 (e.g. zero-a10g AMD EPYC) — on cpu-basic
workers without AVX512_FP16 it falls back to scalar/copy-cast ops that
can be 10-100× slower than fp32. GPU path re-casts to _TORCH_DTYPE when
transitioning to CUDA.
On HF Spaces with ZeroGPU, returns CPU to defer CUDA init
until inside a @gpu_decorator function.
"""
if IS_HF_SPACE or ZERO_GPU_AVAILABLE:
return torch.device("cpu"), _CPU_TORCH_DTYPE
if torch.cuda.is_available():
return torch.device("cuda"), _TORCH_DTYPE
return torch.device("cpu"), _CPU_TORCH_DTYPE
def load_phoneme_asr(model_name=PHONEME_ASR_MODEL_DEFAULT):
"""Load phoneme ASR model on CPU. Returns (model, processor).
Models are loaded once and cached per model_name. Both base and large
can be cached simultaneously. Use move_phoneme_asr_to_gpu() inside
GPU-decorated functions to move to CUDA.
Thread-safe: uses model_device_lock with double-checked locking.
"""
if model_name in _cache:
entry = _cache[model_name]
return entry["model"], entry["processor"]
with model_device_lock:
# Re-check after acquiring lock — another thread may have loaded it
if model_name in _cache:
entry = _cache[model_name]
return entry["model"], entry["processor"]
import logging
from transformers import AutoModelForCTC, AutoProcessor
# Suppress verbose transformers logging during load
logging.getLogger("transformers").setLevel(logging.WARNING)
model_path = PHONEME_ASR_MODELS[model_name]
print(f"Loading phoneme ASR: {model_path} ({model_name})")
# Use HF_TOKEN for private model access
hf_token = _get_hf_token()
device, dtype = _get_device_and_dtype()
model = AutoModelForCTC.from_pretrained(
model_path, token=hf_token, attn_implementation="sdpa"
)
model.to(device, dtype=dtype)
model.eval()
if TORCH_COMPILE and not (IS_HF_SPACE or ZERO_GPU_AVAILABLE):
model = torch.compile(model, mode="reduce-overhead")
processor = AutoProcessor.from_pretrained(model_path, token=hf_token)
_cache[model_name] = {
"model": model,
"processor": processor,
"device": device.type,
}
print(f"Phoneme ASR ({model_name}) loaded on {device}")
return model, processor
def move_phoneme_asr_to_gpu(model_name=None):
"""Move cached phoneme ASR model(s) to GPU.
Args:
model_name: Move only this model. If None, move all cached models.
Call this inside @gpu_decorator functions on HF Spaces.
Idempotent: checks current device before moving.
Skips if quota exhausted or CUDA unavailable.
"""
if is_user_forced_cpu() or not torch.cuda.is_available():
return
names = [model_name] if model_name else list(_cache.keys())
device = torch.device("cuda")
with model_device_lock:
for name in names:
if name not in _cache:
continue
entry = _cache[name]
model = entry["model"]
if next(model.parameters()).device.type != "cuda":
try:
entry["model"] = model.to(device, dtype=_TORCH_DTYPE)
entry["device"] = "cuda"
print(f"[PHONEME ASR] Moved '{name}' to CUDA")
except RuntimeError as e:
print(f"[PHONEME ASR] CUDA move failed for '{name}', staying on CPU: {e}")
def invalidate_asr_cache(model_name=None):
"""Drop cached ASR model(s) so the next load_phoneme_asr() creates fresh ones.
Args:
model_name: Invalidate only this model. If None, invalidate all.
Called from _drain_stale_models() inside a GPU lease. No CUDA ops —
just removes references and lets GC reclaim tensors.
"""
if model_name:
if model_name in _cache:
del _cache[model_name]
print(f"[PHONEME ASR] Cache invalidated for '{model_name}'")
else:
if _cache:
names = list(_cache.keys())
_cache.clear()
print(f"[PHONEME ASR] Cache invalidated: {names}")
def ids_to_phoneme_list(ids: List[int], tokenizer, pad_id: int) -> List[str]:
"""
Convert token IDs to phoneme list with CTC collapse.
CTC decoding:
1. Remove pad/blank tokens
2. Collapse consecutive duplicates
3. Filter out word delimiter "|"
"""
# Convert all IDs to tokens first (do not skip any yet)
toks = tokenizer.convert_ids_to_tokens(ids)
if not toks:
return []
# Get the actual token string for pad
pad_tok = tokenizer.convert_ids_to_tokens([pad_id])[0] if pad_id is not None else "[PAD]"
# CTC collapse: remove consecutive duplicates and special tokens
collapsed: List[str] = []
prev = None
for t in toks:
# Skip pad/blank token
if t == pad_tok:
prev = t
continue
# Skip word delimiter
if t == "|":
prev = t
continue
# Skip consecutive duplicates (CTC collapse)
if t == prev:
continue
collapsed.append(t)
prev = t
return collapsed
def build_batches_naive(sorted_indices: List[int], batch_size: int) -> List[List[int]]:
"""Fixed-count batching (original behavior)."""
return [sorted_indices[i:i + batch_size]
for i in range(0, len(sorted_indices), batch_size)]
def build_batches(sorted_indices: List[int], durations: List[float],
max_batch_seconds: float = MAX_BATCH_SECONDS) -> List[List[int]]:
"""Build dynamic batches from duration-sorted indices.
Constraints:
- sum(durations) per batch <= max_batch_seconds
- pad waste fraction <= MAX_PAD_WASTE (1 - sum/[n*max], measures wasted tensor compute)
- batch won't be cut below MIN_BATCH_SIZE (avoids underutilization)
"""
batches: List[List[int]] = []
current: List[int] = []
current_seconds = 0.0
for i in sorted_indices:
dur = durations[i]
if not current:
current.append(i)
current_seconds = dur
continue
max_dur = dur # candidate is the new longest (sorted ascending)
new_seconds = current_seconds + dur
new_size = len(current) + 1
pad_waste = 1.0 - new_seconds / (new_size * max_dur) if max_dur > 0 else 0.0
seconds_exceeded = new_seconds > max_batch_seconds
waste_exceeded = pad_waste > MAX_PAD_WASTE
if (seconds_exceeded or waste_exceeded) and len(current) >= MIN_BATCH_SIZE:
batches.append(current)
current = [i]
current_seconds = dur
else:
current.append(i)
current_seconds = new_seconds
if current:
batches.append(current)
return batches
def _transcribe_batch_pytorch(
segment_audios: List[np.ndarray],
durations: List[float],
batches: List[List[int]],
model,
processor,
tokenizer,
pad_id: int,
device: torch.device,
dtype: torch.dtype,
) -> tuple:
"""PyTorch inference path (GPU or CPU fallback)."""
results: List[List[str]] = [[] for _ in segment_audios]
batch_profiling = []
# Read once for the QK^T-bytes estimate logged per batch. Both wav2vec2
# base and xls-r downsample audio by 320× (16 kHz → 50 fps), so seq_len in
# frames is `padded_audio_seconds * 50`. The QK^T tensor materialised by
# the CPU SDPA `math` backend (fp16 path on PyTorch 2.8) is sized
# `(batch, heads, seq, seq) * dtype_bytes` — when this exceeds the host's
# L3 the attention layers go DRAM-bound and per-layer cost spikes 10–20×.
num_heads = getattr(model.config, "num_attention_heads", 0)
dtype_bytes = torch.tensor([], dtype=dtype).element_size()
FRAMES_PER_SEC = 50
for batch_num_idx, batch_idx in enumerate(batches):
batch_audios = [segment_audios[i] for i in batch_idx]
batch_durations = [durations[i] for i in batch_idx]
batch_num = batch_num_idx + 1
t0 = time.time()
# Feature extraction + GPU transfer
t_feat_start = time.time()
inputs = processor(
batch_audios,
sampling_rate=16000,
return_tensors="pt",
padding=True,
)
input_values = inputs.input_values.to(device=device, dtype=dtype)
attention_mask = inputs.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(device=device)
feat_time = time.time() - t_feat_start
# Model inference
t_infer_start = time.time()
with torch.no_grad():
outputs = model(input_values, attention_mask=attention_mask)
logits = outputs.logits
if device.type == "cuda":
try:
torch.cuda.synchronize()
except RuntimeError:
pass # GPU may have been deallocated at lease boundary
infer_time = time.time() - t_infer_start
# CTC greedy decode
t_decode_start = time.time()
predicted_ids = torch.argmax(logits, dim=-1)
for j in range(predicted_ids.shape[0]):
ids_list = predicted_ids[j].cpu().tolist()
phoneme_list = ids_to_phoneme_list(ids_list, tokenizer, pad_id)
results[batch_idx[j]] = phoneme_list
decode_time = time.time() - t_decode_start
del input_values, attention_mask, outputs, logits, predicted_ids
batch_time = time.time() - t0
max_dur = max(batch_durations)
seq_len = int(round(max_dur * FRAMES_PER_SEC))
qk_bytes_per_head = len(batch_audios) * seq_len * seq_len * dtype_bytes
qk_bytes_all_heads = qk_bytes_per_head * num_heads
total_seconds = sum(batch_durations)
batch_profiling.append({
"batch_num": batch_num,
"size": len(batch_audios),
"time": round(batch_time, 3),
"feat_time": round(feat_time, 3),
"infer_time": round(infer_time, 3),
"decode_time": round(decode_time, 3),
"min_dur": round(min(batch_durations), 3),
"max_dur": round(max_dur, 3),
"total_seconds": round(total_seconds, 3),
"pad_waste": round(1.0 - total_seconds / (len(batch_durations) * max_dur), 4) if max_dur > 0 else 0.0,
"seq_len": seq_len,
"qk_mb_per_head": round(qk_bytes_per_head / (1024 * 1024), 2),
"qk_mb_all_heads": round(qk_bytes_all_heads / (1024 * 1024), 2),
})
return results, batch_profiling
def transcribe_batch(segment_audios: List[np.ndarray], sample_rate: int, model_name: str = PHONEME_ASR_MODEL_DEFAULT) -> tuple:
"""Transcribe audio segments to phoneme lists, sorted by duration for efficiency.
Args:
segment_audios: List of audio arrays
sample_rate: Audio sample rate
model_name: Which ASR model to use ("base" or "large")
Returns:
(results, batch_profiling) where results is List[List[str]] and
batch_profiling is a list of dicts with per-batch timing and duration stats.
"""
if not segment_audios:
return [], [], 0.0, 0.0
model, processor = load_phoneme_asr(model_name)
if model is None:
return [[] for _ in segment_audios], [], 0.0, 0.0
# Determine inference device.
if is_user_forced_cpu():
device = torch.device("cpu")
else:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
tokenizer = processor.tokenizer
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# Compute durations (audio assumed to be 16kHz — resampled at source)
durations = [len(audio) / 16000.0 for audio in segment_audios]
# Sort indices by duration, then build dynamic batches
t_sort = time.time()
sorted_indices = sorted(range(len(segment_audios)), key=lambda i: durations[i])
sorting_time = time.time() - t_sort
t_batch_build = time.time()
if BATCHING_STRATEGY == "dynamic":
cap = MAX_BATCH_SECONDS_CPU if device.type == "cpu" else MAX_BATCH_SECONDS
batches = build_batches(sorted_indices, durations, max_batch_seconds=cap)
else:
batches = build_batches_naive(sorted_indices, INFERENCE_BATCH_SIZE)
batch_build_time = time.time() - t_batch_build
backend = "PyTorch" + (f" ({device.type})" if device.type != "cpu" else " (CPU)")
print(f"[PHONEME ASR] Using {backend}")
results, batch_profiling = _transcribe_batch_pytorch(
segment_audios, durations, batches,
model, processor, tokenizer, pad_id, device, dtype,
)
sizes = [p["size"] for p in batch_profiling]
print(f"[PHONEME ASR] {len(segment_audios)} segments in {len(batch_profiling)} batches "
f"(sizes: {min(sizes)}-{max(sizes)}, sort: {sorting_time:.3f}s, batch build: {batch_build_time:.3f}s)")
return results, batch_profiling, sorting_time, batch_build_time