Quran-multi-aligner / src /api /session_api.py
hetchyy's picture
feat: add /process_url_session API endpoint for URL-based alignment
e67922d verified
"""Session-based API: persistence layer + endpoint wrappers.
Sessions store preprocessed audio and VAD data in /tmp so that
follow-up calls (resegment, retranscribe, realign) skip expensive
re-uploads and re-inference.
"""
import hashlib
import json
import math
import os
import pickle
import re
import shutil
import time
import uuid
import gradio as gr
import numpy as np
from config import SESSION_DIR, SESSION_EXPIRY_SECONDS, PHONEME_ASR_MODELS
from src.core.zero_gpu import QuotaExhaustedError
# ---------------------------------------------------------------------------
# Session manager
# ---------------------------------------------------------------------------
_last_cleanup_time = 0.0
_CLEANUP_INTERVAL = 1800 # sweep at most every 30 min
_VALID_ID = re.compile(r"^[0-9a-f]{32}$")
_VALID_MODELS = set(PHONEME_ASR_MODELS.keys())
def _validate_model_name(model_name):
"""Return an error dict if model_name is invalid, else None."""
if model_name not in _VALID_MODELS:
valid = ", ".join(sorted(_VALID_MODELS))
return {"error": f"Invalid model_name '{model_name}'. Must be one of: {valid}", "segments": []}
def _session_dir(audio_id: str):
return SESSION_DIR / audio_id
def _validate_id(audio_id: str) -> bool:
return isinstance(audio_id, str) and bool(_VALID_ID.match(audio_id))
def _is_expired(created_at: float) -> bool:
return (time.time() - created_at) > SESSION_EXPIRY_SECONDS
def _sweep_expired():
"""Delete expired session directories (runs at most every 30 min)."""
global _last_cleanup_time
now = time.time()
if now - _last_cleanup_time < _CLEANUP_INTERVAL:
return
_last_cleanup_time = now
if not SESSION_DIR.exists():
return
for entry in SESSION_DIR.iterdir():
if not entry.is_dir():
continue
ts_file = entry / "created_at"
if not ts_file.exists() or _is_expired(float(ts_file.read_text())):
shutil.rmtree(entry, ignore_errors=True)
def _intervals_hash(intervals) -> str:
return hashlib.md5(json.dumps(intervals).encode()).hexdigest()
def create_session(audio, speech_intervals, is_complete, intervals, model_name):
"""Persist session data and return audio_id (32-char hex UUID).
Uses pickle for VAD artifacts (speech_intervals, is_complete) to
preserve exact types (torch.Tensor etc.) expected by the segmenter.
Uses np.save for the audio array (large, always float32 numpy).
"""
_sweep_expired()
audio_id = uuid.uuid4().hex
path = _session_dir(audio_id)
path.mkdir(parents=True, exist_ok=True)
# Audio is always a float32 numpy array after preprocessing
np.save(path / "audio.npy", audio)
# VAD artifacts: preserve original types via pickle
with open(path / "vad.pkl", "wb") as f:
pickle.dump({"speech_intervals": speech_intervals,
"is_complete": is_complete}, f)
# Lightweight metadata (JSON-safe types only)
meta = {
"intervals": intervals,
"model_name": model_name,
"intervals_hash": _intervals_hash(intervals),
"audio_duration_s": round(len(audio) / 16000, 2),
}
with open(path / "metadata.json", "w") as f:
json.dump(meta, f)
# Timestamp file for cheap expiry checks during sweep
(path / "created_at").write_text(str(time.time()))
return audio_id
def load_session(audio_id):
"""Load session data. Returns dict or None if missing/expired/invalid."""
if not _validate_id(audio_id):
return None
path = _session_dir(audio_id)
if not path.exists():
return None
ts_file = path / "created_at"
if not ts_file.exists() or _is_expired(float(ts_file.read_text())):
shutil.rmtree(path, ignore_errors=True)
return None
audio = np.load(path / "audio.npy")
with open(path / "vad.pkl", "rb") as f:
vad = pickle.load(f)
with open(path / "metadata.json") as f:
meta = json.load(f)
return {
"audio": audio,
"speech_intervals": vad["speech_intervals"],
"is_complete": vad["is_complete"],
"intervals": meta["intervals"],
"model_name": meta["model_name"],
"intervals_hash": meta.get("intervals_hash", ""),
"audio_id": audio_id,
}
def update_session(audio_id, *, intervals=None, model_name=None):
"""Update mutable session fields (intervals, model_name)."""
path = _session_dir(audio_id)
meta_path = path / "metadata.json"
if not meta_path.exists():
return
with open(meta_path) as f:
meta = json.load(f)
if intervals is not None:
meta["intervals"] = intervals
meta["intervals_hash"] = _intervals_hash(intervals)
if model_name is not None:
meta["model_name"] = model_name
tmp = path / "metadata.tmp"
with open(tmp, "w") as f:
json.dump(meta, f)
os.replace(tmp, meta_path)
def _save_segments(audio_id, segments):
"""Persist alignment segments for later MFA use."""
path = _session_dir(audio_id)
if not path.exists():
return
seg_path = path / "segments.json"
tmp = path / "segments.tmp"
with open(tmp, "w") as f:
json.dump(segments, f)
os.replace(tmp, seg_path)
def _load_segments(audio_id):
"""Load stored segments. Returns list or None."""
if not _validate_id(audio_id):
return None
path = _session_dir(audio_id)
seg_path = path / "segments.json"
if not seg_path.exists():
return None
with open(seg_path) as f:
return json.load(f)
# ---------------------------------------------------------------------------
# Response formatting
# ---------------------------------------------------------------------------
_SESSION_ERROR = {"error": "Session not found or expired", "segments": []}
# ---------------------------------------------------------------------------
# Duration estimation
# ---------------------------------------------------------------------------
_ESTIMABLE_ENDPOINTS = {
"process_audio_session",
"process_url_session",
"resegment",
"retranscribe",
"realign_from_timestamps",
"timestamps",
"timestamps_direct",
}
_MFA_ENDPOINTS = {"timestamps", "timestamps_direct"}
_VAD_ENDPOINTS = {"process_audio_session", "process_url_session"}
def _load_session_metadata(audio_id):
"""Load only metadata.json (no audio/VAD). Returns dict or None."""
if not _validate_id(audio_id):
return None
path = _session_dir(audio_id)
meta_path = path / "metadata.json"
if not meta_path.exists():
return None
ts_file = path / "created_at"
if not ts_file.exists() or _is_expired(float(ts_file.read_text())):
return None
with open(meta_path) as f:
return json.load(f)
def estimate_duration(endpoint, audio_duration_s=None, audio_id=None,
model_name="Base", device="GPU"):
"""Estimate processing duration for a given endpoint.
Uses direct wall-time regression (not sum of lease components) fitted on
257 runs from hetchyy/quran-aligner-logs v1 dataset.
"""
from config import (
ESTIMATE_GPU_BASE_SLOPE, ESTIMATE_GPU_BASE_INTERCEPT,
ESTIMATE_GPU_LARGE_SLOPE, ESTIMATE_GPU_LARGE_INTERCEPT,
ESTIMATE_CPU_BASE_SLOPE, ESTIMATE_CPU_BASE_INTERCEPT,
ESTIMATE_CPU_LARGE_SLOPE, ESTIMATE_CPU_LARGE_INTERCEPT,
ESTIMATE_WALL_BUFFER,
MFA_PROGRESS_SEGMENT_RATE,
)
_error = {"estimated_duration_s": None}
if endpoint not in _ESTIMABLE_ENDPOINTS:
_error["error"] = (
f"Unknown endpoint '{endpoint}'. "
f"Valid: {', '.join(sorted(_ESTIMABLE_ENDPOINTS))}"
)
return _error
# --- Resolve audio duration ---
meta = None
if audio_id:
meta = _load_session_metadata(audio_id)
if audio_duration_s is not None and audio_duration_s > 0:
duration_s = float(audio_duration_s)
elif meta and meta.get("audio_duration_s"):
duration_s = meta["audio_duration_s"]
else:
_error["error"] = (
"audio_duration_s is required (or provide audio_id with an existing session)"
)
return _error
minutes = duration_s / 60.0
# --- MFA endpoints require session with stored segments ---
if endpoint in _MFA_ENDPOINTS:
if not audio_id:
_error["error"] = "MFA estimation requires audio_id with existing segments"
return _error
segments = _load_segments(audio_id)
if not segments:
_error["error"] = "No segments found in session โ€” run an alignment endpoint first"
return _error
num_segments = len(segments)
estimate = MFA_PROGRESS_SEGMENT_RATE * num_segments
else:
# --- Pipeline endpoints: direct wall-time regression ---
device_upper = (device or "GPU").upper()
is_large = model_name == "Large"
if device_upper == "CPU":
if is_large:
estimate = ESTIMATE_CPU_LARGE_SLOPE * minutes + ESTIMATE_CPU_LARGE_INTERCEPT
else:
estimate = ESTIMATE_CPU_BASE_SLOPE * minutes + ESTIMATE_CPU_BASE_INTERCEPT
else:
if is_large:
estimate = ESTIMATE_GPU_LARGE_SLOPE * minutes + ESTIMATE_GPU_LARGE_INTERCEPT
else:
estimate = ESTIMATE_GPU_BASE_SLOPE * minutes + ESTIMATE_GPU_BASE_INTERCEPT
# Retranscribe/realign skip VAD โ€” scale down by ~50% (ASR+DP only)
if endpoint not in _VAD_ENDPOINTS:
estimate *= 0.5
estimate *= ESTIMATE_WALL_BUFFER
rounded = max(5, math.ceil(estimate / 5) * 5)
return {
"endpoint": endpoint,
"estimated_duration_s": rounded,
"device": device,
"model_name": model_name,
}
def _format_response(audio_id, json_output, warning=None):
"""Convert pipeline json_output to the documented API response schema."""
segments = []
for seg in json_output.get("segments", []):
entry = {
"segment": seg["segment"],
"time_from": seg["time_from"],
"time_to": seg["time_to"],
"ref_from": seg["ref_from"],
"ref_to": seg["ref_to"],
"matched_text": seg["matched_text"],
"confidence": seg["confidence"],
"has_missing_words": seg.get("has_missing_words", False),
"error": seg["error"],
}
if seg.get("special_type"):
entry["special_type"] = seg["special_type"]
segments.append(entry)
_save_segments(audio_id, segments)
resp = {"audio_id": audio_id, "segments": segments}
if warning:
resp["warning"] = warning
return resp
# ---------------------------------------------------------------------------
# Endpoint wrappers
# ---------------------------------------------------------------------------
def process_audio_session(audio_data, min_silence_ms, min_speech_ms, pad_ms,
model_name="Base", device="GPU",
request: gr.Request = None):
"""Full pipeline: preprocess -> VAD -> ASR -> alignment. Creates session."""
err = _validate_model_name(model_name)
if err:
return err
from src.pipeline import process_audio
quota_warning = None
try:
result = process_audio(
audio_data, int(min_silence_ms), int(min_speech_ms), int(pad_ms),
model_name, device, request=request, endpoint="process",
)
except QuotaExhaustedError as e:
reset_msg = f" Resets in {e.reset_time}." if e.reset_time else ""
quota_warning = f"GPU quota reached โ€” processed on CPU (slower).{reset_msg}"
result = process_audio(
audio_data, int(min_silence_ms), int(min_speech_ms), int(pad_ms),
model_name, "CPU", request=request, endpoint="process",
)
# result is a 9-tuple:
# (html, json_output, speech_intervals, is_complete, audio, sr, intervals, seg_dir, log_row)
json_output = result[1]
if json_output is None:
return {"error": "No speech detected in audio", "segments": []}
speech_intervals = result[2]
is_complete = result[3]
audio_ref = result[4]
intervals = result[6]
# Resolve audio from pipeline cache (result[4] is now a cache key, not array)
from src.pipeline import _load_audio
audio, _ = _load_audio(audio_ref)
audio_id = create_session(
audio, speech_intervals, is_complete, intervals, model_name,
)
return _format_response(audio_id, json_output, warning=quota_warning)
def process_url_session(url, min_silence_ms, min_speech_ms, pad_ms,
model_name="Base", device="GPU",
request: gr.Request = None):
"""Full pipeline from URL: download -> preprocess -> VAD -> ASR -> alignment.
Downloads audio via yt-dlp, then runs the same pipeline as
process_audio_session. Returns the same response format with an
additional url_metadata field.
"""
err = _validate_model_name(model_name)
if err:
return err
if not url or not isinstance(url, str) or not url.strip():
return {"error": "URL is required", "segments": []}
url = url.strip()
# Download audio
try:
from src.ui.handlers import _download_url_core
wav_path, url_meta = _download_url_core(url)
except Exception as e:
return {"error": f"Download failed: {e}", "segments": []}
# Run the standard pipeline with the downloaded WAV path
from src.pipeline import process_audio
quota_warning = None
try:
result = process_audio(
wav_path, int(min_silence_ms), int(min_speech_ms), int(pad_ms),
model_name, device, request=request, endpoint="process_url",
)
except QuotaExhaustedError as e:
reset_msg = f" Resets in {e.reset_time}." if e.reset_time else ""
quota_warning = f"GPU quota reached โ€” processed on CPU (slower).{reset_msg}"
result = process_audio(
wav_path, int(min_silence_ms), int(min_speech_ms), int(pad_ms),
model_name, "CPU", request=request, endpoint="process_url",
)
json_output = result[1]
if json_output is None:
return {"error": "No speech detected in audio", "segments": []}
speech_intervals = result[2]
is_complete = result[3]
audio_ref = result[4]
intervals = result[6]
from src.pipeline import _load_audio
audio, _ = _load_audio(audio_ref)
audio_id = create_session(
audio, speech_intervals, is_complete, intervals, model_name,
)
response = _format_response(audio_id, json_output, warning=quota_warning)
response["url_metadata"] = {
"title": url_meta.get("title"),
"duration": url_meta.get("duration"),
"source_url": url_meta.get("source_url"),
}
# Clean up downloaded WAV (audio is now cached in session)
try:
os.remove(wav_path)
except OSError:
pass
return response
def resegment(audio_id, min_silence_ms, min_speech_ms, pad_ms,
model_name="Base", device="GPU",
request: gr.Request = None):
"""Re-clean VAD boundaries with new params and re-run ASR + alignment."""
err = _validate_model_name(model_name)
if err:
err["audio_id"] = audio_id
return err
session = load_session(audio_id)
if session is None:
return _SESSION_ERROR
from src.pipeline import resegment_audio
quota_warning = None
try:
result = resegment_audio(
session["speech_intervals"], session["is_complete"],
session["audio"], 16000,
int(min_silence_ms), int(min_speech_ms), int(pad_ms),
model_name, device, request=request, endpoint="resegment",
)
except QuotaExhaustedError as e:
reset_msg = f" Resets in {e.reset_time}." if e.reset_time else ""
quota_warning = f"GPU quota reached โ€” processed on CPU (slower).{reset_msg}"
result = resegment_audio(
session["speech_intervals"], session["is_complete"],
session["audio"], 16000,
int(min_silence_ms), int(min_speech_ms), int(pad_ms),
model_name, "CPU", request=request, endpoint="resegment",
)
json_output = result[1]
if json_output is None:
return {"audio_id": audio_id, "error": "No segments with these settings", "segments": []}
new_intervals = result[6]
update_session(audio_id, intervals=new_intervals, model_name=model_name)
return _format_response(audio_id, json_output, warning=quota_warning)
def retranscribe(audio_id, model_name="Base", device="GPU",
request: gr.Request = None):
"""Re-run ASR with a different model on current segment boundaries."""
err = _validate_model_name(model_name)
if err:
err["audio_id"] = audio_id
return err
session = load_session(audio_id)
if session is None:
return _SESSION_ERROR
# Guard: reject if model and boundaries unchanged
if (model_name == session["model_name"]
and _intervals_hash(session["intervals"]) == session["intervals_hash"]):
return {
"audio_id": audio_id,
"error": "Model and boundaries unchanged. Change model_name or call /resegment first.",
"segments": [],
}
from src.pipeline import retranscribe_audio
quota_warning = None
try:
result = retranscribe_audio(
session["intervals"],
session["audio"], 16000,
session["speech_intervals"], session["is_complete"],
model_name, device, request=request, endpoint="retranscribe",
)
except QuotaExhaustedError as e:
reset_msg = f" Resets in {e.reset_time}." if e.reset_time else ""
quota_warning = f"GPU quota reached โ€” processed on CPU (slower).{reset_msg}"
result = retranscribe_audio(
session["intervals"],
session["audio"], 16000,
session["speech_intervals"], session["is_complete"],
model_name, "CPU", request=request, endpoint="retranscribe",
)
json_output = result[1]
if json_output is None:
return {"audio_id": audio_id, "error": "Retranscription failed", "segments": []}
update_session(audio_id, model_name=model_name)
return _format_response(audio_id, json_output, warning=quota_warning)
def realign_from_timestamps(audio_id, timestamps, model_name="Base", device="GPU",
request: gr.Request = None):
"""Run ASR + alignment on caller-provided timestamp intervals."""
err = _validate_model_name(model_name)
if err:
err["audio_id"] = audio_id
return err
session = load_session(audio_id)
if session is None:
return _SESSION_ERROR
# Parse timestamps: accept list of {"start": f, "end": f} dicts
if isinstance(timestamps, str):
timestamps = json.loads(timestamps)
intervals = [(ts["start"], ts["end"]) for ts in timestamps]
from src.pipeline import realign_audio
quota_warning = None
try:
result = realign_audio(
intervals,
session["audio"], 16000,
session["speech_intervals"], session["is_complete"],
model_name, device, request=request, endpoint="realign",
)
except QuotaExhaustedError as e:
reset_msg = f" Resets in {e.reset_time}." if e.reset_time else ""
quota_warning = f"GPU quota reached โ€” processed on CPU (slower).{reset_msg}"
result = realign_audio(
intervals,
session["audio"], 16000,
session["speech_intervals"], session["is_complete"],
model_name, "CPU", request=request, endpoint="realign",
)
json_output = result[1]
if json_output is None:
return {"audio_id": audio_id, "error": "Alignment failed", "segments": []}
new_intervals = result[6]
update_session(audio_id, intervals=new_intervals, model_name=model_name)
return _format_response(audio_id, json_output, warning=quota_warning)
# ---------------------------------------------------------------------------
# MFA timestamp helpers
# ---------------------------------------------------------------------------
def _preprocess_api_audio(audio_data):
"""Convert audio input to 16kHz mono float32 numpy array.
Handles file path (str) and Gradio numpy tuple (sample_rate, array).
Returns (audio_np, sample_rate).
"""
import librosa
from config import RESAMPLE_TYPE
if isinstance(audio_data, str):
audio, sr = librosa.load(audio_data, sr=16000, mono=True, res_type=RESAMPLE_TYPE)
return audio, 16000
sample_rate, audio = audio_data
if audio.dtype == np.int16:
audio = audio.astype(np.float32) / 32768.0
elif audio.dtype == np.int32:
audio = audio.astype(np.float32) / 2147483648.0
if len(audio.shape) > 1:
audio = audio.mean(axis=1)
if sample_rate != 16000:
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000, res_type=RESAMPLE_TYPE)
sample_rate = 16000
return audio, sample_rate
def _create_segment_wavs(audio_np, sample_rate, segments):
"""Slice audio by segment boundaries and write WAV files.
Returns the temp directory path containing seg_0.wav, seg_1.wav, etc.
"""
import tempfile
import soundfile as sf
seg_dir = tempfile.mkdtemp(prefix="mfa_api_")
for seg in segments:
seg_idx = seg.get("segment", 0) - 1
time_from = seg.get("time_from", 0)
time_to = seg.get("time_to", 0)
start_sample = int(time_from * sample_rate)
end_sample = int(time_to * sample_rate)
segment_audio = audio_np[start_sample:end_sample]
wav_path = os.path.join(seg_dir, f"seg_{seg_idx}.wav")
sf.write(wav_path, segment_audio, sample_rate)
return seg_dir
# ---------------------------------------------------------------------------
# MFA timestamp helpers
# ---------------------------------------------------------------------------
_SPECIAL_TEXT = {
"Basmala": "ุจูุณู’ู…ู ูฑู„ู„ูŽู‘ู‡ู ูฑู„ุฑูŽู‘ุญู’ู…ูŽูฐู†ู ูฑู„ุฑูŽู‘ุญููŠู…",
"Isti'adha": "ุฃูŽุนููˆุฐู ุจููฑู„ู„ูŽู‘ู‡ู ู…ูู†ูŽ ุงู„ุดูŽู‘ูŠู’ุทูŽุงู†ู ุงู„ุฑูŽู‘ุฌููŠู…",
"Amin": "ุขู…ููŠู†",
"Takbir": "ุงู„ู„ูŽู‘ู‡ู ุฃูŽูƒู’ุจูŽุฑ",
"Tahmeed": "ุณูŽู…ูุนูŽ ุงู„ู„ูŽู‘ู‡ู ู„ูู…ูŽู†ู’ ุญูŽู…ูุฏูŽู‡",
"Tasleem": "ูฑู„ุณูŽู‘ู„ูŽุงู…ู ุนูŽู„ูŽูŠู’ูƒูู…ู’ ูˆูŽุฑูŽุญู’ู…ูŽุฉู ูฑู„ู„ูŽู‘ู‡",
"Sadaqa": "ุตูŽุฏูŽู‚ูŽ ูฑู„ู„ูŽู‘ู‡ู ูฑู„ู’ุนูŽุธููŠู…",
}
def _normalize_segments(segments):
"""Fill defaults so callers can pass minimal segment dicts (timestamps + refs).
Auto-assigns ``segment`` numbers, defaults ``confidence`` to 1.0, and
derives ``matched_text`` from ``special_type`` for special segments.
"""
normalized = []
for i, seg in enumerate(segments):
entry = dict(seg)
if "segment" not in entry:
entry["segment"] = i + 1
if "confidence" not in entry:
entry["confidence"] = 1.0
if "matched_text" not in entry:
special = entry.get("special_type", "")
entry["matched_text"] = _SPECIAL_TEXT.get(special, "")
normalized.append(entry)
return normalized
# ---------------------------------------------------------------------------
# MFA timestamp endpoints
# ---------------------------------------------------------------------------
def timestamps(audio_id, segments_json=None, granularity="words"):
"""Compute MFA word/letter timestamps using session audio."""
if granularity == "words+chars":
return {"audio_id": audio_id, "error": "chars granularity is currently disabled via API", "segments": []}
session = load_session(audio_id)
if session is None:
return _SESSION_ERROR
# Parse segments: use provided or load stored
if isinstance(segments_json, str):
segments_json = json.loads(segments_json)
if segments_json:
segments = _normalize_segments(segments_json)
else:
segments = _load_segments(audio_id)
if not segments:
return {"audio_id": audio_id, "error": "No segments found in session", "segments": []}
# Create segment WAVs from session audio
try:
seg_dir = _create_segment_wavs(session["audio"], 16000, segments)
except Exception as e:
return {"audio_id": audio_id, "error": f"Failed to create segment audio: {e}", "segments": []}
from src.mfa import compute_mfa_timestamps_api
try:
result = compute_mfa_timestamps_api(segments, seg_dir, granularity or "words")
except Exception as e:
return {"audio_id": audio_id, "error": f"MFA alignment failed: {e}", "segments": []}
result["audio_id"] = audio_id
return result
def timestamps_direct(audio_data, segments_json, granularity="words"):
"""Compute MFA word/letter timestamps with provided audio and segments."""
if granularity == "words+chars":
return {"error": "chars granularity is currently disabled via API", "segments": []}
# Parse segments
if isinstance(segments_json, str):
segments_json = json.loads(segments_json)
if not segments_json:
return {"error": "No segments provided", "segments": []}
segments = _normalize_segments(segments_json)
# Preprocess audio
try:
audio_np, sr = _preprocess_api_audio(audio_data)
except Exception as e:
return {"error": f"Failed to preprocess audio: {e}", "segments": []}
# Create segment WAVs
try:
seg_dir = _create_segment_wavs(audio_np, sr, segments)
except Exception as e:
return {"error": f"Failed to create segment audio: {e}", "segments": []}
from src.mfa import compute_mfa_timestamps_api
try:
result = compute_mfa_timestamps_api(segments, seg_dir, granularity or "words")
except Exception as e:
return {"error": f"MFA alignment failed: {e}", "segments": []}
return result