hetchyy commited on
Commit
2ce56b1
·
1 Parent(s): fb6ec07

Add timestamps API endpoints

Browse files
.gitignore CHANGED
@@ -50,5 +50,6 @@ models/
50
  captures/
51
 
52
  docs/api.md
 
53
  scripts/
54
  tests/
 
50
  captures/
51
 
52
  docs/api.md
53
+ docs/lease_duration_history.md
54
  scripts/
55
  tests/
docs/client_api.md CHANGED
@@ -32,6 +32,15 @@ result = client.predict(
32
  "Base", "GPU",
33
  api_name="/realign_from_timestamps"
34
  )
 
 
 
 
 
 
 
 
 
35
  ```
36
 
37
  ---
@@ -48,6 +57,7 @@ The first call returns an `audio_id` (32-character hex string). Pass it to subse
48
  | Raw VAD speech intervals | Disk (pickle) | No |
49
  | Cleaned segment boundaries | Disk (JSON) | Yes (resegment / realign) |
50
  | Model name | Disk (JSON) | Yes (retranscribe) |
 
51
 
52
  If `audio_id` is missing, expired, or invalid:
53
  ```json
@@ -173,6 +183,9 @@ All errors follow the same shape: `{"error": "...", "segments": []}`. Endpoints
173
  | Retranscribe with same model | `"Model and boundaries unchanged. Change model_name or call /resegment_session first."` | Yes |
174
  | Retranscription failed | `"Retranscription failed"` | Yes |
175
  | Realignment failed | `"Alignment failed"` | Yes |
 
 
 
176
 
177
  ---
178
 
@@ -237,3 +250,129 @@ Accepts arbitrary `(start, end)` timestamp pairs and runs ASR + alignment on eac
237
  **Response:** Same shape as `/process_audio_session`. Session boundaries are replaced with the provided timestamps.
238
 
239
  This endpoint subsumes split, merge, and boundary adjustment — the client computes the desired timestamps locally and sends them in one call.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  "Base", "GPU",
33
  api_name="/realign_from_timestamps"
34
  )
35
+
36
+ # Compute MFA word timestamps (uses stored session segments)
37
+ mfa = client.predict(audio_id, None, "words", api_name="/mfa_timestamps_session")
38
+
39
+ # Compute MFA word + letter timestamps
40
+ mfa = client.predict(audio_id, None, "words+chars", api_name="/mfa_timestamps_session")
41
+
42
+ # Direct MFA timestamps (no session needed)
43
+ mfa = client.predict("recitation.mp3", result["segments"], "words", api_name="/mfa_timestamps_direct")
44
  ```
45
 
46
  ---
 
57
  | Raw VAD speech intervals | Disk (pickle) | No |
58
  | Cleaned segment boundaries | Disk (JSON) | Yes (resegment / realign) |
59
  | Model name | Disk (JSON) | Yes (retranscribe) |
60
+ | Alignment segments | Disk (JSON) | Yes (any alignment call) |
61
 
62
  If `audio_id` is missing, expired, or invalid:
63
  ```json
 
183
  | Retranscribe with same model | `"Model and boundaries unchanged. Change model_name or call /resegment_session first."` | Yes |
184
  | Retranscription failed | `"Retranscription failed"` | Yes |
185
  | Realignment failed | `"Alignment failed"` | Yes |
186
+ | No segments in session (MFA) | `"No segments found in session"` | Yes |
187
+ | MFA alignment failed | `"MFA alignment failed: ..."` | Yes (session) / No (direct) |
188
+ | No segments provided (MFA direct) | `"No segments provided"` | No |
189
 
190
  ---
191
 
 
250
  **Response:** Same shape as `/process_audio_session`. Session boundaries are replaced with the provided timestamps.
251
 
252
  This endpoint subsumes split, merge, and boundary adjustment — the client computes the desired timestamps locally and sends them in one call.
253
+
254
+ ---
255
+
256
+ ### `POST /mfa_timestamps_session`
257
+
258
+ Compute word-level (and optionally letter-level) MFA timestamps using session audio. Segments come from the stored session or can be overridden.
259
+
260
+ | Parameter | Type | Default | Description |
261
+ |---|---|---|---|
262
+ | `audio_id` | str | required | Session ID from a previous alignment call |
263
+ | `segments` | list? | `null` | Segment list to align. `null` uses stored segments from the session |
264
+ | `granularity` | str | `"words"` | `"words"` (word timestamps only) or `"words+chars"` (word + letter timestamps) |
265
+
266
+ **Example — using stored segments:**
267
+ ```python
268
+ result = client.predict(
269
+ "a1b2c3d4e5f67890a1b2c3d4e5f67890", # audio_id
270
+ None, # segments (null = use stored)
271
+ "words", # granularity
272
+ api_name="/mfa_timestamps_session",
273
+ )
274
+ ```
275
+
276
+ **Example — with segments override (minimal):**
277
+ ```python
278
+ result = client.predict(
279
+ "a1b2c3d4e5f67890a1b2c3d4e5f67890", # audio_id
280
+ [ # segments override
281
+ {"time_from": 0.48, "time_to": 2.88, "ref_from": "112:1:1", "ref_to": "112:1:4"},
282
+ {"time_from": 3.12, "time_to": 5.44, "ref_from": "112:2:1", "ref_to": "112:2:3"},
283
+ ],
284
+ "words+chars", # granularity
285
+ api_name="/mfa_timestamps_session",
286
+ )
287
+ ```
288
+
289
+ **Example — passing alignment results directly:**
290
+ ```python
291
+ # Segments from /process_audio_session can be passed as-is
292
+ proc = client.predict("recitation.mp3", 200, 1000, 100, "Base", "CPU", api_name="/process_audio_session")
293
+ mfa = client.predict(proc["audio_id"], proc["segments"], "words+chars", api_name="/mfa_timestamps_session")
294
+ ```
295
+
296
+ **Example — special segment (Basmala):**
297
+ ```python
298
+ # Special segments use empty ref_from/ref_to and carry a special_type field
299
+ {"time_from": 0.0, "time_to": 2.1, "ref_from": "", "ref_to": "", "special_type": "Basmala"}
300
+ ```
301
+
302
+ **Segment input fields:**
303
+
304
+ | Field | Type | Required | Description |
305
+ |---|---|---|---|
306
+ | `time_from` | float | yes | Start time in seconds (used to slice audio) |
307
+ | `time_to` | float | yes | End time in seconds (used to slice audio) |
308
+ | `ref_from` | str | yes | First word as `"surah:ayah:word"`. Empty for special segments |
309
+ | `ref_to` | str | yes | Last word as `"surah:ayah:word"`. Empty for special segments |
310
+ | `segment` | int | no | 1-indexed segment number. Auto-assigned from position if omitted |
311
+ | `confidence` | float | no | Defaults to 1.0. Segments with confidence ≤ 0 are skipped |
312
+ | `special_type` | str | no | Only for special segments (`"Basmala"`, `"Isti'adha"`, etc.) |
313
+ | `matched_text` | str | no | Quran text. Used for fused Basmala/Isti'adha prefix detection |
314
+
315
+ > **Tip:** You can pass the `segments` array from any alignment endpoint directly — all extra fields are preserved and echoed back in the response.
316
+
317
+ **Response:**
318
+ ```json
319
+ {
320
+ "audio_id": "a1b2c3d4e5f67890a1b2c3d4e5f67890",
321
+ "segments": [
322
+ {
323
+ "segment": 1,
324
+ "words": [
325
+ ["112:1:1", 0.0, 0.32],
326
+ ["112:1:2", 0.32, 0.58],
327
+ ["112:1:3", 0.58, 1.12],
328
+ ["112:1:4", 1.12, 1.68]
329
+ ]
330
+ }
331
+ ]
332
+ }
333
+ ```
334
+
335
+ With `granularity="words+chars"`, each word includes a 4th element — letter timestamps:
336
+ ```json
337
+ ["112:1:1", 0.0, 0.32, [["ق", 0.0, 0.15], ["ل", 0.15, 0.32]]]
338
+ ```
339
+
340
+ **Word array:** `[location, start, end]` or `[location, start, end, letters]`
341
+
342
+ | Index | Type | Description |
343
+ |---|---|---|
344
+ | 0 | str | Word position as `"surah:ayah:word"` |
345
+ | 1 | float | Start time relative to segment (seconds) |
346
+ | 2 | float | End time relative to segment (seconds) |
347
+ | 3 | list? | Only present when `granularity="words+chars"`. Array of `[char, start, end]` tuples |
348
+
349
+ > **Note:** All timestamps are **relative to the segment** (not to the full recording). Add `time_from` to convert to absolute times.
350
+
351
+ ---
352
+
353
+ ### `POST /mfa_timestamps_direct`
354
+
355
+ Compute MFA timestamps with a provided audio file and segments. No session required — standalone endpoint.
356
+
357
+ | Parameter | Type | Default | Description |
358
+ |---|---|---|---|
359
+ | `audio` | file | required | Audio file (any format) |
360
+ | `segments` | list | required | Segment list with `time_from`/`time_to` boundaries |
361
+ | `granularity` | str | `"words"` | `"words"` or `"words+chars"` |
362
+
363
+ **Response:** Same shape as `/mfa_timestamps_session` but without `audio_id`.
364
+
365
+ **Example (minimal):**
366
+ ```python
367
+ result = client.predict(
368
+ "recitation.mp3",
369
+ [
370
+ {"time_from": 0.48, "time_to": 2.88, "ref_from": "112:1:1", "ref_to": "112:1:4"},
371
+ {"time_from": 3.12, "time_to": 5.44, "ref_from": "112:2:1", "ref_to": "112:2:3"},
372
+ ],
373
+ "words+chars",
374
+ api_name="/mfa_timestamps_direct",
375
+ )
376
+ ```
377
+
378
+ Segment input format is the same as for `/mfa_timestamps_session` — see [segment input fields](#segment-input-fields) above.
src/api/session_api.py CHANGED
@@ -149,6 +149,30 @@ def update_session(audio_id, *, intervals=None, model_name=None):
149
  os.replace(tmp, meta_path)
150
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  # ---------------------------------------------------------------------------
153
  # Response formatting
154
  # ---------------------------------------------------------------------------
@@ -174,6 +198,7 @@ def _format_response(audio_id, json_output, warning=None):
174
  if seg.get("special_type"):
175
  entry["special_type"] = seg["special_type"]
176
  segments.append(entry)
 
177
  resp = {"audio_id": audio_id, "segments": segments}
178
  if warning:
179
  resp["warning"] = warning
@@ -338,3 +363,146 @@ def realign_from_timestamps(audio_id, timestamps, model_name="Base", device="GPU
338
  new_intervals = result[6]
339
  update_session(audio_id, intervals=new_intervals, model_name=model_name)
340
  return _format_response(audio_id, json_output, warning=quota_warning)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  os.replace(tmp, meta_path)
150
 
151
 
152
+ def _save_segments(audio_id, segments):
153
+ """Persist alignment segments for later MFA use."""
154
+ path = _session_dir(audio_id)
155
+ if not path.exists():
156
+ return
157
+ seg_path = path / "segments.json"
158
+ tmp = path / "segments.tmp"
159
+ with open(tmp, "w") as f:
160
+ json.dump(segments, f)
161
+ os.replace(tmp, seg_path)
162
+
163
+
164
+ def _load_segments(audio_id):
165
+ """Load stored segments. Returns list or None."""
166
+ if not _validate_id(audio_id):
167
+ return None
168
+ path = _session_dir(audio_id)
169
+ seg_path = path / "segments.json"
170
+ if not seg_path.exists():
171
+ return None
172
+ with open(seg_path) as f:
173
+ return json.load(f)
174
+
175
+
176
  # ---------------------------------------------------------------------------
177
  # Response formatting
178
  # ---------------------------------------------------------------------------
 
198
  if seg.get("special_type"):
199
  entry["special_type"] = seg["special_type"]
200
  segments.append(entry)
201
+ _save_segments(audio_id, segments)
202
  resp = {"audio_id": audio_id, "segments": segments}
203
  if warning:
204
  resp["warning"] = warning
 
363
  new_intervals = result[6]
364
  update_session(audio_id, intervals=new_intervals, model_name=model_name)
365
  return _format_response(audio_id, json_output, warning=quota_warning)
366
+
367
+
368
+ # ---------------------------------------------------------------------------
369
+ # MFA timestamp helpers
370
+ # ---------------------------------------------------------------------------
371
+
372
+ def _preprocess_api_audio(audio_data):
373
+ """Convert audio input to 16kHz mono float32 numpy array.
374
+
375
+ Handles file path (str) and Gradio numpy tuple (sample_rate, array).
376
+ Returns (audio_np, sample_rate).
377
+ """
378
+ import librosa
379
+ from config import RESAMPLE_TYPE
380
+
381
+ if isinstance(audio_data, str):
382
+ audio, sr = librosa.load(audio_data, sr=16000, mono=True, res_type=RESAMPLE_TYPE)
383
+ return audio, 16000
384
+
385
+ sample_rate, audio = audio_data
386
+ if audio.dtype == np.int16:
387
+ audio = audio.astype(np.float32) / 32768.0
388
+ elif audio.dtype == np.int32:
389
+ audio = audio.astype(np.float32) / 2147483648.0
390
+ if len(audio.shape) > 1:
391
+ audio = audio.mean(axis=1)
392
+ if sample_rate != 16000:
393
+ audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000, res_type=RESAMPLE_TYPE)
394
+ sample_rate = 16000
395
+ return audio, sample_rate
396
+
397
+
398
+ def _create_segment_wavs(audio_np, sample_rate, segments):
399
+ """Slice audio by segment boundaries and write WAV files.
400
+
401
+ Returns the temp directory path containing seg_0.wav, seg_1.wav, etc.
402
+ """
403
+ import tempfile
404
+ import soundfile as sf
405
+
406
+ seg_dir = tempfile.mkdtemp(prefix="mfa_api_")
407
+ for seg in segments:
408
+ seg_idx = seg.get("segment", 0) - 1
409
+ time_from = seg.get("time_from", 0)
410
+ time_to = seg.get("time_to", 0)
411
+ start_sample = int(time_from * sample_rate)
412
+ end_sample = int(time_to * sample_rate)
413
+ segment_audio = audio_np[start_sample:end_sample]
414
+ wav_path = os.path.join(seg_dir, f"seg_{seg_idx}.wav")
415
+ sf.write(wav_path, segment_audio, sample_rate)
416
+ return seg_dir
417
+
418
+
419
+ # ---------------------------------------------------------------------------
420
+ # MFA timestamp helpers
421
+ # ---------------------------------------------------------------------------
422
+
423
+ def _normalize_segments(segments):
424
+ """Fill defaults so callers can pass minimal segment dicts (timestamps + refs).
425
+
426
+ Auto-assigns ``segment`` numbers and defaults ``confidence`` to 1.0 so
427
+ segments are not accidentally skipped by ``_build_mfa_refs``.
428
+ """
429
+ normalized = []
430
+ for i, seg in enumerate(segments):
431
+ entry = dict(seg)
432
+ if "segment" not in entry:
433
+ entry["segment"] = i + 1
434
+ if "confidence" not in entry:
435
+ entry["confidence"] = 1.0
436
+ if "matched_text" not in entry:
437
+ entry["matched_text"] = ""
438
+ normalized.append(entry)
439
+ return normalized
440
+
441
+
442
+ # ---------------------------------------------------------------------------
443
+ # MFA timestamp endpoints
444
+ # ---------------------------------------------------------------------------
445
+
446
+ def mfa_timestamps_session(audio_id, segments_json=None, granularity="words"):
447
+ """Compute MFA word/letter timestamps using session audio."""
448
+ session = load_session(audio_id)
449
+ if session is None:
450
+ return _SESSION_ERROR
451
+
452
+ # Parse segments: use provided or load stored
453
+ if isinstance(segments_json, str):
454
+ segments_json = json.loads(segments_json)
455
+
456
+ if segments_json:
457
+ segments = _normalize_segments(segments_json)
458
+ else:
459
+ segments = _load_segments(audio_id)
460
+ if not segments:
461
+ return {"audio_id": audio_id, "error": "No segments found in session", "segments": []}
462
+
463
+ # Create segment WAVs from session audio
464
+ try:
465
+ seg_dir = _create_segment_wavs(session["audio"], 16000, segments)
466
+ except Exception as e:
467
+ return {"audio_id": audio_id, "error": f"Failed to create segment audio: {e}", "segments": []}
468
+
469
+ from src.mfa import compute_mfa_timestamps_api
470
+ try:
471
+ result = compute_mfa_timestamps_api(segments, seg_dir, granularity or "words")
472
+ except Exception as e:
473
+ return {"audio_id": audio_id, "error": f"MFA alignment failed: {e}", "segments": []}
474
+
475
+ result["audio_id"] = audio_id
476
+ return result
477
+
478
+
479
+ def mfa_timestamps_direct(audio_data, segments_json, granularity="words"):
480
+ """Compute MFA word/letter timestamps with provided audio and segments."""
481
+ # Parse segments
482
+ if isinstance(segments_json, str):
483
+ segments_json = json.loads(segments_json)
484
+
485
+ if not segments_json:
486
+ return {"error": "No segments provided", "segments": []}
487
+
488
+ segments = _normalize_segments(segments_json)
489
+
490
+ # Preprocess audio
491
+ try:
492
+ audio_np, sr = _preprocess_api_audio(audio_data)
493
+ except Exception as e:
494
+ return {"error": f"Failed to preprocess audio: {e}", "segments": []}
495
+
496
+ # Create segment WAVs
497
+ try:
498
+ seg_dir = _create_segment_wavs(audio_np, sr, segments)
499
+ except Exception as e:
500
+ return {"error": f"Failed to create segment audio: {e}", "segments": []}
501
+
502
+ from src.mfa import compute_mfa_timestamps_api
503
+ try:
504
+ result = compute_mfa_timestamps_api(segments, seg_dir, granularity or "words")
505
+ except Exception as e:
506
+ return {"error": f"MFA alignment failed: {e}", "segments": []}
507
+
508
+ return result
src/mfa.py CHANGED
@@ -5,6 +5,10 @@ from config import MFA_SPACE_URL, MFA_TIMEOUT, MFA_PROGRESS_SEGMENT_RATE
5
  # Lowercase special ref names for case-insensitive matching
6
  _SPECIAL_REFS = {"basmala", "isti'adha", "isti'adha+basmala"}
7
 
 
 
 
 
8
 
9
  def _mfa_upload_and_submit(refs, audio_paths):
10
  """Upload audio files and submit alignment batch to the MFA Space.
@@ -95,6 +99,395 @@ def _mfa_wait_result(event_id, headers, base):
95
  return parsed["results"]
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def _ts_progress_bar_html(total_segments, rate, animated=True):
99
  """Return HTML for a progress bar showing Segment x/N.
100
 
@@ -149,6 +542,10 @@ def _ts_progress_bar_html(total_segments, rate, animated=True):
149
  </div>'''
150
 
151
 
 
 
 
 
152
  def compute_mfa_timestamps(current_html, json_output, segment_dir, cached_log_row=None):
153
  """Compute word-level timestamps via MFA forced alignment and inject into HTML.
154
 
@@ -169,61 +566,11 @@ def compute_mfa_timestamps(current_html, json_output, segment_dir, cached_log_ro
169
  yield current_html, gr.update(), gr.update(), gr.update(), gr.update()
170
  return
171
 
172
- # Build refs and audio paths from structured JSON output
173
  segments = json_output.get("segments", []) if json_output else []
174
  print(f"[MFA_TS] {len(segments)} segments in JSON")
175
- refs = []
176
- audio_paths = []
177
- seg_to_result_idx = {} # Maps segment index (0-based) → result index
178
-
179
- _BASMALA_TEXT = "بِسْمِ ٱللَّهِ ٱلرَّحْمَٰنِ ٱلرَّحِيم"
180
- _ISTIATHA_TEXT = "أَعُوذُ بِٱللَّهِ مِنَ الشَّيْطَانِ الرَّجِيم"
181
- _COMBINED_PREFIX = _ISTIATHA_TEXT + " ۝ " + _BASMALA_TEXT
182
-
183
- for seg in segments:
184
- ref_from = seg.get("ref_from", "")
185
- ref_to = seg.get("ref_to", "")
186
- seg_idx = seg.get("segment", 0) - 1 # 0-indexed
187
- confidence = seg.get("confidence", 0)
188
-
189
- # For special segments (Basmala/Isti'adha), ref_from is empty but
190
- # special_type carries the ref name needed for MFA
191
- if not ref_from:
192
- ref_from = seg.get("special_type", "")
193
- ref_to = ref_from # Special segments use same ref for both
194
- if not ref_from or confidence <= 0:
195
- continue
196
-
197
- # Build MFA ref
198
- if ref_from == ref_to:
199
- mfa_ref = ref_from
200
- else:
201
- mfa_ref = f"{ref_from}-{ref_to}"
202
-
203
- # Detect fused special prefix and build compound ref
204
- # (skip when the ref itself is already a special like "Basmala")
205
- _is_special_ref = ref_from.strip().lower() in _SPECIAL_REFS
206
- if not _is_special_ref:
207
- matched_text = seg.get("matched_text", "")
208
- if matched_text.startswith(_COMBINED_PREFIX):
209
- mfa_ref = f"Isti'adha+Basmala+{mfa_ref}"
210
- elif matched_text.startswith(_ISTIATHA_TEXT):
211
- mfa_ref = f"Isti'adha+{mfa_ref}"
212
- elif matched_text.startswith(_BASMALA_TEXT):
213
- mfa_ref = f"Basmala+{mfa_ref}"
214
-
215
- # Check audio file exists
216
- audio_path = os.path.join(segment_dir, f"seg_{seg_idx}.wav") if segment_dir else None
217
- if not audio_path or not os.path.exists(audio_path):
218
- print(f"[MFA_TS] Skipping seg {seg_idx}: audio not found at {audio_path}")
219
- continue
220
 
221
- # Track mapping from segment index to result index
222
- seg_to_result_idx[seg_idx] = len(refs)
223
- refs.append(mfa_ref)
224
- audio_paths.append(audio_path)
225
-
226
- print(f"[MFA_TS] {len(refs)} refs to align: {refs[:5]}{'...' if len(refs) > 5 else ''}")
227
 
228
  if not refs:
229
  print("[MFA_TS] Early return: no valid refs/audio pairs")
@@ -282,217 +629,29 @@ def compute_mfa_timestamps(current_html, json_output, segment_dir, cached_log_ro
282
  )
283
  raise
284
 
285
- # Build lookup: "result_idx:location" (start, end) from all successful results
286
- # Using result_idx prefix ensures each segment has its own timestamps even for shared words
287
- word_timestamps = {} # "result_idx:location" → (start, end)
288
- letter_timestamps = {} # "result_idx:location" → list of letter dicts with group_id
289
- word_to_all_results = {} # word_pos → [result_idx, ...] (all occurrences)
290
-
291
- def _assign_letter_groups(letters, word_location):
292
- """Assign group_id to letters sharing identical (start, end) timestamps."""
293
- if not letters:
294
- return []
295
- result = []
296
- group_id = 0
297
- prev_ts = None
298
- for letter in letters:
299
- ts = (letter.get("start"), letter.get("end"))
300
- if ts != prev_ts:
301
- group_id += 1
302
- prev_ts = ts
303
- result.append({
304
- "char": letter.get("char", ""),
305
- "start": letter.get("start"),
306
- "end": letter.get("end"),
307
- "group_id": f"{word_location}:{group_id}", # Unique across words
308
- })
309
- return result
310
-
311
- for result_idx, result in enumerate(results):
312
- if result.get("status") != "ok":
313
- print(f"[MFA_TS] Segment failed: ref={result.get('ref')} error={result.get('error')}")
314
- continue
315
- ref = result.get("ref", "")
316
- is_special = ref.strip().lower() in _SPECIAL_REFS
317
- is_fused = "+" in ref
318
- for word in result.get("words", []):
319
- loc = word.get("location", "")
320
- if loc:
321
- if is_special:
322
- base_key = f"{ref}:{loc}"
323
- elif is_fused and loc.startswith("0:0:"):
324
- base_key = f"{ref}:{loc}"
325
- else:
326
- base_key = loc
327
- key = f"{result_idx}:{base_key}" # Prefix with result index
328
- word_timestamps[key] = (word["start"], word["end"])
329
- # Extract letter timestamps if available
330
- letters = word.get("letters")
331
- if letters:
332
- letter_timestamps[key] = _assign_letter_groups(letters, loc)
333
- # Track word→result_idx mapping for lookup (regular words only)
334
- if not is_special and not (is_fused and loc.startswith("0:0:")):
335
- if loc not in word_to_all_results:
336
- word_to_all_results[loc] = []
337
- word_to_all_results[loc].append(result_idx)
338
-
339
- print(f"[MFA_TS] {len(word_timestamps)} word timestamps collected, {len(letter_timestamps)} with letter-level data")
340
-
341
- # Build cross-word overlap groups for simultaneous highlighting
342
- def _build_crossword_groups(results_list, letter_ts_dict):
343
- """
344
- Build mapping of (key, letter_idx) -> cross-word group_id.
345
- Only checks word boundaries: last letter(s) of word N vs first letter(s) of word N+1.
346
- """
347
- crossword_groups = {} # (key, idx) -> group_id
348
-
349
- for result_idx, result in enumerate(results_list):
350
- if result.get("status") != "ok":
351
- continue
352
- ref = result.get("ref", "")
353
- is_special = ref.strip().lower() in _SPECIAL_REFS
354
- is_fused = "+" in ref
355
- words = result.get("words", [])
356
-
357
- # Iterate through consecutive word pairs
358
- for word_i in range(len(words) - 1):
359
- word_a = words[word_i]
360
- word_b = words[word_i + 1]
361
-
362
- loc_a = word_a.get("location", "")
363
- loc_b = word_b.get("location", "")
364
- if not loc_a or not loc_b:
365
- continue
366
-
367
- # Build keys for letter_timestamps lookup
368
- def make_key(loc):
369
- if is_special:
370
- base_key = f"{ref}:{loc}"
371
- elif is_fused and loc.startswith("0:0:"):
372
- base_key = f"{ref}:{loc}"
373
- else:
374
- base_key = loc
375
- return f"{result_idx}:{base_key}"
376
-
377
- key_a = make_key(loc_a)
378
- key_b = make_key(loc_b)
379
- letters_a = letter_ts_dict.get(key_a, [])
380
- letters_b = letter_ts_dict.get(key_b, [])
381
-
382
- if not letters_a or not letters_b:
383
- continue
384
-
385
- # Compare last letter(s) of word A with first letter(s) of word B
386
- # Check last few letters of A against first few letters of B
387
- for idx_a in range(len(letters_a) - 1, max(len(letters_a) - 3, -1), -1):
388
- letter_a = letters_a[idx_a]
389
- if letter_a.get("start") is None or letter_a.get("end") is None:
390
- continue
391
- for idx_b in range(min(3, len(letters_b))):
392
- letter_b = letters_b[idx_b]
393
- if letter_b.get("start") is None or letter_b.get("end") is None:
394
- continue
395
- # Check for exact timestamp match (MFA marks simultaneous letters identically)
396
- if letter_a["start"] == letter_b["start"] and letter_a["end"] == letter_b["end"]:
397
- group_id = f"xword-{result_idx}-{word_i}"
398
- crossword_groups[(key_a, idx_a)] = group_id
399
- crossword_groups[(key_b, idx_b)] = group_id
400
-
401
- if crossword_groups:
402
- print(f"[MFA_TS] Found {len(crossword_groups)} cross-word overlapping letters")
403
-
404
- return crossword_groups
405
 
 
406
  crossword_groups = _build_crossword_groups(results, letter_timestamps)
407
 
408
- # Post-process: extend each word's end to the start of the next word
409
- # so words don't disappear between timestamps during animation.
410
- import wave
411
- for seg in segments:
412
- ref_from = seg.get("ref_from", "")
413
- ref_to = seg.get("ref_to", "")
414
- seg_idx = seg.get("segment", 0) - 1
415
- confidence = seg.get("confidence", 0)
416
- if not ref_from:
417
- ref_from = seg.get("special_type", "")
418
- ref_to = ref_from # Special segments use same ref for both
419
- if not ref_from or confidence <= 0:
420
- continue
421
- # Get result_idx for this segment (may not exist if segment was skipped)
422
- result_idx = seg_to_result_idx.get(seg_idx)
423
- if result_idx is None:
424
- continue
425
- # Find the matching MFA result and collect word locations in order
426
- ref_key = f"{ref_from}-{ref_to}" if ref_from != ref_to else ref_from
427
- is_special = ref_from.strip().lower() in _SPECIAL_REFS
428
- # Reconstruct compound ref for fused segments
429
- # (skip when the ref itself is already a special like "Basmala")
430
- if not is_special:
431
- matched_text = seg.get("matched_text", "")
432
- if matched_text.startswith(_COMBINED_PREFIX):
433
- ref_key = f"Isti'adha+Basmala+{ref_key}"
434
- elif matched_text.startswith(_ISTIATHA_TEXT):
435
- ref_key = f"Isti'adha+{ref_key}"
436
- elif matched_text.startswith(_BASMALA_TEXT):
437
- ref_key = f"Basmala+{ref_key}"
438
- is_fused = "+" in ref_key
439
- seg_word_locs = []
440
- for result in results:
441
- if result.get("ref") == ref_key and result.get("status") == "ok":
442
- for w in result.get("words", []):
443
- loc = w.get("location", "")
444
- if loc:
445
- if is_special:
446
- base_key = f"{ref_key}:{loc}"
447
- elif is_fused and loc.startswith("0:0:"):
448
- base_key = f"{ref_key}:{loc}"
449
- else:
450
- base_key = loc
451
- key = f"{result_idx}:{base_key}" # Use result_idx prefix
452
- if key in word_timestamps:
453
- seg_word_locs.append(key)
454
- break
455
- if not seg_word_locs:
456
- continue
457
- # Extend each word's end to the next word's start
458
- for i in range(len(seg_word_locs) - 1):
459
- cur_start, cur_end = word_timestamps[seg_word_locs[i]]
460
- nxt_start, _ = word_timestamps[seg_word_locs[i + 1]]
461
- if nxt_start > cur_end:
462
- word_timestamps[seg_word_locs[i]] = (cur_start, nxt_start)
463
- # Extend first word back to time 0 so highlight starts immediately
464
- first_loc = seg_word_locs[0]
465
- first_start, first_end = word_timestamps[first_loc]
466
- if first_start > 0:
467
- word_timestamps[first_loc] = (0, first_end)
468
- # Extend last word to segment audio duration
469
- last_loc = seg_word_locs[-1]
470
- last_start, last_end = word_timestamps[last_loc]
471
- audio_path = os.path.join(segment_dir, f"seg_{seg_idx}.wav") if segment_dir else None
472
- if audio_path and os.path.exists(audio_path):
473
- with wave.open(audio_path, 'rb') as wf:
474
- seg_duration = wf.getnframes() / wf.getframerate()
475
- if seg_duration > last_end:
476
- word_timestamps[last_loc] = (last_start, seg_duration)
477
 
478
- print(f"[MFA_TS] Post-processed timestamps: extended word ends to fill gaps")
479
 
480
  # Inject timestamps into word spans, using segment boundaries to determine result_idx
481
- # Step 1: Find all segment boundaries (position → seg_idx)
482
- seg_boundaries = [] # [(position, seg_idx), ...]
483
  for m in re.finditer(r'data-segment-idx="(\d+)"', current_html):
484
  seg_boundaries.append((m.start(), int(m.group(1))))
485
  seg_boundaries.sort(key=lambda x: x[0])
486
 
487
- # Build segment offset lookup: seg_idx → time_from (for absolute timestamp conversion)
488
- seg_offset_map = {} # seg_idx (0-based) → time_from
489
  for seg in segments:
490
- idx = seg.get("segment", 0) - 1 # Convert to 0-based
491
  seg_offset_map[idx] = seg.get("time_from", 0)
492
 
493
- # Step 2: For each word span, find which segment it belongs to
494
  def _get_seg_idx_at_pos(pos):
495
- """Find the segment index for a position in the HTML."""
496
  seg_idx = None
497
  for boundary_pos, idx in seg_boundaries:
498
  if boundary_pos > pos:
@@ -508,13 +667,10 @@ def compute_mfa_timestamps(current_html, json_output, segment_dir, cached_log_ro
508
  if not pos_m:
509
  return orig
510
  pos = pos_m.group(1)
511
- # Find which segment this word belongs to
512
  seg_idx = _get_seg_idx_at_pos(m.start())
513
  if seg_idx is None:
514
  return orig
515
- # Get expected result_idx for this segment
516
  expected_result_idx = seg_to_result_idx.get(seg_idx)
517
- # For regular words, use word-based mapping to find correct result_idx
518
  result_idx = None
519
  if pos and not pos.startswith("0:0:"):
520
  candidates = word_to_all_results.get(pos, [])
@@ -529,16 +685,13 @@ def compute_mfa_timestamps(current_html, json_output, segment_dir, cached_log_ro
529
  result_idx = expected_result_idx
530
  if result_idx is None:
531
  return orig
532
- # Use result_idx prefix to get segment-specific timestamp
533
  key = f"{result_idx}:{pos}"
534
  ts = word_timestamps.get(key)
535
  if not ts:
536
  return orig
537
- # Convert relative timestamps to absolute by adding segment offset
538
  seg_offset = seg_offset_map.get(seg_idx, 0)
539
  abs_start = ts[0] + seg_offset
540
  abs_end = ts[1] + seg_offset
541
- # Include result_idx so char-level injection can find letter timestamps
542
  return orig[:-1] + f' data-result-idx="{result_idx}" data-start="{abs_start:.4f}" data-end="{abs_end:.4f}">'
543
 
544
  html = re.sub(word_open_re, _inject_word_ts, current_html)
@@ -551,19 +704,16 @@ def compute_mfa_timestamps(current_html, json_output, segment_dir, cached_log_ro
551
 
552
  def _stamp_chars_with_mfa(word_m):
553
  word_open = word_m.group(1)
554
- word_abs_start = float(word_m.group(2)) # data-start (already correctly injected)
555
  inner = word_m.group(4)
556
 
557
- # Extract data-pos from word tag
558
  pos_m = re.search(r'data-pos="([^"]+)"', word_open)
559
  word_pos = pos_m.group(1) if pos_m else None
560
 
561
- # Find result_idx from word tag's data-result-idx if available, else use mapping
562
  result_idx_m = re.search(r'data-result-idx="(\d+)"', word_open)
563
  if result_idx_m:
564
  result_idx = int(result_idx_m.group(1))
565
  else:
566
- # Fallback: use word-based mapping to find correct result_idx
567
  result_idx = None
568
  if word_pos and not word_pos.startswith("0:0:"):
569
  candidates = word_to_all_results.get(word_pos, [])
@@ -571,51 +721,42 @@ def compute_mfa_timestamps(current_html, json_output, segment_dir, cached_log_ro
571
  if len(candidates) == 1:
572
  result_idx = candidates[0]
573
  else:
574
- # Without position info, just take the first candidate
575
  result_idx = candidates[0]
576
 
577
  key = f"{result_idx}:{word_pos}" if result_idx is not None and word_pos else None
578
 
579
- # Look up word's relative start from MFA to calculate offset
580
  word_ts = word_timestamps.get(key) if key else None
581
  mfa_letters = letter_timestamps.get(key) if key else None
582
  if not mfa_letters or not word_ts:
583
  return word_m.group(0)
584
 
585
- word_rel_start = word_ts[0] # Word's relative start from MFA
586
 
587
  char_matches = list(re.finditer(r'<span class="char">([^<]*)</span>', inner))
588
  if not char_matches:
589
  return word_m.group(0)
590
 
591
- # Match MFA letters to HTML chars (no NFC — base-char comparison instead)
592
  mfa_chars = [l["char"] for l in mfa_letters]
593
  html_chars = [m.group(1).replace('\u0640', '') for m in char_matches]
594
 
595
- # Allowed character mappings (MFA char → HTML char)
596
- # ى (alef maksura) ↔ ي (ya) are visually similar and interchangeable
597
  CHAR_EQUIVALENTS = {
598
- 'ى': 'ي', # alef maksura → ya
599
- 'ي': 'ى', # ya → alef maksura
600
  }
601
 
602
  def _first_base(s):
603
- """First non-combining character after NFD decomposition."""
604
  for c in unicodedata.normalize("NFD", s):
605
  if not unicodedata.category(c).startswith('M'):
606
  return c
607
  return s[0] if s else ''
608
 
609
  def chars_match(mfa_c, html_c, log_substitution=False):
610
- """Check if MFA char matches HTML char, including allowed equivalents."""
611
  if mfa_c == html_c or html_c in mfa_c or mfa_c in html_c:
612
  return True
613
- # Check allowed equivalents
614
  if CHAR_EQUIVALENTS.get(mfa_c) == html_c:
615
  if log_substitution:
616
  print(f"[MFA_TS] Char substitution: MFA '{mfa_c}' → HTML '{html_c}' (key={key})")
617
  return True
618
- # Base-char comparison (handles decomposed↔precomposed without NFC)
619
  mb, hb = _first_base(mfa_c), _first_base(html_c)
620
  if mb and hb and (mb == hb or CHAR_EQUIVALENTS.get(mb) == hb):
621
  if log_substitution:
@@ -634,26 +775,19 @@ def compute_mfa_timestamps(current_html, json_output, segment_dir, cached_log_ro
634
  mfa_char = mfa_chars[mfa_idx]
635
  if chars_match(mfa_char, html_char, log_substitution=True):
636
  letter = mfa_letters[mfa_idx]
637
- # Skip letters without valid timestamps
638
  if letter["start"] is None or letter["end"] is None:
639
  print(f"[MFA_TS] Skipping letter with missing timestamp: char='{letter.get('char')}' key={key} mfa_idx={mfa_idx}")
640
  if chars_match(mfa_char, html_char) or len(html_char) >= len(mfa_char):
641
  mfa_idx += 1
642
  continue
643
- # Convert letter timestamps using word anchor
644
- # word_abs_start is already correct from word-level injection
645
- # letter times are relative to segment, so offset by (letter_start - word_rel_start)
646
  abs_start = word_abs_start + (letter["start"] - word_rel_start)
647
  abs_end = word_abs_start + (letter["end"] - word_rel_start)
648
- # Determine group_id: prefer cross-word group if exists, else use MFA's
649
  crossword_gid = crossword_groups.get((key, mfa_idx), "")
650
  final_group_id = crossword_gid or letter.get("group_id", "")
651
  char_replacements.append((
652
  cm.start(), cm.end(),
653
  f'<span class="char" data-start="{abs_start:.4f}" data-end="{abs_end:.4f}" data-group-id="{final_group_id}">{cm.group(1)}</span>'
654
  ))
655
- # Lookahead: stamp combining continuations with same MFA timestamp
656
- # (handles precomposed MFA char like ئ split into [يْـ, ٔ] in HTML)
657
  mfa_nfd = unicodedata.normalize("NFD", letter["char"])
658
  peek = html_idx + 1
659
  while peek < len(char_matches):
@@ -671,7 +805,6 @@ def compute_mfa_timestamps(current_html, json_output, segment_dir, cached_log_ro
671
  if chars_match(mfa_char, html_char) or len(html_char) >= len(mfa_char):
672
  mfa_idx += 1
673
 
674
- # Apply replacements in reverse order
675
  stamped_inner = inner
676
  for start, end, replacement in reversed(char_replacements):
677
  stamped_inner = stamped_inner[:start] + replacement + stamped_inner[end:]
@@ -703,7 +836,6 @@ def compute_mfa_timestamps(current_html, json_output, segment_dir, cached_log_ro
703
  for w in result.get("words", []) if w.get("start") is not None and w.get("end") is not None
704
  ],
705
  })
706
- # Collect char-level timestamps
707
  _char_ts_log.append({
708
  "ref": result.get("ref", ""),
709
  "words": [
@@ -726,81 +858,11 @@ def compute_mfa_timestamps(current_html, json_output, segment_dir, cached_log_ro
726
  except Exception as e:
727
  print(f"[USAGE_LOG] Failed to log word timestamps: {e}")
728
 
729
- # Build enriched JSON with word/letter timestamps (relative to segment)
730
- from src.core.quran_index import get_quran_index
731
- index = get_quran_index()
732
-
733
- def _get_word_text(location: str) -> str:
734
- """Look up word text from Quran index by location (surah:ayah:word)."""
735
- if not location or location.startswith("0:0:"):
736
- return "" # Special segments (Basmala/Isti'adha) use 0:0:N
737
- try:
738
- parts = location.split(":")
739
- if len(parts) >= 3:
740
- key = (int(parts[0]), int(parts[1]), int(parts[2]))
741
- idx = index.word_lookup.get(key)
742
- if idx is not None:
743
- return index.words[idx].display_text
744
- except (ValueError, IndexError):
745
- pass
746
- return ""
747
-
748
- enriched_segments = []
749
- for seg in segments:
750
- seg_idx = seg.get("segment", 0) - 1
751
- result_idx = seg_to_result_idx.get(seg_idx)
752
-
753
- segment_data = dict(seg) # Copy original segment data
754
-
755
- if result_idx is not None:
756
- # For special segments (Basmala/Isti'adha), get words from matched_text
757
- _ref = seg.get("ref_from", "") or seg.get("special_type", "")
758
- is_special = _ref.lower() in _SPECIAL_REFS
759
- special_words = seg.get("matched_text", "").replace(" \u06dd ", " ").split() if is_special else []
760
-
761
- # Find matching MFA result for this segment
762
- for i, result in enumerate(results):
763
- if i != result_idx or result.get("status") != "ok":
764
- continue
765
- words_with_ts = []
766
- for word_idx, word in enumerate(result.get("words", [])):
767
- if word.get("start") is None or word.get("end") is None:
768
- continue
769
-
770
- location = word.get("location", "")
771
-
772
- # Get word text: from matched_text for special, from index for regular
773
- if is_special or location.startswith("0:0:"):
774
- word_text = special_words[word_idx] if word_idx < len(special_words) else ""
775
- else:
776
- word_text = _get_word_text(location)
777
-
778
- word_data = {
779
- "word": word_text,
780
- "location": location,
781
- "start": round(word["start"], 4), # Relative to segment
782
- "end": round(word["end"], 4),
783
- }
784
- # Add letter timestamps if available
785
- if word.get("letters"):
786
- word_data["letters"] = [
787
- {
788
- "char": lt.get("char", ""),
789
- "start": round(lt["start"], 4),
790
- "end": round(lt["end"], 4),
791
- }
792
- for lt in word.get("letters", [])
793
- if lt.get("start") is not None
794
- ]
795
- words_with_ts.append(word_data)
796
-
797
- if words_with_ts:
798
- segment_data["words"] = words_with_ts
799
- break
800
-
801
- enriched_segments.append(segment_data)
802
-
803
- enriched_json = {"segments": enriched_segments}
804
 
805
  # Final yield: updated HTML, hide progress bar, show Animate All, enriched JSON
806
  animate_all_btn_html = '<button class="animate-all-btn">Animate All</button>'
 
5
  # Lowercase special ref names for case-insensitive matching
6
  _SPECIAL_REFS = {"basmala", "isti'adha", "isti'adha+basmala"}
7
 
8
+ _BASMALA_TEXT = "بِسْمِ ٱللَّهِ ٱلرَّحْمَٰنِ ٱلرَّحِيم"
9
+ _ISTIATHA_TEXT = "أَعُوذُ بِٱللَّهِ مِنَ الشَّيْطَانِ الرَّجِيم"
10
+ _COMBINED_PREFIX = _ISTIATHA_TEXT + " ۝ " + _BASMALA_TEXT
11
+
12
 
13
  def _mfa_upload_and_submit(refs, audio_paths):
14
  """Upload audio files and submit alignment batch to the MFA Space.
 
99
  return parsed["results"]
100
 
101
 
102
+ # ---------------------------------------------------------------------------
103
+ # Reusable helpers (shared by UI generator and API function)
104
+ # ---------------------------------------------------------------------------
105
+
106
+ def _make_ts_key(result_idx, ref, loc):
107
+ """Build the composite key used in word/letter timestamp dicts."""
108
+ is_special = ref.strip().lower() in _SPECIAL_REFS
109
+ is_fused = "+" in ref
110
+ if is_special:
111
+ base_key = f"{ref}:{loc}"
112
+ elif is_fused and loc.startswith("0:0:"):
113
+ base_key = f"{ref}:{loc}"
114
+ else:
115
+ base_key = loc
116
+ return f"{result_idx}:{base_key}"
117
+
118
+
119
+ def _build_mfa_ref(seg):
120
+ """Build the MFA ref string for a single segment. Returns None to skip."""
121
+ ref_from = seg.get("ref_from", "")
122
+ ref_to = seg.get("ref_to", "")
123
+ confidence = seg.get("confidence", 0)
124
+
125
+ if not ref_from:
126
+ ref_from = seg.get("special_type", "")
127
+ ref_to = ref_from
128
+ if not ref_from or confidence <= 0:
129
+ return None
130
+
131
+ if ref_from == ref_to:
132
+ mfa_ref = ref_from
133
+ else:
134
+ mfa_ref = f"{ref_from}-{ref_to}"
135
+
136
+ _is_special_ref = ref_from.strip().lower() in _SPECIAL_REFS
137
+ if not _is_special_ref:
138
+ matched_text = seg.get("matched_text", "")
139
+ if matched_text.startswith(_COMBINED_PREFIX):
140
+ mfa_ref = f"Isti'adha+Basmala+{mfa_ref}"
141
+ elif matched_text.startswith(_ISTIATHA_TEXT):
142
+ mfa_ref = f"Isti'adha+{mfa_ref}"
143
+ elif matched_text.startswith(_BASMALA_TEXT):
144
+ mfa_ref = f"Basmala+{mfa_ref}"
145
+
146
+ return mfa_ref
147
+
148
+
149
+ def _build_mfa_refs(segments, segment_dir):
150
+ """Build MFA refs and audio paths from segments.
151
+
152
+ Returns (refs, audio_paths, seg_to_result_idx).
153
+ """
154
+ refs = []
155
+ audio_paths = []
156
+ seg_to_result_idx = {}
157
+
158
+ for seg in segments:
159
+ seg_idx = seg.get("segment", 0) - 1
160
+ mfa_ref = _build_mfa_ref(seg)
161
+ if mfa_ref is None:
162
+ continue
163
+
164
+ audio_path = os.path.join(segment_dir, f"seg_{seg_idx}.wav") if segment_dir else None
165
+ if not audio_path or not os.path.exists(audio_path):
166
+ print(f"[MFA_TS] Skipping seg {seg_idx}: audio not found at {audio_path}")
167
+ continue
168
+
169
+ seg_to_result_idx[seg_idx] = len(refs)
170
+ refs.append(mfa_ref)
171
+ audio_paths.append(audio_path)
172
+
173
+ print(f"[MFA_TS] {len(refs)} refs to align: {refs[:5]}{'...' if len(refs) > 5 else ''}")
174
+ return refs, audio_paths, seg_to_result_idx
175
+
176
+
177
+ def _assign_letter_groups(letters, word_location):
178
+ """Assign group_id to letters sharing identical (start, end) timestamps."""
179
+ if not letters:
180
+ return []
181
+ result = []
182
+ group_id = 0
183
+ prev_ts = None
184
+ for letter in letters:
185
+ ts = (letter.get("start"), letter.get("end"))
186
+ if ts != prev_ts:
187
+ group_id += 1
188
+ prev_ts = ts
189
+ result.append({
190
+ "char": letter.get("char", ""),
191
+ "start": letter.get("start"),
192
+ "end": letter.get("end"),
193
+ "group_id": f"{word_location}:{group_id}",
194
+ })
195
+ return result
196
+
197
+
198
+ def _build_timestamp_lookups(results):
199
+ """Build timestamp lookup dicts from MFA results.
200
+
201
+ Returns (word_timestamps, letter_timestamps, word_to_all_results).
202
+ """
203
+ word_timestamps = {}
204
+ letter_timestamps = {}
205
+ word_to_all_results = {}
206
+
207
+ for result_idx, result in enumerate(results):
208
+ if result.get("status") != "ok":
209
+ print(f"[MFA_TS] Segment failed: ref={result.get('ref')} error={result.get('error')}")
210
+ continue
211
+ ref = result.get("ref", "")
212
+ is_special = ref.strip().lower() in _SPECIAL_REFS
213
+ is_fused = "+" in ref
214
+ for word in result.get("words", []):
215
+ loc = word.get("location", "")
216
+ if loc:
217
+ key = _make_ts_key(result_idx, ref, loc)
218
+ word_timestamps[key] = (word["start"], word["end"])
219
+ letters = word.get("letters")
220
+ if letters:
221
+ letter_timestamps[key] = _assign_letter_groups(letters, loc)
222
+ if not is_special and not (is_fused and loc.startswith("0:0:")):
223
+ if loc not in word_to_all_results:
224
+ word_to_all_results[loc] = []
225
+ word_to_all_results[loc].append(result_idx)
226
+
227
+ print(f"[MFA_TS] {len(word_timestamps)} word timestamps collected, {len(letter_timestamps)} with letter-level data")
228
+ return word_timestamps, letter_timestamps, word_to_all_results
229
+
230
+
231
+ def _build_crossword_groups(results, letter_ts_dict):
232
+ """Build mapping of (key, letter_idx) -> cross-word group_id.
233
+
234
+ Only checks word boundaries: last letter(s) of word N vs first
235
+ letter(s) of word N+1.
236
+ """
237
+ crossword_groups = {}
238
+
239
+ for result_idx, result in enumerate(results):
240
+ if result.get("status") != "ok":
241
+ continue
242
+ ref = result.get("ref", "")
243
+ words = result.get("words", [])
244
+
245
+ for word_i in range(len(words) - 1):
246
+ word_a = words[word_i]
247
+ word_b = words[word_i + 1]
248
+
249
+ loc_a = word_a.get("location", "")
250
+ loc_b = word_b.get("location", "")
251
+ if not loc_a or not loc_b:
252
+ continue
253
+
254
+ key_a = _make_ts_key(result_idx, ref, loc_a)
255
+ key_b = _make_ts_key(result_idx, ref, loc_b)
256
+ letters_a = letter_ts_dict.get(key_a, [])
257
+ letters_b = letter_ts_dict.get(key_b, [])
258
+
259
+ if not letters_a or not letters_b:
260
+ continue
261
+
262
+ for idx_a in range(len(letters_a) - 1, max(len(letters_a) - 3, -1), -1):
263
+ letter_a = letters_a[idx_a]
264
+ if letter_a.get("start") is None or letter_a.get("end") is None:
265
+ continue
266
+ for idx_b in range(min(3, len(letters_b))):
267
+ letter_b = letters_b[idx_b]
268
+ if letter_b.get("start") is None or letter_b.get("end") is None:
269
+ continue
270
+ if letter_a["start"] == letter_b["start"] and letter_a["end"] == letter_b["end"]:
271
+ group_id = f"xword-{result_idx}-{word_i}"
272
+ crossword_groups[(key_a, idx_a)] = group_id
273
+ crossword_groups[(key_b, idx_b)] = group_id
274
+
275
+ if crossword_groups:
276
+ print(f"[MFA_TS] Found {len(crossword_groups)} cross-word overlapping letters")
277
+
278
+ return crossword_groups
279
+
280
+
281
+ def _reconstruct_ref_key(seg):
282
+ """Reconstruct the MFA ref key for a segment (for result matching)."""
283
+ ref_from = seg.get("ref_from", "")
284
+ ref_to = seg.get("ref_to", "")
285
+ if not ref_from:
286
+ ref_from = seg.get("special_type", "")
287
+ ref_to = ref_from
288
+ ref_key = f"{ref_from}-{ref_to}" if ref_from != ref_to else ref_from
289
+ is_special = ref_from.strip().lower() in _SPECIAL_REFS
290
+ if not is_special:
291
+ matched_text = seg.get("matched_text", "")
292
+ if matched_text.startswith(_COMBINED_PREFIX):
293
+ ref_key = f"Isti'adha+Basmala+{ref_key}"
294
+ elif matched_text.startswith(_ISTIATHA_TEXT):
295
+ ref_key = f"Isti'adha+{ref_key}"
296
+ elif matched_text.startswith(_BASMALA_TEXT):
297
+ ref_key = f"Basmala+{ref_key}"
298
+ return ref_key
299
+
300
+
301
+ def _extend_word_timestamps(word_timestamps, segments, seg_to_result_idx,
302
+ results, segment_dir):
303
+ """Extend word ends to fill gaps between consecutive words.
304
+
305
+ Mutates word_timestamps in place.
306
+ """
307
+ import wave
308
+ for seg in segments:
309
+ ref_from = seg.get("ref_from", "")
310
+ confidence = seg.get("confidence", 0)
311
+ if not ref_from:
312
+ ref_from = seg.get("special_type", "")
313
+ if not ref_from or confidence <= 0:
314
+ continue
315
+ seg_idx = seg.get("segment", 0) - 1
316
+ result_idx = seg_to_result_idx.get(seg_idx)
317
+ if result_idx is None:
318
+ continue
319
+ ref_key = _reconstruct_ref_key(seg)
320
+ seg_word_locs = []
321
+ for result in results:
322
+ if result.get("ref") == ref_key and result.get("status") == "ok":
323
+ for w in result.get("words", []):
324
+ loc = w.get("location", "")
325
+ if loc:
326
+ key = _make_ts_key(result_idx, ref_key, loc)
327
+ if key in word_timestamps:
328
+ seg_word_locs.append(key)
329
+ break
330
+ if not seg_word_locs:
331
+ continue
332
+ # Extend each word's end to the next word's start
333
+ for i in range(len(seg_word_locs) - 1):
334
+ cur_start, cur_end = word_timestamps[seg_word_locs[i]]
335
+ nxt_start, _ = word_timestamps[seg_word_locs[i + 1]]
336
+ if nxt_start > cur_end:
337
+ word_timestamps[seg_word_locs[i]] = (cur_start, nxt_start)
338
+ # Extend first word back to time 0 so highlight starts immediately
339
+ first_loc = seg_word_locs[0]
340
+ first_start, first_end = word_timestamps[first_loc]
341
+ if first_start > 0:
342
+ word_timestamps[first_loc] = (0, first_end)
343
+ # Extend last word to segment audio duration
344
+ last_loc = seg_word_locs[-1]
345
+ last_start, last_end = word_timestamps[last_loc]
346
+ audio_path = os.path.join(segment_dir, f"seg_{seg_idx}.wav") if segment_dir else None
347
+ if audio_path and os.path.exists(audio_path):
348
+ with wave.open(audio_path, 'rb') as wf:
349
+ seg_duration = wf.getnframes() / wf.getframerate()
350
+ if seg_duration > last_end:
351
+ word_timestamps[last_loc] = (last_start, seg_duration)
352
+
353
+ print(f"[MFA_TS] Post-processed timestamps: extended word ends to fill gaps")
354
+
355
+
356
+ def _build_enriched_json(segments, results, seg_to_result_idx,
357
+ word_timestamps, letter_timestamps, granularity,
358
+ *, minimal=False):
359
+ """Build enriched segments with word (and optionally letter) timestamps.
360
+
361
+ When *minimal* is True (API path), each segment only contains
362
+ ``segment`` number + ``words`` array. When False (UI path), all
363
+ original segment fields are preserved.
364
+
365
+ Returns dict with "segments" key.
366
+ """
367
+ from src.core.quran_index import get_quran_index
368
+ index = get_quran_index()
369
+ include_letters = (granularity == "words+chars")
370
+
371
+ def _get_word_text(location):
372
+ if not location or location.startswith("0:0:"):
373
+ return ""
374
+ try:
375
+ parts = location.split(":")
376
+ if len(parts) >= 3:
377
+ key = (int(parts[0]), int(parts[1]), int(parts[2]))
378
+ idx = index.word_lookup.get(key)
379
+ if idx is not None:
380
+ return index.words[idx].display_text
381
+ except (ValueError, IndexError):
382
+ pass
383
+ return ""
384
+
385
+ enriched_segments = []
386
+ for seg in segments:
387
+ seg_idx = seg.get("segment", 0) - 1
388
+ result_idx = seg_to_result_idx.get(seg_idx)
389
+
390
+ if minimal:
391
+ segment_data = {"segment": seg.get("segment", 0)}
392
+ else:
393
+ segment_data = dict(seg)
394
+
395
+ if result_idx is not None:
396
+ _ref = seg.get("ref_from", "") or seg.get("special_type", "")
397
+ is_special = _ref.lower() in _SPECIAL_REFS
398
+ special_words = seg.get("matched_text", "").replace(" \u06dd ", " ").split() if is_special else []
399
+
400
+ for i, result in enumerate(results):
401
+ if i != result_idx or result.get("status") != "ok":
402
+ continue
403
+ words_with_ts = []
404
+ for word_idx, word in enumerate(result.get("words", [])):
405
+ if word.get("start") is None or word.get("end") is None:
406
+ continue
407
+
408
+ location = word.get("location", "")
409
+
410
+ if minimal:
411
+ # API: compact — [location, start, end] or [location, start, end, letters]
412
+ word_entry = [location, round(word["start"], 4), round(word["end"], 4)]
413
+ if include_letters and word.get("letters"):
414
+ word_entry.append([
415
+ [lt.get("char", ""), round(lt["start"], 4), round(lt["end"], 4)]
416
+ for lt in word.get("letters", [])
417
+ if lt.get("start") is not None
418
+ ])
419
+ words_with_ts.append(word_entry)
420
+ else:
421
+ # UI: keyed objects with display text
422
+ if is_special or location.startswith("0:0:"):
423
+ word_text = special_words[word_idx] if word_idx < len(special_words) else ""
424
+ else:
425
+ word_text = _get_word_text(location)
426
+
427
+ word_data = {
428
+ "word": word_text,
429
+ "location": location,
430
+ "start": round(word["start"], 4),
431
+ "end": round(word["end"], 4),
432
+ }
433
+ if include_letters and word.get("letters"):
434
+ word_data["letters"] = [
435
+ {
436
+ "char": lt.get("char", ""),
437
+ "start": round(lt["start"], 4),
438
+ "end": round(lt["end"], 4),
439
+ }
440
+ for lt in word.get("letters", [])
441
+ if lt.get("start") is not None
442
+ ]
443
+ words_with_ts.append(word_data)
444
+
445
+ if words_with_ts:
446
+ segment_data["words"] = words_with_ts
447
+ break
448
+
449
+ enriched_segments.append(segment_data)
450
+
451
+ return {"segments": enriched_segments}
452
+
453
+
454
+ # ---------------------------------------------------------------------------
455
+ # Synchronous API function
456
+ # ---------------------------------------------------------------------------
457
+
458
+ def compute_mfa_timestamps_api(segments, segment_dir, granularity="words"):
459
+ """Run MFA forced alignment and return enriched segments (no UI/HTML).
460
+
461
+ Args:
462
+ segments: List of segment dicts (same format as alignment response).
463
+ segment_dir: Path to directory containing per-segment WAV files.
464
+ granularity: "words" or "words+chars".
465
+
466
+ Returns:
467
+ Dict with "segments" key containing enriched segment data.
468
+ """
469
+ if not granularity or granularity not in ("words", "words+chars"):
470
+ granularity = "words"
471
+
472
+ refs, audio_paths, seg_to_result_idx = _build_mfa_refs(segments, segment_dir)
473
+ if not refs:
474
+ return {"segments": segments}
475
+
476
+ event_id, headers, base = _mfa_upload_and_submit(refs, audio_paths)
477
+ results = _mfa_wait_result(event_id, headers, base)
478
+ print(f"[MFA_TS] Got {len(results)} results from MFA API")
479
+
480
+ word_ts, letter_ts, _ = _build_timestamp_lookups(results)
481
+ _build_crossword_groups(results, letter_ts)
482
+ _extend_word_timestamps(word_ts, segments, seg_to_result_idx, results, segment_dir)
483
+ return _build_enriched_json(segments, results, seg_to_result_idx,
484
+ word_ts, letter_ts, granularity, minimal=True)
485
+
486
+
487
+ # ---------------------------------------------------------------------------
488
+ # UI progress bar
489
+ # ---------------------------------------------------------------------------
490
+
491
  def _ts_progress_bar_html(total_segments, rate, animated=True):
492
  """Return HTML for a progress bar showing Segment x/N.
493
 
 
542
  </div>'''
543
 
544
 
545
+ # ---------------------------------------------------------------------------
546
+ # UI generator (Gradio — yields progress, injects HTML timestamps)
547
+ # ---------------------------------------------------------------------------
548
+
549
  def compute_mfa_timestamps(current_html, json_output, segment_dir, cached_log_row=None):
550
  """Compute word-level timestamps via MFA forced alignment and inject into HTML.
551
 
 
566
  yield current_html, gr.update(), gr.update(), gr.update(), gr.update()
567
  return
568
 
569
+ # Build refs and audio paths using shared helper
570
  segments = json_output.get("segments", []) if json_output else []
571
  print(f"[MFA_TS] {len(segments)} segments in JSON")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
 
573
+ refs, audio_paths, seg_to_result_idx = _build_mfa_refs(segments, segment_dir)
 
 
 
 
 
574
 
575
  if not refs:
576
  print("[MFA_TS] Early return: no valid refs/audio pairs")
 
629
  )
630
  raise
631
 
632
+ # Build timestamp lookups using shared helper
633
+ word_timestamps, letter_timestamps, word_to_all_results = _build_timestamp_lookups(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
 
635
+ # Build cross-word groups using shared helper
636
  crossword_groups = _build_crossword_groups(results, letter_timestamps)
637
 
638
+ # Extend word timestamps using shared helper
639
+ _extend_word_timestamps(word_timestamps, segments, seg_to_result_idx, results, segment_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
 
641
+ # --- HTML injection (UI-only, not shared with API) ---
642
 
643
  # Inject timestamps into word spans, using segment boundaries to determine result_idx
644
+ seg_boundaries = []
 
645
  for m in re.finditer(r'data-segment-idx="(\d+)"', current_html):
646
  seg_boundaries.append((m.start(), int(m.group(1))))
647
  seg_boundaries.sort(key=lambda x: x[0])
648
 
649
+ seg_offset_map = {}
 
650
  for seg in segments:
651
+ idx = seg.get("segment", 0) - 1
652
  seg_offset_map[idx] = seg.get("time_from", 0)
653
 
 
654
  def _get_seg_idx_at_pos(pos):
 
655
  seg_idx = None
656
  for boundary_pos, idx in seg_boundaries:
657
  if boundary_pos > pos:
 
667
  if not pos_m:
668
  return orig
669
  pos = pos_m.group(1)
 
670
  seg_idx = _get_seg_idx_at_pos(m.start())
671
  if seg_idx is None:
672
  return orig
 
673
  expected_result_idx = seg_to_result_idx.get(seg_idx)
 
674
  result_idx = None
675
  if pos and not pos.startswith("0:0:"):
676
  candidates = word_to_all_results.get(pos, [])
 
685
  result_idx = expected_result_idx
686
  if result_idx is None:
687
  return orig
 
688
  key = f"{result_idx}:{pos}"
689
  ts = word_timestamps.get(key)
690
  if not ts:
691
  return orig
 
692
  seg_offset = seg_offset_map.get(seg_idx, 0)
693
  abs_start = ts[0] + seg_offset
694
  abs_end = ts[1] + seg_offset
 
695
  return orig[:-1] + f' data-result-idx="{result_idx}" data-start="{abs_start:.4f}" data-end="{abs_end:.4f}">'
696
 
697
  html = re.sub(word_open_re, _inject_word_ts, current_html)
 
704
 
705
  def _stamp_chars_with_mfa(word_m):
706
  word_open = word_m.group(1)
707
+ word_abs_start = float(word_m.group(2))
708
  inner = word_m.group(4)
709
 
 
710
  pos_m = re.search(r'data-pos="([^"]+)"', word_open)
711
  word_pos = pos_m.group(1) if pos_m else None
712
 
 
713
  result_idx_m = re.search(r'data-result-idx="(\d+)"', word_open)
714
  if result_idx_m:
715
  result_idx = int(result_idx_m.group(1))
716
  else:
 
717
  result_idx = None
718
  if word_pos and not word_pos.startswith("0:0:"):
719
  candidates = word_to_all_results.get(word_pos, [])
 
721
  if len(candidates) == 1:
722
  result_idx = candidates[0]
723
  else:
 
724
  result_idx = candidates[0]
725
 
726
  key = f"{result_idx}:{word_pos}" if result_idx is not None and word_pos else None
727
 
 
728
  word_ts = word_timestamps.get(key) if key else None
729
  mfa_letters = letter_timestamps.get(key) if key else None
730
  if not mfa_letters or not word_ts:
731
  return word_m.group(0)
732
 
733
+ word_rel_start = word_ts[0]
734
 
735
  char_matches = list(re.finditer(r'<span class="char">([^<]*)</span>', inner))
736
  if not char_matches:
737
  return word_m.group(0)
738
 
 
739
  mfa_chars = [l["char"] for l in mfa_letters]
740
  html_chars = [m.group(1).replace('\u0640', '') for m in char_matches]
741
 
 
 
742
  CHAR_EQUIVALENTS = {
743
+ 'ى': 'ي',
744
+ 'ي': 'ى',
745
  }
746
 
747
  def _first_base(s):
 
748
  for c in unicodedata.normalize("NFD", s):
749
  if not unicodedata.category(c).startswith('M'):
750
  return c
751
  return s[0] if s else ''
752
 
753
  def chars_match(mfa_c, html_c, log_substitution=False):
 
754
  if mfa_c == html_c or html_c in mfa_c or mfa_c in html_c:
755
  return True
 
756
  if CHAR_EQUIVALENTS.get(mfa_c) == html_c:
757
  if log_substitution:
758
  print(f"[MFA_TS] Char substitution: MFA '{mfa_c}' → HTML '{html_c}' (key={key})")
759
  return True
 
760
  mb, hb = _first_base(mfa_c), _first_base(html_c)
761
  if mb and hb and (mb == hb or CHAR_EQUIVALENTS.get(mb) == hb):
762
  if log_substitution:
 
775
  mfa_char = mfa_chars[mfa_idx]
776
  if chars_match(mfa_char, html_char, log_substitution=True):
777
  letter = mfa_letters[mfa_idx]
 
778
  if letter["start"] is None or letter["end"] is None:
779
  print(f"[MFA_TS] Skipping letter with missing timestamp: char='{letter.get('char')}' key={key} mfa_idx={mfa_idx}")
780
  if chars_match(mfa_char, html_char) or len(html_char) >= len(mfa_char):
781
  mfa_idx += 1
782
  continue
 
 
 
783
  abs_start = word_abs_start + (letter["start"] - word_rel_start)
784
  abs_end = word_abs_start + (letter["end"] - word_rel_start)
 
785
  crossword_gid = crossword_groups.get((key, mfa_idx), "")
786
  final_group_id = crossword_gid or letter.get("group_id", "")
787
  char_replacements.append((
788
  cm.start(), cm.end(),
789
  f'<span class="char" data-start="{abs_start:.4f}" data-end="{abs_end:.4f}" data-group-id="{final_group_id}">{cm.group(1)}</span>'
790
  ))
 
 
791
  mfa_nfd = unicodedata.normalize("NFD", letter["char"])
792
  peek = html_idx + 1
793
  while peek < len(char_matches):
 
805
  if chars_match(mfa_char, html_char) or len(html_char) >= len(mfa_char):
806
  mfa_idx += 1
807
 
 
808
  stamped_inner = inner
809
  for start, end, replacement in reversed(char_replacements):
810
  stamped_inner = stamped_inner[:start] + replacement + stamped_inner[end:]
 
836
  for w in result.get("words", []) if w.get("start") is not None and w.get("end") is not None
837
  ],
838
  })
 
839
  _char_ts_log.append({
840
  "ref": result.get("ref", ""),
841
  "words": [
 
858
  except Exception as e:
859
  print(f"[USAGE_LOG] Failed to log word timestamps: {e}")
860
 
861
+ # Build enriched JSON using shared helper (UI always includes letters)
862
+ enriched_json = _build_enriched_json(
863
+ segments, results, seg_to_result_idx,
864
+ word_timestamps, letter_timestamps, "words+chars",
865
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
866
 
867
  # Final yield: updated HTML, hide progress bar, show Animate All, enriched JSON
868
  animate_all_btn_html = '<button class="animate-all-btn">Animate All</button>'
src/ui/event_wiring.py CHANGED
@@ -9,6 +9,7 @@ from src.pipeline import (
9
  from src.api.session_api import (
10
  process_audio_session, resegment_session,
11
  retranscribe_session, realign_from_timestamps,
 
12
  )
13
  from src.mfa import compute_mfa_timestamps
14
  from src.ui.handlers import (
@@ -483,3 +484,15 @@ def _wire_api_endpoint(c):
483
  outputs=[c.api_result],
484
  api_name="realign_from_timestamps",
485
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from src.api.session_api import (
10
  process_audio_session, resegment_session,
11
  retranscribe_session, realign_from_timestamps,
12
+ mfa_timestamps_session, mfa_timestamps_direct,
13
  )
14
  from src.mfa import compute_mfa_timestamps
15
  from src.ui.handlers import (
 
484
  outputs=[c.api_result],
485
  api_name="realign_from_timestamps",
486
  )
487
+ gr.Button(visible=False).click(
488
+ fn=mfa_timestamps_session,
489
+ inputs=[c.api_audio_id, c.api_mfa_segments, c.api_mfa_granularity],
490
+ outputs=[c.api_result],
491
+ api_name="mfa_timestamps_session",
492
+ )
493
+ gr.Button(visible=False).click(
494
+ fn=mfa_timestamps_direct,
495
+ inputs=[c.api_audio, c.api_mfa_segments, c.api_mfa_granularity],
496
+ outputs=[c.api_result],
497
+ api_name="mfa_timestamps_direct",
498
+ )
src/ui/interface.py CHANGED
@@ -78,6 +78,8 @@ def build_interface():
78
  c.api_model = gr.Textbox(visible=False)
79
  c.api_device = gr.Textbox(visible=False)
80
  c.api_timestamps = gr.JSON(visible=False)
 
 
81
  c.api_result = gr.JSON(visible=False)
82
 
83
  wire_events(app, c)
@@ -110,7 +112,7 @@ def _build_left_column(c):
110
  choices=["Base", "Large"],
111
  value="Base",
112
  label="ASR Model",
113
- info="Large: more robust to noisy/non-studio recitations but much slower (10x bigger)"
114
  )
115
  c.device_radio = gr.Radio(
116
  choices=["GPU", "CPU"],
 
78
  c.api_model = gr.Textbox(visible=False)
79
  c.api_device = gr.Textbox(visible=False)
80
  c.api_timestamps = gr.JSON(visible=False)
81
+ c.api_mfa_segments = gr.JSON(visible=False)
82
+ c.api_mfa_granularity = gr.Textbox(visible=False)
83
  c.api_result = gr.JSON(visible=False)
84
 
85
  wire_events(app, c)
 
112
  choices=["Base", "Large"],
113
  value="Base",
114
  label="ASR Model",
115
+ info="Large: more robust to noisy/non-studio recitations but slower"
116
  )
117
  c.device_radio = gr.Radio(
118
  choices=["GPU", "CPU"],
tests/test_session_api.py CHANGED
@@ -263,6 +263,156 @@ class TestWorkflow:
263
  # 6. Error handling
264
  # ---------------------------------------------------------------------------
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  class TestErrorHandling:
267
  def test_invalid_audio_id_retranscribe(self, client):
268
  result = client.predict(
 
263
  # 6. Error handling
264
  # ---------------------------------------------------------------------------
265
 
266
+ # ---------------------------------------------------------------------------
267
+ # 7. MFA timestamps — session-based
268
+ # ---------------------------------------------------------------------------
269
+
270
+ class TestMfaTimestampsSession:
271
+ def test_basic_words_only(self, client, session):
272
+ """Session endpoint with stored segments, words granularity."""
273
+ result = client.predict(
274
+ session["audio_id"], None, "words",
275
+ api_name="/mfa_timestamps_session",
276
+ )
277
+ assert result["audio_id"] == session["audio_id"]
278
+ assert len(result["segments"]) > 0
279
+ has_words = any("words" in seg for seg in result["segments"])
280
+ assert has_words, "Expected at least one segment with words"
281
+ # Words-only: each word is [location, start, end] (3 elements)
282
+ for seg in result["segments"]:
283
+ for word in seg.get("words", []):
284
+ assert len(word) == 3, f"words granularity should give 3-element arrays, got {len(word)}"
285
+
286
+ def test_words_plus_chars(self, client, session):
287
+ """Session endpoint with words+chars granularity."""
288
+ result = client.predict(
289
+ session["audio_id"], None, "words+chars",
290
+ api_name="/mfa_timestamps_session",
291
+ )
292
+ has_letters = any(
293
+ len(word) == 4
294
+ for seg in result["segments"]
295
+ for word in seg.get("words", [])
296
+ )
297
+ assert has_letters, "words+chars should include letter arrays (4th element)"
298
+
299
+ def test_with_segments_override(self, client, session):
300
+ """Session endpoint with explicit segments (override stored)."""
301
+ segments_override = session["segments"][:2]
302
+ result = client.predict(
303
+ session["audio_id"], segments_override, "words",
304
+ api_name="/mfa_timestamps_session",
305
+ )
306
+ assert result["audio_id"] == session["audio_id"]
307
+ assert len(result["segments"]) == 2
308
+
309
+ def test_word_timestamp_fields(self, client, session):
310
+ """Verify word arrays have correct structure: [location, start, end, ?letters]."""
311
+ result = client.predict(
312
+ session["audio_id"], None, "words+chars",
313
+ api_name="/mfa_timestamps_session",
314
+ )
315
+ for seg in result["segments"]:
316
+ for word in seg.get("words", []):
317
+ assert isinstance(word[0], str), "word[0] should be location string"
318
+ assert isinstance(word[1], (int, float)), "word[1] should be start time"
319
+ assert isinstance(word[2], (int, float)), "word[2] should be end time"
320
+ assert word[2] > word[1], "end should be > start"
321
+ if len(word) == 4:
322
+ # Letters: list of [char, start, end]
323
+ for letter in word[3]:
324
+ assert len(letter) == 3
325
+ assert isinstance(letter[0], str)
326
+
327
+ def test_invalid_session(self, client):
328
+ result = client.predict(
329
+ FAKE_ID, None, "words",
330
+ api_name="/mfa_timestamps_session",
331
+ )
332
+ assert "error" in result
333
+ assert result["segments"] == []
334
+
335
+ def test_default_granularity(self, client, session):
336
+ """Empty granularity should default to words."""
337
+ result = client.predict(
338
+ session["audio_id"], None, "",
339
+ api_name="/mfa_timestamps_session",
340
+ )
341
+ assert len(result["segments"]) > 0
342
+ for seg in result["segments"]:
343
+ for word in seg.get("words", []):
344
+ assert len(word) == 3, "default granularity should not include letters"
345
+
346
+
347
+ # ---------------------------------------------------------------------------
348
+ # 8. MFA timestamps — direct
349
+ # ---------------------------------------------------------------------------
350
+
351
+ class TestMfaTimestampsDirect:
352
+ def test_basic(self, client, session):
353
+ """Direct endpoint with audio file and segments."""
354
+ result = client.predict(
355
+ AUDIO_FILE, session["segments"], "words",
356
+ api_name="/mfa_timestamps_direct",
357
+ )
358
+ assert "segments" in result
359
+ assert len(result["segments"]) > 0
360
+ has_words = any("words" in seg for seg in result["segments"])
361
+ assert has_words
362
+
363
+ def test_words_plus_chars(self, client, session):
364
+ result = client.predict(
365
+ AUDIO_FILE, session["segments"], "words+chars",
366
+ api_name="/mfa_timestamps_direct",
367
+ )
368
+ has_letters = any(
369
+ len(word) == 4
370
+ for seg in result["segments"]
371
+ for word in seg.get("words", [])
372
+ )
373
+ assert has_letters
374
+
375
+ def test_no_audio_id_in_response(self, client, session):
376
+ """Direct endpoint should not return audio_id."""
377
+ result = client.predict(
378
+ AUDIO_FILE, session["segments"], "words",
379
+ api_name="/mfa_timestamps_direct",
380
+ )
381
+ assert "audio_id" not in result
382
+
383
+ def test_empty_segments_error(self, client):
384
+ result = client.predict(
385
+ AUDIO_FILE, [], "words",
386
+ api_name="/mfa_timestamps_direct",
387
+ )
388
+ assert "error" in result
389
+ assert result["segments"] == []
390
+
391
+
392
+ # ---------------------------------------------------------------------------
393
+ # 9. Segments stored in session after alignment
394
+ # ---------------------------------------------------------------------------
395
+
396
+ class TestSegmentStorage:
397
+ def test_segments_stored_after_process(self, client):
398
+ """process_audio_session should store segments for later MFA use."""
399
+ proc = client.predict(
400
+ AUDIO_FILE, 200, 1000, 100, "Base", "CPU",
401
+ api_name="/process_audio_session",
402
+ )
403
+ # MFA session endpoint should find stored segments
404
+ result = client.predict(
405
+ proc["audio_id"], None, "words",
406
+ api_name="/mfa_timestamps_session",
407
+ )
408
+ assert "error" not in result or result.get("segments")
409
+ assert result["audio_id"] == proc["audio_id"]
410
+
411
+
412
+ # ---------------------------------------------------------------------------
413
+ # 10. Error handling
414
+ # ---------------------------------------------------------------------------
415
+
416
  class TestErrorHandling:
417
  def test_invalid_audio_id_retranscribe(self, client):
418
  result = client.predict(