Tajweed-AI / recitation_engine /gpu_inference.py
hetchyy's picture
Add GPU quota fallback to CPU with user notification
df88bcd
"""
GPU-decorated inference functions for ZeroGPU optimization.
These functions wrap model inference calls with GPU leases and logging
for efficient GPU utilization on HuggingFace Spaces with ZeroGPU.
Each GPU function delegates to an _impl function that contains the actual
logic. CPU fallback versions call the _impl directly (without GPU lease).
"""
import sys
import time
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from config import GPU_DURATION_INITIAL, GPU_DURATION_WAV2VEC_FA
from recitation_engine.segment_processor import detect_vad_segments, transcribe_segments_batched
from recitation_engine.gpu_profiler import record_gpu_lease
from shared_state import get_processor, get_model
# Import spaces.GPU decorator directly for ZeroGPU static scanner detection
# The scanner looks for the literal `@spaces.GPU` decorator name via AST parsing
try:
import spaces
except ImportError:
# Fallback no-op decorator for local development without spaces package
class spaces:
@staticmethod
def GPU(*args, **kwargs):
def decorator(func):
return func
if args and callable(args[0]):
return args[0]
return decorator
def _ensure_models_on_gpu():
"""Move all models to GPU at the start of a GPU lease.
On HF Spaces with ZeroGPU, models are loaded on CPU to avoid CUDA init
in the main process. This function moves them to GPU when we have an
active GPU lease (inside a @spaces.GPU function).
"""
from recitation_engine.model_loader import move_models_to_gpu
from recitation_engine.segment_processor import move_segment_models_to_gpu
move_models_to_gpu()
move_segment_models_to_gpu()
# =============================================================================
# Quota exceeded detection
# =============================================================================
def is_quota_exceeded(exc: Exception) -> bool:
"""Check if exception is a ZeroGPU quota exceeded error."""
msg = str(exc).lower()
return any(kw in msg for kw in ("quota", "exceeded", "gpu quota", "gpu task aborted"))
# =============================================================================
# Implementation functions (device-agnostic)
# =============================================================================
def _run_transcription_impl(audio_data, processor_arg=None, model_arg=None):
"""Core transcription logic β€” runs on whatever device models are on."""
from recitation_engine.transcription import transcribe_audio
processor = get_processor()
model = get_model()
start = time.time()
result = transcribe_audio(audio_data, processor, model)
elapsed = time.time() - start
lease_duration = GPU_DURATION_INITIAL
utilization = (elapsed / lease_duration) * 100 if lease_duration > 0 else 0
print(f"[GPU STATS] TRANSCRIPTION ──────────────────────────────────────")
print(f" Runtime: {elapsed:.2f}s | Lease: {lease_duration}s | Utilization: {utilization:.1f}%")
print(f"[GPU STATS] TRANSCRIPTION ──────────────────────────────────────")
record_gpu_lease("Transcription", lease_duration, elapsed)
return result
def _run_fa_extraction_impl(audio_array, sample_rate, phoneme_sequence):
"""Core forced alignment logic β€” runs on whatever device models are on."""
import torch
from recitation_analysis.duration_analysis.fa_backend import extract_phoneme_timestamps
processor = get_processor()
model = get_model()
device = "cuda" if torch.cuda.is_available() else "cpu"
start = time.time()
fa_result = extract_phoneme_timestamps(
audio_array=audio_array,
sample_rate=sample_rate,
phoneme_sequence=phoneme_sequence,
model=model,
processor=processor,
device=device,
return_visualization_data=True,
)
elapsed = time.time() - start
lease_duration = GPU_DURATION_WAV2VEC_FA
utilization = (elapsed / lease_duration) * 100 if lease_duration > 0 else 0
print(f"[GPU STATS] FA EXTRACTION ──────────────────────────────")
print(f" Runtime: {elapsed:.2f}s | Lease: {lease_duration}s | Utilization: {utilization:.1f}%")
print(f"[GPU STATS] FA EXTRACTION ──────────────────────────────")
record_gpu_lease("FA Extraction", lease_duration, elapsed)
return fa_result
def _run_initial_impl(audio_data, canonical_text, verse_ref):
"""Core smart router logic β€” runs on whatever device models are on."""
import torch
from recitation_engine.transcription import transcribe_audio
from recitation_analysis.duration_analysis.fa_backend import extract_phoneme_timestamps
start = time.time()
# Step 1: VAD
vad_start = time.time()
vad_result = detect_vad_segments(audio_data, canonical_text, verse_ref)
vad_elapsed = time.time() - vad_start
if not vad_result or not vad_result.vad_segments:
elapsed = time.time() - start
print(f"[GPU STATS] INITIAL (no speech) ──────────────────────")
print(f" VAD: {vad_elapsed:.2f}s | Total: {elapsed:.2f}s")
print(f"[GPU STATS] INITIAL ──────────────────────────────────")
record_gpu_lease("Initial", GPU_DURATION_INITIAL, elapsed, "no speech")
return {'error': "No speech detected", 'vad_result': None, 'num_segments': 0}
num_segments = len(vad_result.vad_segments)
# Extract segment audios (needed for both paths)
segment_audios = []
for vad_seg in vad_result.vad_segments:
start_sample = int(vad_seg.start_time * vad_result.sample_rate)
end_sample = int(vad_seg.end_time * vad_result.sample_rate)
segment_audios.append(vad_result.audio[start_sample:end_sample])
result = {
'vad_result': vad_result,
'num_segments': num_segments,
'segment_audios': segment_audios,
'error': None,
}
if num_segments == 1:
# === SINGLE SEGMENT PATH ===
trimmed_audio = segment_audios[0]
if len(trimmed_audio) < 1600: # < 0.1s at 16kHz
elapsed = time.time() - start
print(f"[GPU STATS] INITIAL (speech too short) ───────────────")
print(f" VAD: {vad_elapsed:.2f}s | Total: {elapsed:.2f}s")
print(f"[GPU STATS] INITIAL ──────────────────────────────────")
record_gpu_lease("Initial", GPU_DURATION_INITIAL, elapsed, "too short")
result['error'] = "Speech too short"
return result
processor = get_processor()
model = get_model()
# Transcription
trans_start = time.time()
transcription, trans_error = transcribe_audio(
(vad_result.sample_rate, trimmed_audio), processor, model
)
trans_elapsed = time.time() - trans_start
if trans_error or not transcription:
elapsed = time.time() - start
print(f"[GPU STATS] INITIAL (transcription error) ────────────")
print(f" VAD: {vad_elapsed:.2f}s | Trans: {trans_elapsed:.2f}s | Total: {elapsed:.2f}s")
print(f"[GPU STATS] INITIAL ──────────────────────────────────")
record_gpu_lease("Initial", GPU_DURATION_INITIAL, elapsed, "trans error")
result['error'] = trans_error or "No phonemes detected"
result['transcription'] = transcription
return result
# FA extraction
fa_start = time.time()
device = "cuda" if torch.cuda.is_available() else "cpu"
fa_result = extract_phoneme_timestamps(
audio_array=trimmed_audio,
sample_rate=vad_result.sample_rate,
phoneme_sequence=tuple(transcription.split()),
model=model,
processor=processor,
device=device,
return_visualization_data=True,
)
fa_elapsed = time.time() - fa_start
result['transcription'] = transcription
result['fa_result'] = fa_result
result['trimmed_audio'] = (vad_result.sample_rate, trimmed_audio)
elapsed = time.time() - start
utilization = (elapsed / GPU_DURATION_INITIAL) * 100
print(f"[GPU STATS] INITIAL (single segment) ─────────────────")
print(f" VAD: {vad_elapsed:.2f}s | Trans: {trans_elapsed:.2f}s | FA: {fa_elapsed:.2f}s")
print(f" Total: {elapsed:.2f}s | Lease: {GPU_DURATION_INITIAL}s | Utilization: {utilization:.1f}%")
print(f"[GPU STATS] INITIAL ──────────────────────────────────")
record_gpu_lease("Initial", GPU_DURATION_INITIAL, elapsed, "1 seg: VAD+Trans+FA")
else:
# === MULTI SEGMENT PATH ===
# Run Whisper on all segments
whisper_start = time.time()
whisper_texts = transcribe_segments_batched(segment_audios, vad_result.sample_rate)
whisper_elapsed = time.time() - whisper_start
result['whisper_texts'] = whisper_texts
elapsed = time.time() - start
utilization = (elapsed / GPU_DURATION_INITIAL) * 100
print(f"[GPU STATS] INITIAL (multi segment) ──────────────────")
print(f" VAD: {vad_elapsed:.2f}s | Whisper ({num_segments} segs): {whisper_elapsed:.2f}s")
print(f" Total: {elapsed:.2f}s | Lease: {GPU_DURATION_INITIAL}s | Utilization: {utilization:.1f}%")
print(f"[GPU STATS] INITIAL ──────────────────────────────────")
record_gpu_lease("Initial", GPU_DURATION_INITIAL, elapsed, f"{num_segments} segs: VAD+Whisper")
return result
def _run_wav2vec_and_fa_impl(segment_audios, sample_rate):
"""Core Wav2Vec2 + FA logic β€” runs on whatever device models are on."""
import torch
from recitation_engine.transcription import transcribe_audio_batched
from recitation_analysis.duration_analysis.fa_backend import forced_align_from_logits
start = time.time()
num_segments = len(segment_audios)
processor = get_processor()
model = get_model()
device = "cuda" if torch.cuda.is_available() else "cpu"
# Step 1: Wav2Vec2 batch transcription with logits for FA reuse
wav2vec_start = time.time()
wav2vec_results, logits_list, audio_durations = transcribe_audio_batched(
segment_audios, sample_rate, processor, model, return_logits=True
)
wav2vec_elapsed = time.time() - wav2vec_start
# Step 2: FA using pre-computed logits (no model inference)
fa_start = time.time()
fa_results = []
for i, (phonemes, logits, audio_dur) in enumerate(zip(wav2vec_results, logits_list, audio_durations)):
if not phonemes or logits is None:
fa_results.append(None)
continue
try:
phoneme_seq = tuple(phonemes.split()) if isinstance(phonemes, str) else tuple(phonemes)
fa_result = forced_align_from_logits(
logits=logits,
phoneme_sequence=phoneme_seq,
processor=processor,
audio_duration=audio_dur,
device=device,
return_visualization_data=True,
)
fa_results.append(fa_result)
except Exception as e:
print(f"[FA] Segment {i+1} failed: {e}")
fa_results.append(None)
fa_elapsed = time.time() - fa_start
elapsed = time.time() - start
utilization = (elapsed / GPU_DURATION_WAV2VEC_FA) * 100
print(f"[GPU STATS] WAV2VEC+FA ({num_segments} segments) ─────────────")
print(f" Wav2Vec2: {wav2vec_elapsed:.2f}s | FA (logits reuse): {fa_elapsed:.2f}s")
print(f" Total: {elapsed:.2f}s | Lease: {GPU_DURATION_WAV2VEC_FA}s | Utilization: {utilization:.1f}%")
print(f"[GPU STATS] WAV2VEC+FA ───────────────────────────────────")
record_gpu_lease("Wav2Vec+FA", GPU_DURATION_WAV2VEC_FA, elapsed, f"{num_segments} segs")
return wav2vec_results, fa_results
def _run_multi_segment_pipeline_impl(audio_data, canonical_text, verse_ref):
"""Core multi-segment pipeline logic β€” runs on whatever device models are on."""
import torch
from recitation_engine.transcription import transcribe_audio_batched
from recitation_analysis.duration_analysis.fa_backend import forced_align_from_logits
from shared_state import (
set_processor as state_set_processor,
set_model as state_set_model,
get_model_bundles as state_get_model_bundles,
)
# Get model bundles from shared state (not passed as arg to avoid serialization)
model_bundles = state_get_model_bundles()
start = time.time()
# Step 1: VAD
vad_start = time.time()
vad_result = detect_vad_segments(audio_data, canonical_text, verse_ref)
vad_elapsed = time.time() - vad_start
if not vad_result or not vad_result.vad_segments:
elapsed = time.time() - start
record_gpu_lease("MultiSegPipeline", 120, elapsed, "no speech")
return {'error': "No speech detected", 'vad_result': None, 'num_segments': 0}
num_segments = len(vad_result.vad_segments)
# Extract segment audios
segment_audios = []
for vad_seg in vad_result.vad_segments:
start_sample = int(vad_seg.start_time * vad_result.sample_rate)
end_sample = int(vad_seg.end_time * vad_result.sample_rate)
segment_audios.append(vad_result.audio[start_sample:end_sample])
# Step 2: Whisper batch transcription (model-independent)
whisper_start = time.time()
whisper_texts = transcribe_segments_batched(segment_audios, vad_result.sample_rate)
whisper_elapsed = time.time() - whisper_start
# Step 3: For each model, run Wav2Vec2 + FA
device = "cuda" if torch.cuda.is_available() else "cpu"
model_results = {}
# If no model bundles configured, use current model from shared state
if not model_bundles:
model_bundles = [{"processor": get_processor(), "model": get_model(), "path": "current"}]
for idx, bundle in enumerate(model_bundles):
processor = bundle["processor"]
model = bundle["model"]
# Set shared state for this model (needed by some downstream code)
state_set_processor(processor)
state_set_model(model)
# Move this model to GPU if needed
if model is not None:
current_device = next(model.parameters()).device
if current_device.type != "cuda" and torch.cuda.is_available():
model = model.to(device)
bundle["model"] = model
wav2vec_start = time.time()
wav2vec_results, logits_list, audio_durations = transcribe_audio_batched(
segment_audios, vad_result.sample_rate, processor, model, return_logits=True
)
wav2vec_elapsed = time.time() - wav2vec_start
# FA using pre-computed logits
fa_start = time.time()
fa_results = []
for i, (phonemes, logits, audio_dur) in enumerate(zip(wav2vec_results, logits_list, audio_durations)):
if not phonemes or logits is None:
fa_results.append(None)
continue
try:
phoneme_seq = tuple(phonemes.split()) if isinstance(phonemes, str) else tuple(phonemes)
fa_result = forced_align_from_logits(
logits=logits,
phoneme_sequence=phoneme_seq,
processor=processor,
audio_duration=audio_dur,
device=device,
return_visualization_data=True,
)
fa_results.append(fa_result)
except Exception as e:
print(f"[FA] Model {idx+1} Segment {i+1} failed: {e}")
fa_results.append(None)
fa_elapsed = time.time() - fa_start
model_results[idx] = {
'wav2vec_results': wav2vec_results,
'fa_results': fa_results,
'wav2vec_elapsed': wav2vec_elapsed,
'fa_elapsed': fa_elapsed,
}
print(f"[MODEL {idx+1}] Wav2Vec2: {wav2vec_elapsed:.2f}s | FA: {fa_elapsed:.2f}s")
elapsed = time.time() - start
utilization = (elapsed / 120) * 100
num_models = len(model_bundles)
print(f"[GPU STATS] MULTI-SEG PIPELINE ({num_segments} segs, {num_models} models) ─────")
print(f" VAD: {vad_elapsed:.2f}s | Whisper: {whisper_elapsed:.2f}s")
print(f" Total: {elapsed:.2f}s | Lease: 120s | Utilization: {utilization:.1f}%")
print(f"[GPU STATS] MULTI-SEG PIPELINE ───────────────────────────────")
record_gpu_lease("MultiSegPipeline", 120, elapsed, f"{num_segments} segs, {num_models} models")
return {
'vad_result': vad_result,
'num_segments': num_segments,
'segment_audios': segment_audios,
'whisper_texts': whisper_texts,
'model_results': model_results,
'error': None,
}
# =============================================================================
# GPU-decorated functions (thin wrappers)
# =============================================================================
@spaces.GPU(duration=GPU_DURATION_INITIAL)
def run_transcription_gpu(audio_data, processor_arg=None, model_arg=None):
"""
Run Wav2Vec2 transcription with GPU lease.
Note: Prefer run_initial_gpu() for optimized lease usage.
processor_arg/model_arg are ignored - we use shared state.
Args:
audio_data: Tuple of (sample_rate, audio_array) from Gradio
processor_arg: Ignored (kept for API compatibility)
model_arg: Ignored (kept for API compatibility)
Returns:
Tuple of (transcription, error)
"""
_ensure_models_on_gpu()
return _run_transcription_impl(audio_data, processor_arg, model_arg)
@spaces.GPU(duration=GPU_DURATION_WAV2VEC_FA)
def run_fa_extraction_gpu(audio_array, sample_rate, phoneme_sequence):
"""
Run CTC forced alignment extraction with GPU lease.
Note: Prefer run_initial_gpu() or run_wav2vec_and_fa_gpu() for optimized lease usage.
Args:
audio_array: Audio waveform as numpy array
sample_rate: Audio sample rate in Hz
phoneme_sequence: Tuple of phoneme strings to align
Returns:
Dict with FA results (segments, timestamps, visualization data)
"""
_ensure_models_on_gpu()
return _run_fa_extraction_impl(audio_array, sample_rate, phoneme_sequence)
@spaces.GPU(duration=GPU_DURATION_INITIAL)
def run_initial_gpu(audio_data, canonical_text, verse_ref):
"""
Smart router: VAD β†’ branch based on segment count.
Single segment: VAD β†’ Trim β†’ Transcription β†’ FA (all in one lease)
Multi segment: VAD β†’ Extract audios β†’ Whisper (returns for Lease 2)
Args:
audio_data: Tuple of (sample_rate, audio_array) from Gradio
canonical_text: Expected Arabic text for the verse
verse_ref: Verse reference (e.g., "1:2")
Returns:
dict with keys:
- 'vad_result': VadResult
- 'num_segments': int
- 'segment_audios': list (always populated)
For single segment (num_segments == 1):
- 'transcription': str
- 'fa_result': dict
- 'trimmed_audio': tuple (sample_rate, array)
For multi segment (num_segments > 1):
- 'whisper_texts': list
- 'error': str or None
"""
_ensure_models_on_gpu()
return _run_initial_impl(audio_data, canonical_text, verse_ref)
@spaces.GPU(duration=GPU_DURATION_WAV2VEC_FA)
def run_wav2vec_and_fa_gpu(segment_audios, sample_rate):
"""
Combined Wav2Vec2 + FA for multi-segment path.
Uses logits reuse optimization: model runs once for transcription,
then logits are passed directly to FA (no duplicate model inference).
Args:
segment_audios: List of audio arrays per segment
sample_rate: Audio sample rate
Returns:
tuple: (wav2vec_results, fa_results)
"""
_ensure_models_on_gpu()
return _run_wav2vec_and_fa_impl(segment_audios, sample_rate)
@spaces.GPU(duration=120) # Longer lease for full multi-model multi-segment pipeline
def run_multi_segment_pipeline_gpu(audio_data, canonical_text, verse_ref):
"""
Complete multi-segment pipeline in a single GPU lease.
Runs: VAD β†’ Whisper (batched) β†’ For each model: Wav2Vec2 + FA (batched)
This avoids ZeroGPU token expiration by keeping all GPU work in one lease.
Note: Model bundles are accessed from shared state INSIDE this function
to avoid serialization issues with PyTorch models on ZeroGPU.
Args:
audio_data: Tuple of (sample_rate, audio_array) from Gradio
canonical_text: Expected Arabic text for the verse
verse_ref: Verse reference (e.g., "1:2")
Returns:
dict with keys:
- 'vad_result': VadResult
- 'num_segments': int
- 'segment_audios': list
- 'whisper_texts': list
- 'model_results': dict mapping model_idx to (wav2vec_results, fa_results)
- 'error': str or None
"""
_ensure_models_on_gpu()
return _run_multi_segment_pipeline_impl(audio_data, canonical_text, verse_ref)
# =============================================================================
# CPU fallback versions (no GPU lease, runs on CPU)
# =============================================================================
def run_transcription_cpu(audio_data, processor_arg=None, model_arg=None):
"""CPU fallback for run_transcription_gpu."""
return _run_transcription_impl(audio_data, processor_arg, model_arg)
def run_fa_extraction_cpu(audio_array, sample_rate, phoneme_sequence):
"""CPU fallback for run_fa_extraction_gpu."""
return _run_fa_extraction_impl(audio_array, sample_rate, phoneme_sequence)
def run_initial_cpu(audio_data, canonical_text, verse_ref):
"""CPU fallback for run_initial_gpu."""
return _run_initial_impl(audio_data, canonical_text, verse_ref)
def run_wav2vec_and_fa_cpu(segment_audios, sample_rate):
"""CPU fallback for run_wav2vec_and_fa_gpu."""
return _run_wav2vec_and_fa_impl(segment_audios, sample_rate)
def run_multi_segment_pipeline_cpu(audio_data, canonical_text, verse_ref):
"""CPU fallback for run_multi_segment_pipeline_gpu."""
return _run_multi_segment_pipeline_impl(audio_data, canonical_text, verse_ref)