Spaces:
Running on Zero
Running on Zero
| """Phoneme ASR using wav2vec2 CTC model.""" | |
| import os | |
| import time | |
| import torch | |
| import numpy as np | |
| from typing import List | |
| from config import ( | |
| PHONEME_ASR_MODELS, PHONEME_ASR_MODEL_DEFAULT, DTYPE, CPU_DTYPE, | |
| IS_HF_SPACE, TORCH_COMPILE, | |
| BATCHING_STRATEGY, INFERENCE_BATCH_SIZE, | |
| MAX_BATCH_SECONDS, MAX_BATCH_SECONDS_CPU, MAX_PAD_WASTE, MIN_BATCH_SIZE, | |
| ) | |
| from ..core.zero_gpu import ZERO_GPU_AVAILABLE, is_user_forced_cpu, model_device_lock | |
| _cache = {} # model_name -> {"model": Model, "processor": Processor, "device": str} | |
| _TORCH_DTYPE = torch.float16 if DTYPE == "float16" else torch.float32 | |
| _CPU_TORCH_DTYPE = ( | |
| torch.bfloat16 if CPU_DTYPE in ("bfloat16", "bf16") | |
| else torch.float16 if CPU_DTYPE == "float16" | |
| else torch.float32 | |
| ) | |
| def _get_hf_token(): | |
| """Get HF token from env var or stored login.""" | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| try: | |
| from huggingface_hub import HfFolder | |
| token = HfFolder.get_token() | |
| except Exception: | |
| pass | |
| return token | |
| def _get_device_and_dtype(): | |
| """Get the best available device and dtype. | |
| CPU dtype is governed by CPU_DTYPE env var (default fp32). fp16 is only | |
| safe on CPUs with AVX512_FP16 (e.g. zero-a10g AMD EPYC) — on cpu-basic | |
| workers without AVX512_FP16 it falls back to scalar/copy-cast ops that | |
| can be 10-100× slower than fp32. GPU path re-casts to _TORCH_DTYPE when | |
| transitioning to CUDA. | |
| On HF Spaces with ZeroGPU, returns CPU to defer CUDA init | |
| until inside a @gpu_decorator function. | |
| """ | |
| if IS_HF_SPACE or ZERO_GPU_AVAILABLE: | |
| return torch.device("cpu"), _CPU_TORCH_DTYPE | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda"), _TORCH_DTYPE | |
| return torch.device("cpu"), _CPU_TORCH_DTYPE | |
| def load_phoneme_asr(model_name=PHONEME_ASR_MODEL_DEFAULT): | |
| """Load phoneme ASR model on CPU. Returns (model, processor). | |
| Models are loaded once and cached per model_name. Both base and large | |
| can be cached simultaneously. Use move_phoneme_asr_to_gpu() inside | |
| GPU-decorated functions to move to CUDA. | |
| Thread-safe: uses model_device_lock with double-checked locking. | |
| """ | |
| if model_name in _cache: | |
| entry = _cache[model_name] | |
| return entry["model"], entry["processor"] | |
| with model_device_lock: | |
| # Re-check after acquiring lock — another thread may have loaded it | |
| if model_name in _cache: | |
| entry = _cache[model_name] | |
| return entry["model"], entry["processor"] | |
| import logging | |
| from transformers import AutoModelForCTC, AutoProcessor | |
| # Suppress verbose transformers logging during load | |
| logging.getLogger("transformers").setLevel(logging.WARNING) | |
| model_path = PHONEME_ASR_MODELS[model_name] | |
| print(f"Loading phoneme ASR: {model_path} ({model_name})") | |
| # Use HF_TOKEN for private model access | |
| hf_token = _get_hf_token() | |
| device, dtype = _get_device_and_dtype() | |
| model = AutoModelForCTC.from_pretrained( | |
| model_path, token=hf_token, attn_implementation="sdpa" | |
| ) | |
| model.to(device, dtype=dtype) | |
| model.eval() | |
| if TORCH_COMPILE and not (IS_HF_SPACE or ZERO_GPU_AVAILABLE): | |
| model = torch.compile(model, mode="reduce-overhead") | |
| processor = AutoProcessor.from_pretrained(model_path, token=hf_token) | |
| _cache[model_name] = { | |
| "model": model, | |
| "processor": processor, | |
| "device": device.type, | |
| } | |
| print(f"Phoneme ASR ({model_name}) loaded on {device}") | |
| return model, processor | |
| def move_phoneme_asr_to_gpu(model_name=None): | |
| """Move cached phoneme ASR model(s) to GPU. | |
| Args: | |
| model_name: Move only this model. If None, move all cached models. | |
| Call this inside @gpu_decorator functions on HF Spaces. | |
| Idempotent: checks current device before moving. | |
| Skips if quota exhausted or CUDA unavailable. | |
| """ | |
| if is_user_forced_cpu() or not torch.cuda.is_available(): | |
| return | |
| names = [model_name] if model_name else list(_cache.keys()) | |
| device = torch.device("cuda") | |
| with model_device_lock: | |
| for name in names: | |
| if name not in _cache: | |
| continue | |
| entry = _cache[name] | |
| model = entry["model"] | |
| if next(model.parameters()).device.type != "cuda": | |
| try: | |
| entry["model"] = model.to(device, dtype=_TORCH_DTYPE) | |
| entry["device"] = "cuda" | |
| print(f"[PHONEME ASR] Moved '{name}' to CUDA") | |
| except RuntimeError as e: | |
| print(f"[PHONEME ASR] CUDA move failed for '{name}', staying on CPU: {e}") | |
| def invalidate_asr_cache(model_name=None): | |
| """Drop cached ASR model(s) so the next load_phoneme_asr() creates fresh ones. | |
| Args: | |
| model_name: Invalidate only this model. If None, invalidate all. | |
| Called from _drain_stale_models() inside a GPU lease. No CUDA ops — | |
| just removes references and lets GC reclaim tensors. | |
| """ | |
| if model_name: | |
| if model_name in _cache: | |
| del _cache[model_name] | |
| print(f"[PHONEME ASR] Cache invalidated for '{model_name}'") | |
| else: | |
| if _cache: | |
| names = list(_cache.keys()) | |
| _cache.clear() | |
| print(f"[PHONEME ASR] Cache invalidated: {names}") | |
| def ids_to_phoneme_list(ids: List[int], tokenizer, pad_id: int) -> List[str]: | |
| """ | |
| Convert token IDs to phoneme list with CTC collapse. | |
| CTC decoding: | |
| 1. Remove pad/blank tokens | |
| 2. Collapse consecutive duplicates | |
| 3. Filter out word delimiter "|" | |
| """ | |
| # Convert all IDs to tokens first (do not skip any yet) | |
| toks = tokenizer.convert_ids_to_tokens(ids) | |
| if not toks: | |
| return [] | |
| # Get the actual token string for pad | |
| pad_tok = tokenizer.convert_ids_to_tokens([pad_id])[0] if pad_id is not None else "[PAD]" | |
| # CTC collapse: remove consecutive duplicates and special tokens | |
| collapsed: List[str] = [] | |
| prev = None | |
| for t in toks: | |
| # Skip pad/blank token | |
| if t == pad_tok: | |
| prev = t | |
| continue | |
| # Skip word delimiter | |
| if t == "|": | |
| prev = t | |
| continue | |
| # Skip consecutive duplicates (CTC collapse) | |
| if t == prev: | |
| continue | |
| collapsed.append(t) | |
| prev = t | |
| return collapsed | |
| def build_batches_naive(sorted_indices: List[int], batch_size: int) -> List[List[int]]: | |
| """Fixed-count batching (original behavior).""" | |
| return [sorted_indices[i:i + batch_size] | |
| for i in range(0, len(sorted_indices), batch_size)] | |
| def build_batches(sorted_indices: List[int], durations: List[float], | |
| max_batch_seconds: float = MAX_BATCH_SECONDS) -> List[List[int]]: | |
| """Build dynamic batches from duration-sorted indices. | |
| Constraints: | |
| - sum(durations) per batch <= max_batch_seconds | |
| - pad waste fraction <= MAX_PAD_WASTE (1 - sum/[n*max], measures wasted tensor compute) | |
| - batch won't be cut below MIN_BATCH_SIZE (avoids underutilization) | |
| """ | |
| batches: List[List[int]] = [] | |
| current: List[int] = [] | |
| current_seconds = 0.0 | |
| for i in sorted_indices: | |
| dur = durations[i] | |
| if not current: | |
| current.append(i) | |
| current_seconds = dur | |
| continue | |
| max_dur = dur # candidate is the new longest (sorted ascending) | |
| new_seconds = current_seconds + dur | |
| new_size = len(current) + 1 | |
| pad_waste = 1.0 - new_seconds / (new_size * max_dur) if max_dur > 0 else 0.0 | |
| seconds_exceeded = new_seconds > max_batch_seconds | |
| waste_exceeded = pad_waste > MAX_PAD_WASTE | |
| if (seconds_exceeded or waste_exceeded) and len(current) >= MIN_BATCH_SIZE: | |
| batches.append(current) | |
| current = [i] | |
| current_seconds = dur | |
| else: | |
| current.append(i) | |
| current_seconds = new_seconds | |
| if current: | |
| batches.append(current) | |
| return batches | |
| def _transcribe_batch_pytorch( | |
| segment_audios: List[np.ndarray], | |
| durations: List[float], | |
| batches: List[List[int]], | |
| model, | |
| processor, | |
| tokenizer, | |
| pad_id: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> tuple: | |
| """PyTorch inference path (GPU or CPU fallback).""" | |
| results: List[List[str]] = [[] for _ in segment_audios] | |
| batch_profiling = [] | |
| # Read once for the QK^T-bytes estimate logged per batch. Both wav2vec2 | |
| # base and xls-r downsample audio by 320× (16 kHz → 50 fps), so seq_len in | |
| # frames is `padded_audio_seconds * 50`. The QK^T tensor materialised by | |
| # the CPU SDPA `math` backend (fp16 path on PyTorch 2.8) is sized | |
| # `(batch, heads, seq, seq) * dtype_bytes` — when this exceeds the host's | |
| # L3 the attention layers go DRAM-bound and per-layer cost spikes 10–20×. | |
| num_heads = getattr(model.config, "num_attention_heads", 0) | |
| dtype_bytes = torch.tensor([], dtype=dtype).element_size() | |
| FRAMES_PER_SEC = 50 | |
| for batch_num_idx, batch_idx in enumerate(batches): | |
| batch_audios = [segment_audios[i] for i in batch_idx] | |
| batch_durations = [durations[i] for i in batch_idx] | |
| batch_num = batch_num_idx + 1 | |
| t0 = time.time() | |
| # Feature extraction + GPU transfer | |
| t_feat_start = time.time() | |
| inputs = processor( | |
| batch_audios, | |
| sampling_rate=16000, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| input_values = inputs.input_values.to(device=device, dtype=dtype) | |
| attention_mask = inputs.get("attention_mask") | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.to(device=device) | |
| feat_time = time.time() - t_feat_start | |
| # Model inference | |
| t_infer_start = time.time() | |
| with torch.no_grad(): | |
| outputs = model(input_values, attention_mask=attention_mask) | |
| logits = outputs.logits | |
| if device.type == "cuda": | |
| try: | |
| torch.cuda.synchronize() | |
| except RuntimeError: | |
| pass # GPU may have been deallocated at lease boundary | |
| infer_time = time.time() - t_infer_start | |
| # CTC greedy decode | |
| t_decode_start = time.time() | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| for j in range(predicted_ids.shape[0]): | |
| ids_list = predicted_ids[j].cpu().tolist() | |
| phoneme_list = ids_to_phoneme_list(ids_list, tokenizer, pad_id) | |
| results[batch_idx[j]] = phoneme_list | |
| decode_time = time.time() - t_decode_start | |
| del input_values, attention_mask, outputs, logits, predicted_ids | |
| batch_time = time.time() - t0 | |
| max_dur = max(batch_durations) | |
| seq_len = int(round(max_dur * FRAMES_PER_SEC)) | |
| qk_bytes_per_head = len(batch_audios) * seq_len * seq_len * dtype_bytes | |
| qk_bytes_all_heads = qk_bytes_per_head * num_heads | |
| total_seconds = sum(batch_durations) | |
| batch_profiling.append({ | |
| "batch_num": batch_num, | |
| "size": len(batch_audios), | |
| "time": round(batch_time, 3), | |
| "feat_time": round(feat_time, 3), | |
| "infer_time": round(infer_time, 3), | |
| "decode_time": round(decode_time, 3), | |
| "min_dur": round(min(batch_durations), 3), | |
| "max_dur": round(max_dur, 3), | |
| "total_seconds": round(total_seconds, 3), | |
| "pad_waste": round(1.0 - total_seconds / (len(batch_durations) * max_dur), 4) if max_dur > 0 else 0.0, | |
| "seq_len": seq_len, | |
| "qk_mb_per_head": round(qk_bytes_per_head / (1024 * 1024), 2), | |
| "qk_mb_all_heads": round(qk_bytes_all_heads / (1024 * 1024), 2), | |
| }) | |
| return results, batch_profiling | |
| def transcribe_batch(segment_audios: List[np.ndarray], sample_rate: int, model_name: str = PHONEME_ASR_MODEL_DEFAULT) -> tuple: | |
| """Transcribe audio segments to phoneme lists, sorted by duration for efficiency. | |
| Args: | |
| segment_audios: List of audio arrays | |
| sample_rate: Audio sample rate | |
| model_name: Which ASR model to use ("base" or "large") | |
| Returns: | |
| (results, batch_profiling) where results is List[List[str]] and | |
| batch_profiling is a list of dicts with per-batch timing and duration stats. | |
| """ | |
| if not segment_audios: | |
| return [], [], 0.0, 0.0 | |
| model, processor = load_phoneme_asr(model_name) | |
| if model is None: | |
| return [[] for _ in segment_audios], [], 0.0, 0.0 | |
| # Determine inference device. | |
| if is_user_forced_cpu(): | |
| device = torch.device("cpu") | |
| else: | |
| device = next(model.parameters()).device | |
| dtype = next(model.parameters()).dtype | |
| tokenizer = processor.tokenizer | |
| pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 | |
| # Compute durations (audio assumed to be 16kHz — resampled at source) | |
| durations = [len(audio) / 16000.0 for audio in segment_audios] | |
| # Sort indices by duration, then build dynamic batches | |
| t_sort = time.time() | |
| sorted_indices = sorted(range(len(segment_audios)), key=lambda i: durations[i]) | |
| sorting_time = time.time() - t_sort | |
| t_batch_build = time.time() | |
| if BATCHING_STRATEGY == "dynamic": | |
| cap = MAX_BATCH_SECONDS_CPU if device.type == "cpu" else MAX_BATCH_SECONDS | |
| batches = build_batches(sorted_indices, durations, max_batch_seconds=cap) | |
| else: | |
| batches = build_batches_naive(sorted_indices, INFERENCE_BATCH_SIZE) | |
| batch_build_time = time.time() - t_batch_build | |
| backend = "PyTorch" + (f" ({device.type})" if device.type != "cpu" else " (CPU)") | |
| print(f"[PHONEME ASR] Using {backend}") | |
| results, batch_profiling = _transcribe_batch_pytorch( | |
| segment_audios, durations, batches, | |
| model, processor, tokenizer, pad_id, device, dtype, | |
| ) | |
| sizes = [p["size"] for p in batch_profiling] | |
| print(f"[PHONEME ASR] {len(segment_audios)} segments in {len(batch_profiling)} batches " | |
| f"(sizes: {min(sizes)}-{max(sizes)}, sort: {sorting_time:.3f}s, batch build: {batch_build_time:.3f}s)") | |
| return results, batch_profiling, sorting_time, batch_build_time | |