""" GPU-decorated inference functions for ZeroGPU optimization. These functions wrap model inference calls with GPU leases and logging for efficient GPU utilization on HuggingFace Spaces with ZeroGPU. Each GPU function delegates to an _impl function that contains the actual logic. CPU fallback versions call the _impl directly (without GPU lease). """ import sys import time from pathlib import Path # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) from config import GPU_DURATION_INITIAL, GPU_DURATION_WAV2VEC_FA from recitation_engine.segment_processor import detect_vad_segments, transcribe_segments_batched from recitation_engine.gpu_profiler import record_gpu_lease from shared_state import get_processor, get_model # Import spaces.GPU decorator directly for ZeroGPU static scanner detection # The scanner looks for the literal `@spaces.GPU` decorator name via AST parsing try: import spaces except ImportError: # Fallback no-op decorator for local development without spaces package class spaces: @staticmethod def GPU(*args, **kwargs): def decorator(func): return func if args and callable(args[0]): return args[0] return decorator def _ensure_models_on_gpu(): """Move all models to GPU at the start of a GPU lease. On HF Spaces with ZeroGPU, models are loaded on CPU to avoid CUDA init in the main process. This function moves them to GPU when we have an active GPU lease (inside a @spaces.GPU function). """ from recitation_engine.model_loader import move_models_to_gpu from recitation_engine.segment_processor import move_segment_models_to_gpu move_models_to_gpu() move_segment_models_to_gpu() # ============================================================================= # Quota exceeded detection # ============================================================================= def is_quota_exceeded(exc: Exception) -> bool: """Check if exception is a ZeroGPU quota exceeded error.""" msg = str(exc).lower() return any(kw in msg for kw in ("quota", "exceeded", "gpu quota", "gpu task aborted")) # ============================================================================= # Implementation functions (device-agnostic) # ============================================================================= def _run_transcription_impl(audio_data, processor_arg=None, model_arg=None): """Core transcription logic — runs on whatever device models are on.""" from recitation_engine.transcription import transcribe_audio processor = get_processor() model = get_model() start = time.time() result = transcribe_audio(audio_data, processor, model) elapsed = time.time() - start lease_duration = GPU_DURATION_INITIAL utilization = (elapsed / lease_duration) * 100 if lease_duration > 0 else 0 print(f"[GPU STATS] TRANSCRIPTION ──────────────────────────────────────") print(f" Runtime: {elapsed:.2f}s | Lease: {lease_duration}s | Utilization: {utilization:.1f}%") print(f"[GPU STATS] TRANSCRIPTION ──────────────────────────────────────") record_gpu_lease("Transcription", lease_duration, elapsed) return result def _run_fa_extraction_impl(audio_array, sample_rate, phoneme_sequence): """Core forced alignment logic — runs on whatever device models are on.""" import torch from recitation_analysis.duration_analysis.fa_backend import extract_phoneme_timestamps processor = get_processor() model = get_model() device = "cuda" if torch.cuda.is_available() else "cpu" start = time.time() fa_result = extract_phoneme_timestamps( audio_array=audio_array, sample_rate=sample_rate, phoneme_sequence=phoneme_sequence, model=model, processor=processor, device=device, return_visualization_data=True, ) elapsed = time.time() - start lease_duration = GPU_DURATION_WAV2VEC_FA utilization = (elapsed / lease_duration) * 100 if lease_duration > 0 else 0 print(f"[GPU STATS] FA EXTRACTION ──────────────────────────────") print(f" Runtime: {elapsed:.2f}s | Lease: {lease_duration}s | Utilization: {utilization:.1f}%") print(f"[GPU STATS] FA EXTRACTION ──────────────────────────────") record_gpu_lease("FA Extraction", lease_duration, elapsed) return fa_result def _run_initial_impl(audio_data, canonical_text, verse_ref): """Core smart router logic — runs on whatever device models are on.""" import torch from recitation_engine.transcription import transcribe_audio from recitation_analysis.duration_analysis.fa_backend import extract_phoneme_timestamps start = time.time() # Step 1: VAD vad_start = time.time() vad_result = detect_vad_segments(audio_data, canonical_text, verse_ref) vad_elapsed = time.time() - vad_start if not vad_result or not vad_result.vad_segments: elapsed = time.time() - start print(f"[GPU STATS] INITIAL (no speech) ──────────────────────") print(f" VAD: {vad_elapsed:.2f}s | Total: {elapsed:.2f}s") print(f"[GPU STATS] INITIAL ──────────────────────────────────") record_gpu_lease("Initial", GPU_DURATION_INITIAL, elapsed, "no speech") return {'error': "No speech detected", 'vad_result': None, 'num_segments': 0} num_segments = len(vad_result.vad_segments) # Extract segment audios (needed for both paths) segment_audios = [] for vad_seg in vad_result.vad_segments: start_sample = int(vad_seg.start_time * vad_result.sample_rate) end_sample = int(vad_seg.end_time * vad_result.sample_rate) segment_audios.append(vad_result.audio[start_sample:end_sample]) result = { 'vad_result': vad_result, 'num_segments': num_segments, 'segment_audios': segment_audios, 'error': None, } if num_segments == 1: # === SINGLE SEGMENT PATH === trimmed_audio = segment_audios[0] if len(trimmed_audio) < 1600: # < 0.1s at 16kHz elapsed = time.time() - start print(f"[GPU STATS] INITIAL (speech too short) ───────────────") print(f" VAD: {vad_elapsed:.2f}s | Total: {elapsed:.2f}s") print(f"[GPU STATS] INITIAL ──────────────────────────────────") record_gpu_lease("Initial", GPU_DURATION_INITIAL, elapsed, "too short") result['error'] = "Speech too short" return result processor = get_processor() model = get_model() # Transcription trans_start = time.time() transcription, trans_error = transcribe_audio( (vad_result.sample_rate, trimmed_audio), processor, model ) trans_elapsed = time.time() - trans_start if trans_error or not transcription: elapsed = time.time() - start print(f"[GPU STATS] INITIAL (transcription error) ────────────") print(f" VAD: {vad_elapsed:.2f}s | Trans: {trans_elapsed:.2f}s | Total: {elapsed:.2f}s") print(f"[GPU STATS] INITIAL ──────────────────────────────────") record_gpu_lease("Initial", GPU_DURATION_INITIAL, elapsed, "trans error") result['error'] = trans_error or "No phonemes detected" result['transcription'] = transcription return result # FA extraction fa_start = time.time() device = "cuda" if torch.cuda.is_available() else "cpu" fa_result = extract_phoneme_timestamps( audio_array=trimmed_audio, sample_rate=vad_result.sample_rate, phoneme_sequence=tuple(transcription.split()), model=model, processor=processor, device=device, return_visualization_data=True, ) fa_elapsed = time.time() - fa_start result['transcription'] = transcription result['fa_result'] = fa_result result['trimmed_audio'] = (vad_result.sample_rate, trimmed_audio) elapsed = time.time() - start utilization = (elapsed / GPU_DURATION_INITIAL) * 100 print(f"[GPU STATS] INITIAL (single segment) ─────────────────") print(f" VAD: {vad_elapsed:.2f}s | Trans: {trans_elapsed:.2f}s | FA: {fa_elapsed:.2f}s") print(f" Total: {elapsed:.2f}s | Lease: {GPU_DURATION_INITIAL}s | Utilization: {utilization:.1f}%") print(f"[GPU STATS] INITIAL ──────────────────────────────────") record_gpu_lease("Initial", GPU_DURATION_INITIAL, elapsed, "1 seg: VAD+Trans+FA") else: # === MULTI SEGMENT PATH === # Run Whisper on all segments whisper_start = time.time() whisper_texts = transcribe_segments_batched(segment_audios, vad_result.sample_rate) whisper_elapsed = time.time() - whisper_start result['whisper_texts'] = whisper_texts elapsed = time.time() - start utilization = (elapsed / GPU_DURATION_INITIAL) * 100 print(f"[GPU STATS] INITIAL (multi segment) ──────────────────") print(f" VAD: {vad_elapsed:.2f}s | Whisper ({num_segments} segs): {whisper_elapsed:.2f}s") print(f" Total: {elapsed:.2f}s | Lease: {GPU_DURATION_INITIAL}s | Utilization: {utilization:.1f}%") print(f"[GPU STATS] INITIAL ──────────────────────────────────") record_gpu_lease("Initial", GPU_DURATION_INITIAL, elapsed, f"{num_segments} segs: VAD+Whisper") return result def _run_wav2vec_and_fa_impl(segment_audios, sample_rate): """Core Wav2Vec2 + FA logic — runs on whatever device models are on.""" import torch from recitation_engine.transcription import transcribe_audio_batched from recitation_analysis.duration_analysis.fa_backend import forced_align_from_logits start = time.time() num_segments = len(segment_audios) processor = get_processor() model = get_model() device = "cuda" if torch.cuda.is_available() else "cpu" # Step 1: Wav2Vec2 batch transcription with logits for FA reuse wav2vec_start = time.time() wav2vec_results, logits_list, audio_durations = transcribe_audio_batched( segment_audios, sample_rate, processor, model, return_logits=True ) wav2vec_elapsed = time.time() - wav2vec_start # Step 2: FA using pre-computed logits (no model inference) fa_start = time.time() fa_results = [] for i, (phonemes, logits, audio_dur) in enumerate(zip(wav2vec_results, logits_list, audio_durations)): if not phonemes or logits is None: fa_results.append(None) continue try: phoneme_seq = tuple(phonemes.split()) if isinstance(phonemes, str) else tuple(phonemes) fa_result = forced_align_from_logits( logits=logits, phoneme_sequence=phoneme_seq, processor=processor, audio_duration=audio_dur, device=device, return_visualization_data=True, ) fa_results.append(fa_result) except Exception as e: print(f"[FA] Segment {i+1} failed: {e}") fa_results.append(None) fa_elapsed = time.time() - fa_start elapsed = time.time() - start utilization = (elapsed / GPU_DURATION_WAV2VEC_FA) * 100 print(f"[GPU STATS] WAV2VEC+FA ({num_segments} segments) ─────────────") print(f" Wav2Vec2: {wav2vec_elapsed:.2f}s | FA (logits reuse): {fa_elapsed:.2f}s") print(f" Total: {elapsed:.2f}s | Lease: {GPU_DURATION_WAV2VEC_FA}s | Utilization: {utilization:.1f}%") print(f"[GPU STATS] WAV2VEC+FA ───────────────────────────────────") record_gpu_lease("Wav2Vec+FA", GPU_DURATION_WAV2VEC_FA, elapsed, f"{num_segments} segs") return wav2vec_results, fa_results def _run_multi_segment_pipeline_impl(audio_data, canonical_text, verse_ref): """Core multi-segment pipeline logic — runs on whatever device models are on.""" import torch from recitation_engine.transcription import transcribe_audio_batched from recitation_analysis.duration_analysis.fa_backend import forced_align_from_logits from shared_state import ( set_processor as state_set_processor, set_model as state_set_model, get_model_bundles as state_get_model_bundles, ) # Get model bundles from shared state (not passed as arg to avoid serialization) model_bundles = state_get_model_bundles() start = time.time() # Step 1: VAD vad_start = time.time() vad_result = detect_vad_segments(audio_data, canonical_text, verse_ref) vad_elapsed = time.time() - vad_start if not vad_result or not vad_result.vad_segments: elapsed = time.time() - start record_gpu_lease("MultiSegPipeline", 120, elapsed, "no speech") return {'error': "No speech detected", 'vad_result': None, 'num_segments': 0} num_segments = len(vad_result.vad_segments) # Extract segment audios segment_audios = [] for vad_seg in vad_result.vad_segments: start_sample = int(vad_seg.start_time * vad_result.sample_rate) end_sample = int(vad_seg.end_time * vad_result.sample_rate) segment_audios.append(vad_result.audio[start_sample:end_sample]) # Step 2: Whisper batch transcription (model-independent) whisper_start = time.time() whisper_texts = transcribe_segments_batched(segment_audios, vad_result.sample_rate) whisper_elapsed = time.time() - whisper_start # Step 3: For each model, run Wav2Vec2 + FA device = "cuda" if torch.cuda.is_available() else "cpu" model_results = {} # If no model bundles configured, use current model from shared state if not model_bundles: model_bundles = [{"processor": get_processor(), "model": get_model(), "path": "current"}] for idx, bundle in enumerate(model_bundles): processor = bundle["processor"] model = bundle["model"] # Set shared state for this model (needed by some downstream code) state_set_processor(processor) state_set_model(model) # Move this model to GPU if needed if model is not None: current_device = next(model.parameters()).device if current_device.type != "cuda" and torch.cuda.is_available(): model = model.to(device) bundle["model"] = model wav2vec_start = time.time() wav2vec_results, logits_list, audio_durations = transcribe_audio_batched( segment_audios, vad_result.sample_rate, processor, model, return_logits=True ) wav2vec_elapsed = time.time() - wav2vec_start # FA using pre-computed logits fa_start = time.time() fa_results = [] for i, (phonemes, logits, audio_dur) in enumerate(zip(wav2vec_results, logits_list, audio_durations)): if not phonemes or logits is None: fa_results.append(None) continue try: phoneme_seq = tuple(phonemes.split()) if isinstance(phonemes, str) else tuple(phonemes) fa_result = forced_align_from_logits( logits=logits, phoneme_sequence=phoneme_seq, processor=processor, audio_duration=audio_dur, device=device, return_visualization_data=True, ) fa_results.append(fa_result) except Exception as e: print(f"[FA] Model {idx+1} Segment {i+1} failed: {e}") fa_results.append(None) fa_elapsed = time.time() - fa_start model_results[idx] = { 'wav2vec_results': wav2vec_results, 'fa_results': fa_results, 'wav2vec_elapsed': wav2vec_elapsed, 'fa_elapsed': fa_elapsed, } print(f"[MODEL {idx+1}] Wav2Vec2: {wav2vec_elapsed:.2f}s | FA: {fa_elapsed:.2f}s") elapsed = time.time() - start utilization = (elapsed / 120) * 100 num_models = len(model_bundles) print(f"[GPU STATS] MULTI-SEG PIPELINE ({num_segments} segs, {num_models} models) ─────") print(f" VAD: {vad_elapsed:.2f}s | Whisper: {whisper_elapsed:.2f}s") print(f" Total: {elapsed:.2f}s | Lease: 120s | Utilization: {utilization:.1f}%") print(f"[GPU STATS] MULTI-SEG PIPELINE ───────────────────────────────") record_gpu_lease("MultiSegPipeline", 120, elapsed, f"{num_segments} segs, {num_models} models") return { 'vad_result': vad_result, 'num_segments': num_segments, 'segment_audios': segment_audios, 'whisper_texts': whisper_texts, 'model_results': model_results, 'error': None, } # ============================================================================= # GPU-decorated functions (thin wrappers) # ============================================================================= @spaces.GPU(duration=GPU_DURATION_INITIAL) def run_transcription_gpu(audio_data, processor_arg=None, model_arg=None): """ Run Wav2Vec2 transcription with GPU lease. Note: Prefer run_initial_gpu() for optimized lease usage. processor_arg/model_arg are ignored - we use shared state. Args: audio_data: Tuple of (sample_rate, audio_array) from Gradio processor_arg: Ignored (kept for API compatibility) model_arg: Ignored (kept for API compatibility) Returns: Tuple of (transcription, error) """ _ensure_models_on_gpu() return _run_transcription_impl(audio_data, processor_arg, model_arg) @spaces.GPU(duration=GPU_DURATION_WAV2VEC_FA) def run_fa_extraction_gpu(audio_array, sample_rate, phoneme_sequence): """ Run CTC forced alignment extraction with GPU lease. Note: Prefer run_initial_gpu() or run_wav2vec_and_fa_gpu() for optimized lease usage. Args: audio_array: Audio waveform as numpy array sample_rate: Audio sample rate in Hz phoneme_sequence: Tuple of phoneme strings to align Returns: Dict with FA results (segments, timestamps, visualization data) """ _ensure_models_on_gpu() return _run_fa_extraction_impl(audio_array, sample_rate, phoneme_sequence) @spaces.GPU(duration=GPU_DURATION_INITIAL) def run_initial_gpu(audio_data, canonical_text, verse_ref): """ Smart router: VAD → branch based on segment count. Single segment: VAD → Trim → Transcription → FA (all in one lease) Multi segment: VAD → Extract audios → Whisper (returns for Lease 2) Args: audio_data: Tuple of (sample_rate, audio_array) from Gradio canonical_text: Expected Arabic text for the verse verse_ref: Verse reference (e.g., "1:2") Returns: dict with keys: - 'vad_result': VadResult - 'num_segments': int - 'segment_audios': list (always populated) For single segment (num_segments == 1): - 'transcription': str - 'fa_result': dict - 'trimmed_audio': tuple (sample_rate, array) For multi segment (num_segments > 1): - 'whisper_texts': list - 'error': str or None """ _ensure_models_on_gpu() return _run_initial_impl(audio_data, canonical_text, verse_ref) @spaces.GPU(duration=GPU_DURATION_WAV2VEC_FA) def run_wav2vec_and_fa_gpu(segment_audios, sample_rate): """ Combined Wav2Vec2 + FA for multi-segment path. Uses logits reuse optimization: model runs once for transcription, then logits are passed directly to FA (no duplicate model inference). Args: segment_audios: List of audio arrays per segment sample_rate: Audio sample rate Returns: tuple: (wav2vec_results, fa_results) """ _ensure_models_on_gpu() return _run_wav2vec_and_fa_impl(segment_audios, sample_rate) @spaces.GPU(duration=120) # Longer lease for full multi-model multi-segment pipeline def run_multi_segment_pipeline_gpu(audio_data, canonical_text, verse_ref): """ Complete multi-segment pipeline in a single GPU lease. Runs: VAD → Whisper (batched) → For each model: Wav2Vec2 + FA (batched) This avoids ZeroGPU token expiration by keeping all GPU work in one lease. Note: Model bundles are accessed from shared state INSIDE this function to avoid serialization issues with PyTorch models on ZeroGPU. Args: audio_data: Tuple of (sample_rate, audio_array) from Gradio canonical_text: Expected Arabic text for the verse verse_ref: Verse reference (e.g., "1:2") Returns: dict with keys: - 'vad_result': VadResult - 'num_segments': int - 'segment_audios': list - 'whisper_texts': list - 'model_results': dict mapping model_idx to (wav2vec_results, fa_results) - 'error': str or None """ _ensure_models_on_gpu() return _run_multi_segment_pipeline_impl(audio_data, canonical_text, verse_ref) # ============================================================================= # CPU fallback versions (no GPU lease, runs on CPU) # ============================================================================= def run_transcription_cpu(audio_data, processor_arg=None, model_arg=None): """CPU fallback for run_transcription_gpu.""" return _run_transcription_impl(audio_data, processor_arg, model_arg) def run_fa_extraction_cpu(audio_array, sample_rate, phoneme_sequence): """CPU fallback for run_fa_extraction_gpu.""" return _run_fa_extraction_impl(audio_array, sample_rate, phoneme_sequence) def run_initial_cpu(audio_data, canonical_text, verse_ref): """CPU fallback for run_initial_gpu.""" return _run_initial_impl(audio_data, canonical_text, verse_ref) def run_wav2vec_and_fa_cpu(segment_audios, sample_rate): """CPU fallback for run_wav2vec_and_fa_gpu.""" return _run_wav2vec_and_fa_impl(segment_audios, sample_rate) def run_multi_segment_pipeline_cpu(audio_data, canonical_text, verse_ref): """CPU fallback for run_multi_segment_pipeline_gpu.""" return _run_multi_segment_pipeline_impl(audio_data, canonical_text, verse_ref)