hetchyy commited on
Commit
058f17e
·
1 Parent(s): d5f2531

Add estimate duration API

Browse files
.gitignore CHANGED
@@ -52,4 +52,5 @@ captures/
52
  docs/api.md
53
  docs/lease_duration_history.md
54
  scripts/
55
- tests/
 
 
52
  docs/api.md
53
  docs/lease_duration_history.md
54
  scripts/
55
+ tests/
56
+ align_config.py
align_config.py CHANGED
@@ -4,13 +4,31 @@ Only params that differ from the quran_aligner defaults.
4
  """
5
 
6
  # Window sizes
7
- LOOKBACK_WORDS = 10
8
- LOOKAHEAD_WORDS = 5
9
 
10
  # Retry windows
11
  RETRY_LOOKBACK_WORDS = 80
12
  RETRY_LOOKAHEAD_WORDS = 60
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Debug/profiling -- off for batch CLI
15
  ANCHOR_DEBUG = False
16
  PHONEME_ALIGNMENT_DEBUG = False
 
4
  """
5
 
6
  # Window sizes
7
+ LOOKBACK_WORDS = 30
8
+ LOOKAHEAD_WORDS = 8
9
 
10
  # Retry windows
11
  RETRY_LOOKBACK_WORDS = 80
12
  RETRY_LOOKAHEAD_WORDS = 60
13
 
14
+ # Inference settings
15
+ DTYPE = "float16"
16
+ TORCH_COMPILE = False # Skip torch.compile() overhead for batch jobs
17
+
18
+ # Download parallelism
19
+ DOWNLOAD_WORKERS = 16 # Parallel download+decode threads (I/O-bound, safe to oversubscribe CPUs)
20
+
21
+ # VAD batching (number of audio files to VAD together)
22
+ VAD_BATCH_SIZE_AYAH = 256
23
+ VAD_BATCH_SIZE_SURA = 4
24
+
25
+ # ASR batching
26
+ BATCHING_STRATEGY = "dynamic"
27
+ MAX_BATCH_SECONDS = 800
28
+ MAX_PAD_WASTE = 0.25
29
+ MIN_BATCH_SIZE = 16
30
+ INFERENCE_BATCH_SIZE = 32 # Only used when BATCHING_STRATEGY="naive"
31
+
32
  # Debug/profiling -- off for batch CLI
33
  ANCHOR_DEBUG = False
34
  PHONEME_ALIGNMENT_DEBUG = False
app.py CHANGED
@@ -54,6 +54,7 @@ if __name__ == "__main__":
54
  parser = argparse.ArgumentParser()
55
  parser.add_argument("--share", action="store_true", help="Create public link")
56
  parser.add_argument("--port", type=int, default=PORT, help="Port to run on")
 
57
  args = parser.parse_args()
58
 
59
  port = 7860
@@ -61,22 +62,25 @@ if __name__ == "__main__":
61
  print(f"ZeroGPU available: {ZERO_GPU_AVAILABLE}")
62
  print(f"Launching Gradio on port {port}")
63
 
64
- # Preload models and caches at startup so first request is fast
65
- print("Preloading models...")
66
- load_segmenter()
67
- load_phoneme_asr("Base")
68
- load_phoneme_asr("Large")
69
- print("Models preloaded.")
70
- print("Preloading caches...")
71
- get_ngram_index()
72
- preload_all_chapters()
73
- print("Caches preloaded.")
 
 
 
74
 
75
- # Warm up soxr resampler so first request doesn't pay initialization cost
76
- _dummy = librosa.resample(np.zeros(1600, dtype=np.float32),
77
- orig_sr=44100, target_sr=16000, res_type=RESAMPLE_TYPE)
78
- del _dummy
79
- print("Resampler warmed up.")
80
 
81
  # AoT compilation for VAD model (requires GPU lease)
82
  if IS_HF_SPACE and ZERO_GPU_AVAILABLE:
 
54
  parser = argparse.ArgumentParser()
55
  parser.add_argument("--share", action="store_true", help="Create public link")
56
  parser.add_argument("--port", type=int, default=PORT, help="Port to run on")
57
+ parser.add_argument("--dev", action="store_true", help="Dev mode: skip model preloading for fast startup")
58
  args = parser.parse_args()
59
 
60
  port = 7860
 
62
  print(f"ZeroGPU available: {ZERO_GPU_AVAILABLE}")
63
  print(f"Launching Gradio on port {port}")
64
 
65
+ if args.dev:
66
+ print("Dev mode: skipping model preloading (models load on first request)")
67
+ else:
68
+ # Preload models and caches at startup so first request is fast
69
+ print("Preloading models...")
70
+ load_segmenter()
71
+ load_phoneme_asr("Base")
72
+ load_phoneme_asr("Large")
73
+ print("Models preloaded.")
74
+ print("Preloading caches...")
75
+ get_ngram_index()
76
+ preload_all_chapters()
77
+ print("Caches preloaded.")
78
 
79
+ # Warm up soxr resampler so first request doesn't pay initialization cost
80
+ _dummy = librosa.resample(np.zeros(1600, dtype=np.float32),
81
+ orig_sr=44100, target_sr=16000, res_type=RESAMPLE_TYPE)
82
+ del _dummy
83
+ print("Resampler warmed up.")
84
 
85
  # AoT compilation for VAD model (requires GPU lease)
86
  if IS_HF_SPACE and ZERO_GPU_AVAILABLE:
config.py CHANGED
@@ -64,36 +64,21 @@ NGRAM_INDEX_PATH = DATA_PATH / f"phoneme_ngram_index_{NGRAM_SIZE}.pkl"
64
  # Inference settings
65
  # =============================================================================
66
 
 
67
  def get_vad_duration(minutes):
68
- """GPU seconds needed for VAD based on audio minutes.
69
-
70
- VAD GPU time scales linearly at ~0.28s per audio minute.
71
- Tuned from 50-run log analysis (Feb 2026): previous leases were tight
72
- at 30-60 min (15s lease vs 17s actual) and 60-120 min (25s vs 26s).
73
- """
74
- if minutes > 180:
75
- return 60
76
- elif minutes > 120:
77
- return 45 # was 40 — 137 min audio hit 38.3s (95% of old lease)
78
- elif minutes > 60:
79
- return 30 # was 25 — 89 min audio hit 25.8s (exceeded old lease)
80
- elif minutes > 30:
81
- return 20 # was 15 — 58 min audio hit 17s (exceeded old lease)
82
- elif minutes > 15:
83
- return 10
84
- else:
85
- return 5
86
 
87
  def get_asr_duration(minutes, model_name="Base"):
88
  """GPU seconds needed for ASR.
89
-
90
- ASR GPU time is nearly constant regardless of audio length due to batch
91
- processing — no range tiers needed. Tuned from 50-run log analysis
92
- (Feb 2026): Base uses 0.2-2.5s (warm), Large uses 0.8-5.6s (warm).
93
  """
94
  if model_name == "Large":
95
- return 10 # max warm 5.6s, cold start 10.4s
96
- return 3 # max warm 2.5s, cold start 5.6s
 
 
 
97
 
98
  # Batching strategy
99
  BATCHING_STRATEGY = "dynamic" # "naive" (fixed count) or "dynamic" (seconds + pad waste)
@@ -195,7 +180,7 @@ MFA_TIMEOUT = 240
195
  MFA_METHOD = "kalpy" # "kalpy", "align_one", "python_api", "cli"
196
  MFA_BEAM = 10 # Viterbi beam width
197
  MFA_RETRY_BEAM = 40 # Retry beam width (used when initial alignment fails)
198
- MFA_SHARED_CMVN = True # Compute shared CMVN across batch (kalpy only)
199
 
200
  # =============================================================================
201
  # Usage logging (pushed to HF Hub via ParquetScheduler)
 
64
  # Inference settings
65
  # =============================================================================
66
 
67
+ # VAD lease: linear regression from 121 GPU runs (R²=0.992)
68
  def get_vad_duration(minutes):
69
+ """GPU seconds needed for VAD based on audio minutes."""
70
+ VAD_LEASE_BUFFER = 3 # safety margin over regression (seconds)
71
+ return max(3, 0.282 * minutes + VAD_LEASE_BUFFER)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def get_asr_duration(minutes, model_name="Base"):
74
  """GPU seconds needed for ASR.
 
 
 
 
75
  """
76
  if model_name == "Large":
77
+ return 7
78
+ return 3
79
+
80
+ ESTIMATE_ALIGNMENT_OVERHEAD_S = 3 # DP alignment + result building
81
+ ESTIMATE_CPU_MULTIPLIER = 50
82
 
83
  # Batching strategy
84
  BATCHING_STRATEGY = "dynamic" # "naive" (fixed count) or "dynamic" (seconds + pad waste)
 
180
  MFA_METHOD = "kalpy" # "kalpy", "align_one", "python_api", "cli"
181
  MFA_BEAM = 10 # Viterbi beam width
182
  MFA_RETRY_BEAM = 40 # Retry beam width (used when initial alignment fails)
183
+ MFA_SHARED_CMVN = True # Compute shared CMVN across batch (kalpy only)
184
 
185
  # =============================================================================
186
  # Usage logging (pushed to HF Hub via ParquetScheduler)
docs/client_api.md CHANGED
@@ -7,6 +7,10 @@ from gradio_client import Client
7
 
8
  client = Client("https://your-space.hf.space")
9
 
 
 
 
 
10
  # Full pipeline
11
  result = client.predict(
12
  "recitation.mp3", # audio file path
@@ -68,6 +72,54 @@ If `audio_id` is missing, expired, or invalid:
68
 
69
  ## Endpoints
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ### `POST /process_audio_session`
72
 
73
  Full pipeline: preprocess → VAD → ASR → alignment. Creates a server-side session.
 
7
 
8
  client = Client("https://your-space.hf.space")
9
 
10
+ # Estimate processing time before starting
11
+ est = client.predict("process_audio_session", 60.0, None, "Base", "GPU", api_name="/estimate_duration")
12
+ print(f"Estimated time: {est['estimated_duration_s']}s")
13
+
14
  # Full pipeline
15
  result = client.predict(
16
  "recitation.mp3", # audio file path
 
72
 
73
  ## Endpoints
74
 
75
+ ### `POST /estimate_duration`
76
+
77
+ Estimate how long a processing endpoint will take before calling it.
78
+
79
+ | Parameter | Type | Default | Description |
80
+ |---|---|---|---|
81
+ | `endpoint` | str | required | Target endpoint name (e.g. `"process_audio_session"`) |
82
+ | `audio_duration_s` | float | `None` | Audio length in seconds. Required if no `audio_id` |
83
+ | `audio_id` | str | `None` | Session ID — infers audio duration from session metadata |
84
+ | `model_name` | str | `"Base"` | `"Base"` or `"Large"` |
85
+ | `device` | str | `"GPU"` | `"GPU"` or `"CPU"` |
86
+
87
+ **Example — before first processing call:**
88
+ ```python
89
+ est = client.predict(
90
+ "process_audio_session", # endpoint
91
+ 60.0, # audio_duration_s (seconds)
92
+ None, # audio_id (not yet available)
93
+ "Base", # model_name
94
+ "GPU", # device
95
+ api_name="/estimate_duration",
96
+ )
97
+ print(f"Estimated time: {est['estimated_duration_s']}s")
98
+ ```
99
+
100
+ **Example — with existing session (e.g. before MFA):**
101
+ ```python
102
+ est = client.predict(
103
+ "mfa_timestamps_session", # endpoint
104
+ None, # audio_duration_s (inferred from session)
105
+ audio_id, # audio_id
106
+ "Base", # model_name
107
+ "GPU", # device
108
+ api_name="/estimate_duration",
109
+ )
110
+ ```
111
+
112
+ **Response:**
113
+ ```json
114
+ {
115
+ "endpoint": "process_audio_session",
116
+ "estimated_duration_s": 28.0,
117
+ "device": "GPU",
118
+ "model_name": "Base"
119
+ }
120
+ ```
121
+ ---
122
+
123
  ### `POST /process_audio_session`
124
 
125
  Full pipeline: preprocess → VAD → ASR → alignment. Creates a server-side session.
src/api/session_api.py CHANGED
@@ -7,6 +7,7 @@ re-uploads and re-inference.
7
 
8
  import hashlib
9
  import json
 
10
  import os
11
  import pickle
12
  import re
@@ -88,6 +89,7 @@ def create_session(audio, speech_intervals, is_complete, intervals, model_name):
88
  "intervals": intervals,
89
  "model_name": model_name,
90
  "intervals_hash": _intervals_hash(intervals),
 
91
  }
92
  with open(path / "metadata.json", "w") as f:
93
  json.dump(meta, f)
@@ -180,6 +182,106 @@ def _load_segments(audio_id):
180
  _SESSION_ERROR = {"error": "Session not found or expired", "segments": []}
181
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def _format_response(audio_id, json_output, warning=None):
184
  """Convert pipeline json_output to the documented API response schema."""
185
  segments = []
 
7
 
8
  import hashlib
9
  import json
10
+ import math
11
  import os
12
  import pickle
13
  import re
 
89
  "intervals": intervals,
90
  "model_name": model_name,
91
  "intervals_hash": _intervals_hash(intervals),
92
+ "audio_duration_s": round(len(audio) / 16000, 2),
93
  }
94
  with open(path / "metadata.json", "w") as f:
95
  json.dump(meta, f)
 
182
  _SESSION_ERROR = {"error": "Session not found or expired", "segments": []}
183
 
184
 
185
+ # ---------------------------------------------------------------------------
186
+ # Duration estimation
187
+ # ---------------------------------------------------------------------------
188
+
189
+ _ESTIMABLE_ENDPOINTS = {
190
+ "process_audio_session",
191
+ "resegment_session",
192
+ "retranscribe_session",
193
+ "realign_from_timestamps",
194
+ "mfa_timestamps_session",
195
+ "mfa_timestamps_direct",
196
+ }
197
+
198
+ _MFA_ENDPOINTS = {"mfa_timestamps_session", "mfa_timestamps_direct"}
199
+ _VAD_ENDPOINTS = {"process_audio_session"}
200
+
201
+
202
+ def _load_session_metadata(audio_id):
203
+ """Load only metadata.json (no audio/VAD). Returns dict or None."""
204
+ if not _validate_id(audio_id):
205
+ return None
206
+ path = _session_dir(audio_id)
207
+ meta_path = path / "metadata.json"
208
+ if not meta_path.exists():
209
+ return None
210
+ ts_file = path / "created_at"
211
+ if not ts_file.exists() or _is_expired(float(ts_file.read_text())):
212
+ return None
213
+ with open(meta_path) as f:
214
+ return json.load(f)
215
+
216
+
217
+ def estimate_duration(endpoint, audio_duration_s=None, audio_id=None,
218
+ model_name="Base", device="GPU"):
219
+ """Estimate processing duration for a given endpoint."""
220
+ from config import (
221
+ get_vad_duration, get_asr_duration,
222
+ ESTIMATE_ALIGNMENT_OVERHEAD_S, ESTIMATE_CPU_MULTIPLIER,
223
+ MFA_PROGRESS_SEGMENT_RATE,
224
+ )
225
+
226
+ _error = {"estimated_duration_s": None}
227
+
228
+ if endpoint not in _ESTIMABLE_ENDPOINTS:
229
+ _error["error"] = (
230
+ f"Unknown endpoint '{endpoint}'. "
231
+ f"Valid: {', '.join(sorted(_ESTIMABLE_ENDPOINTS))}"
232
+ )
233
+ return _error
234
+
235
+ # --- Resolve audio duration ---
236
+ meta = None
237
+ if audio_id:
238
+ meta = _load_session_metadata(audio_id)
239
+
240
+ if audio_duration_s is not None and audio_duration_s > 0:
241
+ duration_s = float(audio_duration_s)
242
+ elif meta and meta.get("audio_duration_s"):
243
+ duration_s = meta["audio_duration_s"]
244
+ else:
245
+ _error["error"] = (
246
+ "audio_duration_s is required (or provide audio_id with an existing session)"
247
+ )
248
+ return _error
249
+
250
+ minutes = duration_s / 60.0
251
+
252
+ # --- MFA endpoints require session with stored segments ---
253
+ if endpoint in _MFA_ENDPOINTS:
254
+ if not audio_id:
255
+ _error["error"] = "MFA estimation requires audio_id with existing segments"
256
+ return _error
257
+ segments = _load_segments(audio_id)
258
+ if not segments:
259
+ _error["error"] = "No segments found in session — run an alignment endpoint first"
260
+ return _error
261
+ num_segments = len(segments)
262
+ estimate = MFA_PROGRESS_SEGMENT_RATE * num_segments
263
+ else:
264
+ # --- Pipeline endpoints: VAD + ASR + alignment overhead ---
265
+ estimate = 0.0
266
+ if endpoint in _VAD_ENDPOINTS:
267
+ estimate += get_vad_duration(minutes)
268
+ estimate += get_asr_duration(minutes, model_name)
269
+ estimate += ESTIMATE_ALIGNMENT_OVERHEAD_S
270
+
271
+ # --- CPU multiplier ---
272
+ if device == "CPU":
273
+ estimate *= ESTIMATE_CPU_MULTIPLIER
274
+
275
+ rounded = math.ceil(estimate / 5) * 5
276
+
277
+ return {
278
+ "endpoint": endpoint,
279
+ "estimated_duration_s": rounded,
280
+ "device": device,
281
+ "model_name": model_name,
282
+ }
283
+
284
+
285
  def _format_response(audio_id, json_output, warning=None):
286
  """Convert pipeline json_output to the documented API response schema."""
287
  segments = []
src/ui/event_wiring.py CHANGED
@@ -8,6 +8,7 @@ from src.pipeline import (
8
  _retranscribe_wrapper, save_json_export,
9
  )
10
  from src.api.session_api import (
 
11
  process_audio_session, resegment_session,
12
  retranscribe_session, realign_from_timestamps,
13
  mfa_timestamps_session, mfa_timestamps_direct,
@@ -461,6 +462,13 @@ def _wire_settings_restoration(app, c):
461
 
462
  def _wire_api_endpoint(c):
463
  """Hidden API-only endpoints for session-based programmatic access."""
 
 
 
 
 
 
 
464
  gr.Button(visible=False).click(
465
  fn=process_audio_session,
466
  inputs=[c.api_audio, c.api_silence, c.api_speech, c.api_pad,
 
8
  _retranscribe_wrapper, save_json_export,
9
  )
10
  from src.api.session_api import (
11
+ estimate_duration,
12
  process_audio_session, resegment_session,
13
  retranscribe_session, realign_from_timestamps,
14
  mfa_timestamps_session, mfa_timestamps_direct,
 
462
 
463
  def _wire_api_endpoint(c):
464
  """Hidden API-only endpoints for session-based programmatic access."""
465
+ gr.Button(visible=False).click(
466
+ fn=estimate_duration,
467
+ inputs=[c.api_estimate_endpoint, c.api_estimate_audio_duration,
468
+ c.api_audio_id, c.api_model, c.api_device],
469
+ outputs=[c.api_result],
470
+ api_name="estimate_duration",
471
+ )
472
  gr.Button(visible=False).click(
473
  fn=process_audio_session,
474
  inputs=[c.api_audio, c.api_silence, c.api_speech, c.api_pad,
src/ui/interface.py CHANGED
@@ -89,6 +89,8 @@ def build_interface():
89
  c.api_timestamps = gr.JSON(visible=False)
90
  c.api_mfa_segments = gr.JSON(visible=False)
91
  c.api_mfa_granularity = gr.Textbox(visible=False)
 
 
92
  c.api_result = gr.JSON(visible=False)
93
 
94
  wire_events(app, c)
 
89
  c.api_timestamps = gr.JSON(visible=False)
90
  c.api_mfa_segments = gr.JSON(visible=False)
91
  c.api_mfa_granularity = gr.Textbox(visible=False)
92
+ c.api_estimate_endpoint = gr.Textbox(visible=False)
93
+ c.api_estimate_audio_duration = gr.Number(visible=False)
94
  c.api_result = gr.JSON(visible=False)
95
 
96
  wire_events(app, c)