Spaces:
Running on Zero
Running on Zero
| """Session-based API: persistence layer + endpoint wrappers. | |
| Sessions store preprocessed audio and VAD data in /tmp so that | |
| follow-up calls (resegment, retranscribe, realign) skip expensive | |
| re-uploads and re-inference. | |
| """ | |
| import hashlib | |
| import json | |
| import math | |
| import os | |
| import pickle | |
| import re | |
| import shutil | |
| import time | |
| import uuid | |
| import gradio as gr | |
| import numpy as np | |
| from config import SESSION_DIR, SESSION_EXPIRY_SECONDS, PHONEME_ASR_MODELS | |
| from src.core.zero_gpu import QuotaExhaustedError | |
| # --------------------------------------------------------------------------- | |
| # Session manager | |
| # --------------------------------------------------------------------------- | |
| _last_cleanup_time = 0.0 | |
| _CLEANUP_INTERVAL = 1800 # sweep at most every 30 min | |
| _VALID_ID = re.compile(r"^[0-9a-f]{32}$") | |
| _VALID_MODELS = set(PHONEME_ASR_MODELS.keys()) | |
| def _validate_model_name(model_name): | |
| """Return an error dict if model_name is invalid, else None.""" | |
| if model_name not in _VALID_MODELS: | |
| valid = ", ".join(sorted(_VALID_MODELS)) | |
| return {"error": f"Invalid model_name '{model_name}'. Must be one of: {valid}", "segments": []} | |
| def _session_dir(audio_id: str): | |
| return SESSION_DIR / audio_id | |
| def _validate_id(audio_id: str) -> bool: | |
| return isinstance(audio_id, str) and bool(_VALID_ID.match(audio_id)) | |
| def _is_expired(created_at: float) -> bool: | |
| return (time.time() - created_at) > SESSION_EXPIRY_SECONDS | |
| def _sweep_expired(): | |
| """Delete expired session directories (runs at most every 30 min).""" | |
| global _last_cleanup_time | |
| now = time.time() | |
| if now - _last_cleanup_time < _CLEANUP_INTERVAL: | |
| return | |
| _last_cleanup_time = now | |
| if not SESSION_DIR.exists(): | |
| return | |
| for entry in SESSION_DIR.iterdir(): | |
| if not entry.is_dir(): | |
| continue | |
| ts_file = entry / "created_at" | |
| if not ts_file.exists() or _is_expired(float(ts_file.read_text())): | |
| shutil.rmtree(entry, ignore_errors=True) | |
| def _intervals_hash(intervals) -> str: | |
| return hashlib.md5(json.dumps(intervals).encode()).hexdigest() | |
| def create_session(audio, speech_intervals, is_complete, intervals, model_name): | |
| """Persist session data and return audio_id (32-char hex UUID). | |
| Uses pickle for VAD artifacts (speech_intervals, is_complete) to | |
| preserve exact types (torch.Tensor etc.) expected by the segmenter. | |
| Uses np.save for the audio array (large, always float32 numpy). | |
| """ | |
| _sweep_expired() | |
| audio_id = uuid.uuid4().hex | |
| path = _session_dir(audio_id) | |
| path.mkdir(parents=True, exist_ok=True) | |
| # Audio is always a float32 numpy array after preprocessing | |
| np.save(path / "audio.npy", audio) | |
| # VAD artifacts: preserve original types via pickle | |
| with open(path / "vad.pkl", "wb") as f: | |
| pickle.dump({"speech_intervals": speech_intervals, | |
| "is_complete": is_complete}, f) | |
| # Lightweight metadata (JSON-safe types only) | |
| meta = { | |
| "intervals": intervals, | |
| "model_name": model_name, | |
| "intervals_hash": _intervals_hash(intervals), | |
| "audio_duration_s": round(len(audio) / 16000, 2), | |
| } | |
| with open(path / "metadata.json", "w") as f: | |
| json.dump(meta, f) | |
| # Timestamp file for cheap expiry checks during sweep | |
| (path / "created_at").write_text(str(time.time())) | |
| return audio_id | |
| def load_session(audio_id): | |
| """Load session data. Returns dict or None if missing/expired/invalid.""" | |
| if not _validate_id(audio_id): | |
| return None | |
| path = _session_dir(audio_id) | |
| if not path.exists(): | |
| return None | |
| ts_file = path / "created_at" | |
| if not ts_file.exists() or _is_expired(float(ts_file.read_text())): | |
| shutil.rmtree(path, ignore_errors=True) | |
| return None | |
| audio = np.load(path / "audio.npy") | |
| with open(path / "vad.pkl", "rb") as f: | |
| vad = pickle.load(f) | |
| with open(path / "metadata.json") as f: | |
| meta = json.load(f) | |
| return { | |
| "audio": audio, | |
| "speech_intervals": vad["speech_intervals"], | |
| "is_complete": vad["is_complete"], | |
| "intervals": meta["intervals"], | |
| "model_name": meta["model_name"], | |
| "intervals_hash": meta.get("intervals_hash", ""), | |
| "audio_id": audio_id, | |
| } | |
| def update_session(audio_id, *, intervals=None, model_name=None): | |
| """Update mutable session fields (intervals, model_name).""" | |
| path = _session_dir(audio_id) | |
| meta_path = path / "metadata.json" | |
| if not meta_path.exists(): | |
| return | |
| with open(meta_path) as f: | |
| meta = json.load(f) | |
| if intervals is not None: | |
| meta["intervals"] = intervals | |
| meta["intervals_hash"] = _intervals_hash(intervals) | |
| if model_name is not None: | |
| meta["model_name"] = model_name | |
| tmp = path / "metadata.tmp" | |
| with open(tmp, "w") as f: | |
| json.dump(meta, f) | |
| os.replace(tmp, meta_path) | |
| def _save_segments(audio_id, segments): | |
| """Persist alignment segments for later MFA use.""" | |
| path = _session_dir(audio_id) | |
| if not path.exists(): | |
| return | |
| seg_path = path / "segments.json" | |
| tmp = path / "segments.tmp" | |
| with open(tmp, "w") as f: | |
| json.dump(segments, f) | |
| os.replace(tmp, seg_path) | |
| def _load_segments(audio_id): | |
| """Load stored segments. Returns list or None.""" | |
| if not _validate_id(audio_id): | |
| return None | |
| path = _session_dir(audio_id) | |
| seg_path = path / "segments.json" | |
| if not seg_path.exists(): | |
| return None | |
| with open(seg_path) as f: | |
| return json.load(f) | |
| # --------------------------------------------------------------------------- | |
| # Response formatting | |
| # --------------------------------------------------------------------------- | |
| _SESSION_ERROR = {"error": "Session not found or expired", "segments": []} | |
| # --------------------------------------------------------------------------- | |
| # Duration estimation | |
| # --------------------------------------------------------------------------- | |
| _ESTIMABLE_ENDPOINTS = { | |
| "process_audio_session", | |
| "process_url_session", | |
| "resegment", | |
| "retranscribe", | |
| "realign_from_timestamps", | |
| "timestamps", | |
| "timestamps_direct", | |
| } | |
| _MFA_ENDPOINTS = {"timestamps", "timestamps_direct"} | |
| _VAD_ENDPOINTS = {"process_audio_session", "process_url_session"} | |
| def _load_session_metadata(audio_id): | |
| """Load only metadata.json (no audio/VAD). Returns dict or None.""" | |
| if not _validate_id(audio_id): | |
| return None | |
| path = _session_dir(audio_id) | |
| meta_path = path / "metadata.json" | |
| if not meta_path.exists(): | |
| return None | |
| ts_file = path / "created_at" | |
| if not ts_file.exists() or _is_expired(float(ts_file.read_text())): | |
| return None | |
| with open(meta_path) as f: | |
| return json.load(f) | |
| def estimate_duration(endpoint, audio_duration_s=None, audio_id=None, | |
| model_name="Base", device="GPU"): | |
| """Estimate processing duration for a given endpoint. | |
| Uses direct wall-time regression (not sum of lease components) fitted on | |
| 257 runs from hetchyy/quran-aligner-logs v1 dataset. | |
| """ | |
| from config import ( | |
| ESTIMATE_GPU_BASE_SLOPE, ESTIMATE_GPU_BASE_INTERCEPT, | |
| ESTIMATE_GPU_LARGE_SLOPE, ESTIMATE_GPU_LARGE_INTERCEPT, | |
| ESTIMATE_CPU_BASE_SLOPE, ESTIMATE_CPU_BASE_INTERCEPT, | |
| ESTIMATE_CPU_LARGE_SLOPE, ESTIMATE_CPU_LARGE_INTERCEPT, | |
| ESTIMATE_WALL_BUFFER, | |
| MFA_PROGRESS_SEGMENT_RATE, | |
| ) | |
| _error = {"estimated_duration_s": None} | |
| if endpoint not in _ESTIMABLE_ENDPOINTS: | |
| _error["error"] = ( | |
| f"Unknown endpoint '{endpoint}'. " | |
| f"Valid: {', '.join(sorted(_ESTIMABLE_ENDPOINTS))}" | |
| ) | |
| return _error | |
| # --- Resolve audio duration --- | |
| meta = None | |
| if audio_id: | |
| meta = _load_session_metadata(audio_id) | |
| if audio_duration_s is not None and audio_duration_s > 0: | |
| duration_s = float(audio_duration_s) | |
| elif meta and meta.get("audio_duration_s"): | |
| duration_s = meta["audio_duration_s"] | |
| else: | |
| _error["error"] = ( | |
| "audio_duration_s is required (or provide audio_id with an existing session)" | |
| ) | |
| return _error | |
| minutes = duration_s / 60.0 | |
| # --- MFA endpoints require session with stored segments --- | |
| if endpoint in _MFA_ENDPOINTS: | |
| if not audio_id: | |
| _error["error"] = "MFA estimation requires audio_id with existing segments" | |
| return _error | |
| segments = _load_segments(audio_id) | |
| if not segments: | |
| _error["error"] = "No segments found in session โ run an alignment endpoint first" | |
| return _error | |
| num_segments = len(segments) | |
| estimate = MFA_PROGRESS_SEGMENT_RATE * num_segments | |
| else: | |
| # --- Pipeline endpoints: direct wall-time regression --- | |
| device_upper = (device or "GPU").upper() | |
| is_large = model_name == "Large" | |
| if device_upper == "CPU": | |
| if is_large: | |
| estimate = ESTIMATE_CPU_LARGE_SLOPE * minutes + ESTIMATE_CPU_LARGE_INTERCEPT | |
| else: | |
| estimate = ESTIMATE_CPU_BASE_SLOPE * minutes + ESTIMATE_CPU_BASE_INTERCEPT | |
| else: | |
| if is_large: | |
| estimate = ESTIMATE_GPU_LARGE_SLOPE * minutes + ESTIMATE_GPU_LARGE_INTERCEPT | |
| else: | |
| estimate = ESTIMATE_GPU_BASE_SLOPE * minutes + ESTIMATE_GPU_BASE_INTERCEPT | |
| # Retranscribe/realign skip VAD โ scale down by ~50% (ASR+DP only) | |
| if endpoint not in _VAD_ENDPOINTS: | |
| estimate *= 0.5 | |
| estimate *= ESTIMATE_WALL_BUFFER | |
| rounded = max(5, math.ceil(estimate / 5) * 5) | |
| return { | |
| "endpoint": endpoint, | |
| "estimated_duration_s": rounded, | |
| "device": device, | |
| "model_name": model_name, | |
| } | |
| def _format_response(audio_id, json_output, warning=None): | |
| """Convert pipeline json_output to the documented API response schema.""" | |
| segments = [] | |
| for seg in json_output.get("segments", []): | |
| entry = { | |
| "segment": seg["segment"], | |
| "time_from": seg["time_from"], | |
| "time_to": seg["time_to"], | |
| "ref_from": seg["ref_from"], | |
| "ref_to": seg["ref_to"], | |
| "matched_text": seg["matched_text"], | |
| "confidence": seg["confidence"], | |
| "has_missing_words": seg.get("has_missing_words", False), | |
| "error": seg["error"], | |
| } | |
| if seg.get("special_type"): | |
| entry["special_type"] = seg["special_type"] | |
| segments.append(entry) | |
| _save_segments(audio_id, segments) | |
| resp = {"audio_id": audio_id, "segments": segments} | |
| if warning: | |
| resp["warning"] = warning | |
| return resp | |
| # --------------------------------------------------------------------------- | |
| # Endpoint wrappers | |
| # --------------------------------------------------------------------------- | |
| def process_audio_session(audio_data, min_silence_ms, min_speech_ms, pad_ms, | |
| model_name="Base", device="GPU", | |
| request: gr.Request = None): | |
| """Full pipeline: preprocess -> VAD -> ASR -> alignment. Creates session.""" | |
| err = _validate_model_name(model_name) | |
| if err: | |
| return err | |
| from src.pipeline import process_audio | |
| quota_warning = None | |
| try: | |
| result = process_audio( | |
| audio_data, int(min_silence_ms), int(min_speech_ms), int(pad_ms), | |
| model_name, device, request=request, endpoint="process", | |
| ) | |
| except QuotaExhaustedError as e: | |
| reset_msg = f" Resets in {e.reset_time}." if e.reset_time else "" | |
| quota_warning = f"GPU quota reached โ processed on CPU (slower).{reset_msg}" | |
| result = process_audio( | |
| audio_data, int(min_silence_ms), int(min_speech_ms), int(pad_ms), | |
| model_name, "CPU", request=request, endpoint="process", | |
| ) | |
| # result is a 9-tuple: | |
| # (html, json_output, speech_intervals, is_complete, audio, sr, intervals, seg_dir, log_row) | |
| json_output = result[1] | |
| if json_output is None: | |
| return {"error": "No speech detected in audio", "segments": []} | |
| speech_intervals = result[2] | |
| is_complete = result[3] | |
| audio_ref = result[4] | |
| intervals = result[6] | |
| # Resolve audio from pipeline cache (result[4] is now a cache key, not array) | |
| from src.pipeline import _load_audio | |
| audio, _ = _load_audio(audio_ref) | |
| audio_id = create_session( | |
| audio, speech_intervals, is_complete, intervals, model_name, | |
| ) | |
| return _format_response(audio_id, json_output, warning=quota_warning) | |
| def process_url_session(url, min_silence_ms, min_speech_ms, pad_ms, | |
| model_name="Base", device="GPU", | |
| request: gr.Request = None): | |
| """Full pipeline from URL: download -> preprocess -> VAD -> ASR -> alignment. | |
| Downloads audio via yt-dlp, then runs the same pipeline as | |
| process_audio_session. Returns the same response format with an | |
| additional url_metadata field. | |
| """ | |
| err = _validate_model_name(model_name) | |
| if err: | |
| return err | |
| if not url or not isinstance(url, str) or not url.strip(): | |
| return {"error": "URL is required", "segments": []} | |
| url = url.strip() | |
| # Download audio | |
| try: | |
| from src.ui.handlers import _download_url_core | |
| wav_path, url_meta = _download_url_core(url) | |
| except Exception as e: | |
| return {"error": f"Download failed: {e}", "segments": []} | |
| # Run the standard pipeline with the downloaded WAV path | |
| from src.pipeline import process_audio | |
| quota_warning = None | |
| try: | |
| result = process_audio( | |
| wav_path, int(min_silence_ms), int(min_speech_ms), int(pad_ms), | |
| model_name, device, request=request, endpoint="process_url", | |
| ) | |
| except QuotaExhaustedError as e: | |
| reset_msg = f" Resets in {e.reset_time}." if e.reset_time else "" | |
| quota_warning = f"GPU quota reached โ processed on CPU (slower).{reset_msg}" | |
| result = process_audio( | |
| wav_path, int(min_silence_ms), int(min_speech_ms), int(pad_ms), | |
| model_name, "CPU", request=request, endpoint="process_url", | |
| ) | |
| json_output = result[1] | |
| if json_output is None: | |
| return {"error": "No speech detected in audio", "segments": []} | |
| speech_intervals = result[2] | |
| is_complete = result[3] | |
| audio_ref = result[4] | |
| intervals = result[6] | |
| from src.pipeline import _load_audio | |
| audio, _ = _load_audio(audio_ref) | |
| audio_id = create_session( | |
| audio, speech_intervals, is_complete, intervals, model_name, | |
| ) | |
| response = _format_response(audio_id, json_output, warning=quota_warning) | |
| response["url_metadata"] = { | |
| "title": url_meta.get("title"), | |
| "duration": url_meta.get("duration"), | |
| "source_url": url_meta.get("source_url"), | |
| } | |
| # Clean up downloaded WAV (audio is now cached in session) | |
| try: | |
| os.remove(wav_path) | |
| except OSError: | |
| pass | |
| return response | |
| def resegment(audio_id, min_silence_ms, min_speech_ms, pad_ms, | |
| model_name="Base", device="GPU", | |
| request: gr.Request = None): | |
| """Re-clean VAD boundaries with new params and re-run ASR + alignment.""" | |
| err = _validate_model_name(model_name) | |
| if err: | |
| err["audio_id"] = audio_id | |
| return err | |
| session = load_session(audio_id) | |
| if session is None: | |
| return _SESSION_ERROR | |
| from src.pipeline import resegment_audio | |
| quota_warning = None | |
| try: | |
| result = resegment_audio( | |
| session["speech_intervals"], session["is_complete"], | |
| session["audio"], 16000, | |
| int(min_silence_ms), int(min_speech_ms), int(pad_ms), | |
| model_name, device, request=request, endpoint="resegment", | |
| ) | |
| except QuotaExhaustedError as e: | |
| reset_msg = f" Resets in {e.reset_time}." if e.reset_time else "" | |
| quota_warning = f"GPU quota reached โ processed on CPU (slower).{reset_msg}" | |
| result = resegment_audio( | |
| session["speech_intervals"], session["is_complete"], | |
| session["audio"], 16000, | |
| int(min_silence_ms), int(min_speech_ms), int(pad_ms), | |
| model_name, "CPU", request=request, endpoint="resegment", | |
| ) | |
| json_output = result[1] | |
| if json_output is None: | |
| return {"audio_id": audio_id, "error": "No segments with these settings", "segments": []} | |
| new_intervals = result[6] | |
| update_session(audio_id, intervals=new_intervals, model_name=model_name) | |
| return _format_response(audio_id, json_output, warning=quota_warning) | |
| def retranscribe(audio_id, model_name="Base", device="GPU", | |
| request: gr.Request = None): | |
| """Re-run ASR with a different model on current segment boundaries.""" | |
| err = _validate_model_name(model_name) | |
| if err: | |
| err["audio_id"] = audio_id | |
| return err | |
| session = load_session(audio_id) | |
| if session is None: | |
| return _SESSION_ERROR | |
| # Guard: reject if model and boundaries unchanged | |
| if (model_name == session["model_name"] | |
| and _intervals_hash(session["intervals"]) == session["intervals_hash"]): | |
| return { | |
| "audio_id": audio_id, | |
| "error": "Model and boundaries unchanged. Change model_name or call /resegment first.", | |
| "segments": [], | |
| } | |
| from src.pipeline import retranscribe_audio | |
| quota_warning = None | |
| try: | |
| result = retranscribe_audio( | |
| session["intervals"], | |
| session["audio"], 16000, | |
| session["speech_intervals"], session["is_complete"], | |
| model_name, device, request=request, endpoint="retranscribe", | |
| ) | |
| except QuotaExhaustedError as e: | |
| reset_msg = f" Resets in {e.reset_time}." if e.reset_time else "" | |
| quota_warning = f"GPU quota reached โ processed on CPU (slower).{reset_msg}" | |
| result = retranscribe_audio( | |
| session["intervals"], | |
| session["audio"], 16000, | |
| session["speech_intervals"], session["is_complete"], | |
| model_name, "CPU", request=request, endpoint="retranscribe", | |
| ) | |
| json_output = result[1] | |
| if json_output is None: | |
| return {"audio_id": audio_id, "error": "Retranscription failed", "segments": []} | |
| update_session(audio_id, model_name=model_name) | |
| return _format_response(audio_id, json_output, warning=quota_warning) | |
| def realign_from_timestamps(audio_id, timestamps, model_name="Base", device="GPU", | |
| request: gr.Request = None): | |
| """Run ASR + alignment on caller-provided timestamp intervals.""" | |
| err = _validate_model_name(model_name) | |
| if err: | |
| err["audio_id"] = audio_id | |
| return err | |
| session = load_session(audio_id) | |
| if session is None: | |
| return _SESSION_ERROR | |
| # Parse timestamps: accept list of {"start": f, "end": f} dicts | |
| if isinstance(timestamps, str): | |
| timestamps = json.loads(timestamps) | |
| intervals = [(ts["start"], ts["end"]) for ts in timestamps] | |
| from src.pipeline import realign_audio | |
| quota_warning = None | |
| try: | |
| result = realign_audio( | |
| intervals, | |
| session["audio"], 16000, | |
| session["speech_intervals"], session["is_complete"], | |
| model_name, device, request=request, endpoint="realign", | |
| ) | |
| except QuotaExhaustedError as e: | |
| reset_msg = f" Resets in {e.reset_time}." if e.reset_time else "" | |
| quota_warning = f"GPU quota reached โ processed on CPU (slower).{reset_msg}" | |
| result = realign_audio( | |
| intervals, | |
| session["audio"], 16000, | |
| session["speech_intervals"], session["is_complete"], | |
| model_name, "CPU", request=request, endpoint="realign", | |
| ) | |
| json_output = result[1] | |
| if json_output is None: | |
| return {"audio_id": audio_id, "error": "Alignment failed", "segments": []} | |
| new_intervals = result[6] | |
| update_session(audio_id, intervals=new_intervals, model_name=model_name) | |
| return _format_response(audio_id, json_output, warning=quota_warning) | |
| # --------------------------------------------------------------------------- | |
| # MFA timestamp helpers | |
| # --------------------------------------------------------------------------- | |
| def _preprocess_api_audio(audio_data): | |
| """Convert audio input to 16kHz mono float32 numpy array. | |
| Handles file path (str) and Gradio numpy tuple (sample_rate, array). | |
| Returns (audio_np, sample_rate). | |
| """ | |
| import librosa | |
| from config import RESAMPLE_TYPE | |
| if isinstance(audio_data, str): | |
| audio, sr = librosa.load(audio_data, sr=16000, mono=True, res_type=RESAMPLE_TYPE) | |
| return audio, 16000 | |
| sample_rate, audio = audio_data | |
| if audio.dtype == np.int16: | |
| audio = audio.astype(np.float32) / 32768.0 | |
| elif audio.dtype == np.int32: | |
| audio = audio.astype(np.float32) / 2147483648.0 | |
| if len(audio.shape) > 1: | |
| audio = audio.mean(axis=1) | |
| if sample_rate != 16000: | |
| audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000, res_type=RESAMPLE_TYPE) | |
| sample_rate = 16000 | |
| return audio, sample_rate | |
| def _create_segment_wavs(audio_np, sample_rate, segments): | |
| """Slice audio by segment boundaries and write WAV files. | |
| Returns the temp directory path containing seg_0.wav, seg_1.wav, etc. | |
| """ | |
| import tempfile | |
| import soundfile as sf | |
| seg_dir = tempfile.mkdtemp(prefix="mfa_api_") | |
| for seg in segments: | |
| seg_idx = seg.get("segment", 0) - 1 | |
| time_from = seg.get("time_from", 0) | |
| time_to = seg.get("time_to", 0) | |
| start_sample = int(time_from * sample_rate) | |
| end_sample = int(time_to * sample_rate) | |
| segment_audio = audio_np[start_sample:end_sample] | |
| wav_path = os.path.join(seg_dir, f"seg_{seg_idx}.wav") | |
| sf.write(wav_path, segment_audio, sample_rate) | |
| return seg_dir | |
| # --------------------------------------------------------------------------- | |
| # MFA timestamp helpers | |
| # --------------------------------------------------------------------------- | |
| _SPECIAL_TEXT = { | |
| "Basmala": "ุจูุณูู ู ูฑูููููู ูฑูุฑููุญูู ููฐูู ูฑูุฑููุญููู ", | |
| "Isti'adha": "ุฃูุนููุฐู ุจููฑูููููู ู ููู ุงูุดููููุทูุงูู ุงูุฑููุฌููู ", | |
| "Amin": "ุขู ููู", | |
| "Takbir": "ุงูููููู ุฃูููุจูุฑ", | |
| "Tahmeed": "ุณูู ูุนู ุงูููููู ููู ููู ุญูู ูุฏูู", | |
| "Tasleem": "ูฑูุณููููุงู ู ุนูููููููู ู ููุฑูุญูู ูุฉู ูฑููููู", | |
| "Sadaqa": "ุตูุฏููู ูฑูููููู ูฑููุนูุธููู ", | |
| } | |
| def _normalize_segments(segments): | |
| """Fill defaults so callers can pass minimal segment dicts (timestamps + refs). | |
| Auto-assigns ``segment`` numbers, defaults ``confidence`` to 1.0, and | |
| derives ``matched_text`` from ``special_type`` for special segments. | |
| """ | |
| normalized = [] | |
| for i, seg in enumerate(segments): | |
| entry = dict(seg) | |
| if "segment" not in entry: | |
| entry["segment"] = i + 1 | |
| if "confidence" not in entry: | |
| entry["confidence"] = 1.0 | |
| if "matched_text" not in entry: | |
| special = entry.get("special_type", "") | |
| entry["matched_text"] = _SPECIAL_TEXT.get(special, "") | |
| normalized.append(entry) | |
| return normalized | |
| # --------------------------------------------------------------------------- | |
| # MFA timestamp endpoints | |
| # --------------------------------------------------------------------------- | |
| def timestamps(audio_id, segments_json=None, granularity="words"): | |
| """Compute MFA word/letter timestamps using session audio.""" | |
| if granularity == "words+chars": | |
| return {"audio_id": audio_id, "error": "chars granularity is currently disabled via API", "segments": []} | |
| session = load_session(audio_id) | |
| if session is None: | |
| return _SESSION_ERROR | |
| # Parse segments: use provided or load stored | |
| if isinstance(segments_json, str): | |
| segments_json = json.loads(segments_json) | |
| if segments_json: | |
| segments = _normalize_segments(segments_json) | |
| else: | |
| segments = _load_segments(audio_id) | |
| if not segments: | |
| return {"audio_id": audio_id, "error": "No segments found in session", "segments": []} | |
| # Create segment WAVs from session audio | |
| try: | |
| seg_dir = _create_segment_wavs(session["audio"], 16000, segments) | |
| except Exception as e: | |
| return {"audio_id": audio_id, "error": f"Failed to create segment audio: {e}", "segments": []} | |
| from src.mfa import compute_mfa_timestamps_api | |
| try: | |
| result = compute_mfa_timestamps_api(segments, seg_dir, granularity or "words") | |
| except Exception as e: | |
| return {"audio_id": audio_id, "error": f"MFA alignment failed: {e}", "segments": []} | |
| result["audio_id"] = audio_id | |
| return result | |
| def timestamps_direct(audio_data, segments_json, granularity="words"): | |
| """Compute MFA word/letter timestamps with provided audio and segments.""" | |
| if granularity == "words+chars": | |
| return {"error": "chars granularity is currently disabled via API", "segments": []} | |
| # Parse segments | |
| if isinstance(segments_json, str): | |
| segments_json = json.loads(segments_json) | |
| if not segments_json: | |
| return {"error": "No segments provided", "segments": []} | |
| segments = _normalize_segments(segments_json) | |
| # Preprocess audio | |
| try: | |
| audio_np, sr = _preprocess_api_audio(audio_data) | |
| except Exception as e: | |
| return {"error": f"Failed to preprocess audio: {e}", "segments": []} | |
| # Create segment WAVs | |
| try: | |
| seg_dir = _create_segment_wavs(audio_np, sr, segments) | |
| except Exception as e: | |
| return {"error": f"Failed to create segment audio: {e}", "segments": []} | |
| from src.mfa import compute_mfa_timestamps_api | |
| try: | |
| result = compute_mfa_timestamps_api(segments, seg_dir, granularity or "words") | |
| except Exception as e: | |
| return {"error": f"MFA alignment failed: {e}", "segments": []} | |
| return result | |