bichnhan2701 commited on
Commit
bbf158e
·
1 Parent(s): 75ad4a0

Debug timeout

Browse files
app/api/transcribe.py CHANGED
@@ -29,6 +29,7 @@ from app.core.asr_engine import (
29
  load_model,
30
  transcribe_file,
31
  transcribe_file_chunks,
 
32
  )
33
 
34
  router = APIRouter()
@@ -58,14 +59,26 @@ def _ensure_file_limits(path: str):
58
  raise HTTPException(413, "Audio duration exceeds limit")
59
 
60
 
61
- def _enqueue_async_job(audio_url: str, note_id: str, user_id: str | None = None):
 
 
 
 
 
 
 
 
 
 
62
  q = Queue("asr", connection=redis_client)
 
 
63
  return q.enqueue(
64
  transcribe_job,
65
  audio_url,
66
  note_id,
67
  user_id,
68
- job_timeout=1800,
69
  retry=Retry(max=3, interval=[2, 5, 10]),
70
  )
71
 
@@ -74,16 +87,15 @@ def _enqueue_async_job(audio_url: str, note_id: str, user_id: str | None = None)
74
  async def _run_sync_pipeline(tmp_wav: str, note_id: str, audio_url: str | None = None):
75
  """
76
  Sync ASR → update existing note
 
77
  """
78
  note_service = NoteServiceClient()
79
  info = get_audio_info(tmp_wav) or {}
80
 
81
  with ASR_DURATION.labels("/transcribe").time():
82
- text = await asyncio.to_thread(
83
- transcribe_file, ASR_MODEL, tmp_wav, 30.0, 5.0
84
- )
85
- chunks = await asyncio.to_thread(
86
- transcribe_file_chunks, ASR_MODEL, tmp_wav, 30.0, 5.0
87
  )
88
 
89
  chunks = [
@@ -190,7 +202,7 @@ async def transcribe(file: UploadFile = File(...)):
190
  audio_url = await asyncio.to_thread(upload_temp_audio, tmp_wav)
191
 
192
  await _create_placeholder_note(note_id, duration, audio_url)
193
- job = _enqueue_async_job(audio_url, note_id)
194
 
195
  REQUEST_COUNT.labels(endpoint, "queued").inc()
196
  return JSONResponse(
@@ -252,7 +264,7 @@ async def transcribe_url(payload: dict):
252
  # ---------- ASYNC ----------
253
  if duration > ASYNC_THRESHOLD:
254
  await _create_placeholder_note(note_id, duration, audio_url)
255
- job = _enqueue_async_job(audio_url, note_id, user_id)
256
 
257
  REQUEST_COUNT.labels(endpoint, "queued").inc()
258
  return JSONResponse(
 
29
  load_model,
30
  transcribe_file,
31
  transcribe_file_chunks,
32
+ transcribe_file_unified,
33
  )
34
 
35
  router = APIRouter()
 
59
  raise HTTPException(413, "Audio duration exceeds limit")
60
 
61
 
62
+ def _calculate_job_timeout(duration: float) -> int:
63
+ """
64
+ Calculate dynamic job timeout based on audio duration.
65
+ Formula: max(1800, duration * 3 + 300)
66
+ - Minimum 30 minutes
67
+ - ~3x realtime + 5 min buffer for long audio
68
+ """
69
+ return max(1800, int(duration * 3) + 300)
70
+
71
+
72
+ def _enqueue_async_job(audio_url: str, note_id: str, user_id: str | None = None, duration: float = 0, duration: float = 0):
73
  q = Queue("asr", connection=redis_client)
74
+ job_timeout = _calculate_job_timeout(duration)
75
+ logger.info("[ASR] Enqueuing job for note=%s, duration=%.2fs, timeout=%ds", note_id, duration, job_timeout)
76
  return q.enqueue(
77
  transcribe_job,
78
  audio_url,
79
  note_id,
80
  user_id,
81
+ job_timeout=job_timeout,
82
  retry=Retry(max=3, interval=[2, 5, 10]),
83
  )
84
 
 
87
  async def _run_sync_pipeline(tmp_wav: str, note_id: str, audio_url: str | None = None):
88
  """
89
  Sync ASR → update existing note
90
+ 🔥 FIX: Use unified function to avoid double inference
91
  """
92
  note_service = NoteServiceClient()
93
  info = get_audio_info(tmp_wav) or {}
94
 
95
  with ASR_DURATION.labels("/transcribe").time():
96
+ # 🔥 SINGLE INFERENCE - returns both text and chunks
97
+ text, chunks = await asyncio.to_thread(
98
+ transcribe_file_unified, ASR_MODEL, tmp_wav, 30.0, 5.0
 
 
99
  )
100
 
101
  chunks = [
 
202
  audio_url = await asyncio.to_thread(upload_temp_audio, tmp_wav)
203
 
204
  await _create_placeholder_note(note_id, duration, audio_url)
205
+ job = _enqueue_async_job(audio_url, note_id, duration=duration)
206
 
207
  REQUEST_COUNT.labels(endpoint, "queued").inc()
208
  return JSONResponse(
 
264
  # ---------- ASYNC ----------
265
  if duration > ASYNC_THRESHOLD:
266
  await _create_placeholder_note(note_id, duration, audio_url)
267
+ job = _enqueue_async_job(audio_url, note_id, user_id, duration=duration)
268
 
269
  REQUEST_COUNT.labels(endpoint, "queued").inc()
270
  return JSONResponse(
app/core/asr_engine.py CHANGED
@@ -1,12 +1,12 @@
1
  import logging
2
- from typing import List, Dict
 
3
 
4
  import torch
5
  from transformers import pipeline
6
  from transformers import logging as transformers_logging
7
  import warnings
8
  import os
9
- from typing import Tuple
10
 
11
  from app.core.chunking import split_audio_to_chunks
12
  from app.core.audio_utils import get_audio_info
@@ -60,50 +60,105 @@ def load_model(chunk_length_s: float = 30.0):
60
  # ===============================
61
  # Transcribe full text
62
  # ===============================
63
- def transcribe_file(
64
  model,
65
  wav_path: str,
66
  chunk_length_s: float = 30.0,
67
  stride_s: float = 5.0,
68
- ) -> str:
69
  """
70
- Return full transcript text.
 
 
 
 
71
  """
72
  if not wav_path:
73
- return ""
 
 
 
74
 
75
  # If audio is long, prefer chunked inference to avoid memory/time issues
76
  info = get_audio_info(wav_path) or {}
77
  duration = info.get("duration", 0)
 
 
78
  if duration and duration > chunk_length_s:
79
  try:
80
- text, _chunks = transcribe_long_audio(
81
  model, wav_path, chunk_length_s=chunk_length_s, overlap_s=stride_s
82
  )
83
- return text
 
 
84
  except Exception:
85
  logger.exception("transcribe_long_audio failed, falling back to pipeline")
86
 
 
87
  out = model(
88
  wav_path,
89
  chunk_length_s=chunk_length_s,
90
  stride_length_s=stride_s,
91
- # return_timestamps may be ignored for full-text outputs but safe to pass
92
  )
93
 
94
- # Primary: pipeline may return 'text'
95
  text = (out.get("text") or "").strip()
96
- if text:
97
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- # Fallback: some pipeline versions return detailed segments/chunks
100
- segs = out.get("chunks") or out.get("segments") or []
101
- if segs:
102
- parts = [ (s.get("text") or "").strip() for s in segs ]
103
- joined = " ".join([p for p in parts if p])
104
- return joined.strip()
105
 
106
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
 
109
  def transcribe_long_audio(
@@ -119,14 +174,17 @@ def transcribe_long_audio(
119
  if not wav_path:
120
  return "", []
121
 
 
 
122
  # prefer VAD-based splitting if available
123
  try:
124
  from app.core.chunking import split_audio_with_vad
125
  chunk_paths = split_audio_with_vad(wav_path)
126
- except Exception:
 
 
127
  chunk_paths = split_audio_to_chunks(wav_path, chunk_length_s=chunk_length_s, overlap_s=overlap_s)
128
-
129
- logger.debug("transcribe_long_audio: split into %d chunk_paths", len(chunk_paths))
130
  combined_text_parts = []
131
  combined_chunks: List[Dict] = []
132
 
@@ -230,58 +288,7 @@ def transcribe_file_chunks(
230
  """
231
  Return list of chunks:
232
  [{ start, end, text }]
 
233
  """
234
- if not wav_path:
235
- return []
236
- # For long audio prefer explicit chunked inference (split + per-chunk inference)
237
- info = get_audio_info(wav_path) or {}
238
- duration = info.get("duration", 0)
239
- if duration and duration > chunk_length_s:
240
- try:
241
- _, combined = transcribe_long_audio(
242
- model, wav_path, chunk_length_s=chunk_length_s, overlap_s=stride_s
243
- )
244
- return combined
245
- except Exception:
246
- logger.exception("transcribe_long_audio failed in transcribe_file_chunks, falling back to pipeline")
247
-
248
- out = model(
249
- wav_path,
250
- chunk_length_s=chunk_length_s,
251
- stride_length_s=stride_s,
252
- return_timestamps=True,
253
- )
254
-
255
- # Pipeline output can vary across transformers versions/models:
256
- # - some return `chunks` (with `timestamp` list),
257
- # - others return `segments` (with `start`/end),
258
- # so be permissive and handle both shapes.
259
- raw_segments = out.get("chunks") or out.get("segments") or []
260
-
261
- chunks = []
262
- for c in raw_segments:
263
- # try multiple timestamp shapes
264
- start = None
265
- end = None
266
-
267
- if isinstance(c.get("timestamp"), (list, tuple)) and len(c.get("timestamp")) >= 2:
268
- ts = c.get("timestamp")
269
- start, end = ts[0], ts[1]
270
- elif c.get("start") is not None and c.get("end") is not None:
271
- start, end = c.get("start"), c.get("end")
272
-
273
- text = (c.get("text") or "").strip()
274
- if not text:
275
- continue
276
-
277
- # If timestamps are missing, skip (we don't want chunks without timing)
278
- if start is None or end is None:
279
- continue
280
-
281
- try:
282
- chunks.append({"start": float(start), "end": float(end), "text": text})
283
- except Exception:
284
- # be robust against unexpected types
285
- continue
286
-
287
  return chunks
 
1
  import logging
2
+ import time
3
+ from typing import List, Dict, Tuple
4
 
5
  import torch
6
  from transformers import pipeline
7
  from transformers import logging as transformers_logging
8
  import warnings
9
  import os
 
10
 
11
  from app.core.chunking import split_audio_to_chunks
12
  from app.core.audio_utils import get_audio_info
 
60
  # ===============================
61
  # Transcribe full text
62
  # ===============================
63
+ def transcribe_file_unified(
64
  model,
65
  wav_path: str,
66
  chunk_length_s: float = 30.0,
67
  stride_s: float = 5.0,
68
+ ) -> Tuple[str, List[Dict]]:
69
  """
70
+ 🔥 UNIFIED: Return both full transcript text AND timestamped chunks in ONE inference pass.
71
+ This avoids the costly double-inference that was causing timeouts.
72
+
73
+ Returns:
74
+ (text, chunks) where chunks = [{"start": float, "end": float, "text": str}, ...]
75
  """
76
  if not wav_path:
77
+ return "", []
78
+
79
+ start_time = time.time()
80
+ logger.info("[ASR] Starting unified transcription for %s", wav_path)
81
 
82
  # If audio is long, prefer chunked inference to avoid memory/time issues
83
  info = get_audio_info(wav_path) or {}
84
  duration = info.get("duration", 0)
85
+ logger.info("[ASR] Audio duration: %.2fs", duration)
86
+
87
  if duration and duration > chunk_length_s:
88
  try:
89
+ text, chunks = transcribe_long_audio(
90
  model, wav_path, chunk_length_s=chunk_length_s, overlap_s=stride_s
91
  )
92
+ elapsed = time.time() - start_time
93
+ logger.info("[ASR] Long audio transcription completed in %.2fs (%.2fx realtime)", elapsed, elapsed / duration if duration else 0)
94
+ return text, chunks
95
  except Exception:
96
  logger.exception("transcribe_long_audio failed, falling back to pipeline")
97
 
98
+ # Short audio: single pipeline call with timestamps
99
  out = model(
100
  wav_path,
101
  chunk_length_s=chunk_length_s,
102
  stride_length_s=stride_s,
103
+ return_timestamps=True,
104
  )
105
 
106
+ # Extract text
107
  text = (out.get("text") or "").strip()
108
+ if not text:
109
+ segs = out.get("chunks") or out.get("segments") or []
110
+ if segs:
111
+ parts = [(s.get("text") or "").strip() for s in segs]
112
+ text = " ".join([p for p in parts if p]).strip()
113
+
114
+ # Extract chunks with timestamps
115
+ chunks = _extract_chunks_from_output(out)
116
+
117
+ elapsed = time.time() - start_time
118
+ logger.info("[ASR] Short audio transcription completed in %.2fs", elapsed)
119
+
120
+ return text, chunks
121
+
122
+
123
+ def _extract_chunks_from_output(out: dict) -> List[Dict]:
124
+ """Extract timestamped chunks from model output."""
125
+ raw_segments = out.get("chunks") or out.get("segments") or []
126
+ chunks = []
127
+
128
+ for c in raw_segments:
129
+ start = None
130
+ end = None
131
+
132
+ if isinstance(c.get("timestamp"), (list, tuple)) and len(c.get("timestamp")) >= 2:
133
+ ts = c.get("timestamp")
134
+ start, end = ts[0], ts[1]
135
+ elif c.get("start") is not None and c.get("end") is not None:
136
+ start, end = c.get("start"), c.get("end")
137
 
138
+ text = (c.get("text") or "").strip()
139
+ if not text or start is None or end is None:
140
+ continue
 
 
 
141
 
142
+ try:
143
+ chunks.append({"start": float(start), "end": float(end), "text": text})
144
+ except Exception:
145
+ continue
146
+
147
+ return chunks
148
+
149
+
150
+ def transcribe_file(
151
+ model,
152
+ wav_path: str,
153
+ chunk_length_s: float = 30.0,
154
+ stride_s: float = 5.0,
155
+ ) -> str:
156
+ """
157
+ Return full transcript text.
158
+ ⚠️ DEPRECATED: Use transcribe_file_unified() to get both text and chunks in one pass.
159
+ """
160
+ text, _ = transcribe_file_unified(model, wav_path, chunk_length_s, stride_s)
161
+ return text
162
 
163
 
164
  def transcribe_long_audio(
 
174
  if not wav_path:
175
  return "", []
176
 
177
+ split_start = time.time()
178
+
179
  # prefer VAD-based splitting if available
180
  try:
181
  from app.core.chunking import split_audio_with_vad
182
  chunk_paths = split_audio_with_vad(wav_path)
183
+ logger.info("[ASR] VAD split into %d chunks in %.2fs", len(chunk_paths), time.time() - split_start)
184
+ except Exception as e:
185
+ logger.warning("[ASR] VAD split failed (%s), using fixed windows", e)
186
  chunk_paths = split_audio_to_chunks(wav_path, chunk_length_s=chunk_length_s, overlap_s=overlap_s)
187
+ logger.info("[ASR] Fixed-window split into %d chunks in %.2fs", len(chunk_paths), time.time() - split_start)
 
188
  combined_text_parts = []
189
  combined_chunks: List[Dict] = []
190
 
 
288
  """
289
  Return list of chunks:
290
  [{ start, end, text }]
291
+ ⚠️ DEPRECATED: Use transcribe_file_unified() to get both text and chunks in one pass.
292
  """
293
+ _, chunks = transcribe_file_unified(model, wav_path, chunk_length_s, stride_s)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  return chunks
app/jobs/transcribe_job.py CHANGED
@@ -6,11 +6,13 @@ import requests
6
  import httpx
7
  import time
8
 
9
- from app.core.asr_engine import load_model, transcribe_file, transcribe_file_chunks
10
  from app.services.note_client import NoteServiceClient
11
  from app.core.audio_utils import get_audio_info
12
  from app.core.audio_utils import ensure_wav_16k_mono, make_temp_path
13
 
 
 
14
  def run_async(coro):
15
  try:
16
  loop = asyncio.get_running_loop()
@@ -35,19 +37,28 @@ def download_audio(audio_url: str) -> str:
35
 
36
 
37
  def transcribe_job(audio_url: str, note_id: str, user_id: str | None = None):
 
 
 
38
  model = load_model()
39
  wav_path = None
40
 
41
  try:
42
  # 1️⃣ Download audio
 
43
  wav_path = download_audio(audio_url)
 
44
 
45
  # Ensure WAV is 16k mono for consistent chunking and ASR behavior
46
  try:
47
  info = get_audio_info(wav_path) or {}
 
 
48
  if info.get("samplerate") != 16000 or info.get("channels") != 1:
 
49
  tmp_wav = make_temp_path(suffix=".wav")
50
  ensure_wav_16k_mono(wav_path, tmp_wav)
 
51
  # replace wav_path with converted file and remove original
52
  try:
53
  os.remove(wav_path)
@@ -55,11 +66,12 @@ def transcribe_job(audio_url: str, note_id: str, user_id: str | None = None):
55
  pass
56
  wav_path = tmp_wav
57
  except Exception:
58
- logging.exception("Failed to ensure wav format for %s", wav_path)
59
 
60
- # 2️⃣ ASR
61
- text = transcribe_file(model, wav_path, 30.0, 5.0)
62
- chunks = transcribe_file_chunks(model, wav_path, 30.0, 5.0)
 
63
 
64
  # normalize chunks list
65
  chunks = [
@@ -75,7 +87,7 @@ def transcribe_job(audio_url: str, note_id: str, user_id: str | None = None):
75
  duration = info.get("duration") or 0.0
76
  chunks = [{"text": text.strip(), "start": 0.0, "end": float(duration)}]
77
  except Exception:
78
- logging.exception("failed to create fallback chunk for note %s", note_id)
79
 
80
  # Consider transcribed if we have either timestamped chunks or non-empty text
81
  note_status = "transcribed" if (chunks or (text and text.strip())) else "error"
@@ -105,7 +117,7 @@ def transcribe_job(audio_url: str, note_id: str, user_id: str | None = None):
105
  try:
106
  payload["metadata"]["audio"]["url"] = audio_url
107
  except Exception:
108
- logging.exception("Failed to attach audio_url to payload for note %s", note_id)
109
 
110
  generate_tasks = (
111
  ["normalize", "keywords", "summary", "mindmap"]
@@ -128,9 +140,10 @@ def transcribe_job(audio_url: str, note_id: str, user_id: str | None = None):
128
  },
129
  )
130
  )
 
131
  except httpx.HTTPStatusError as e:
132
  if e.response.status_code == 404:
133
- logging.warning(
134
  "Note not found on update, will retry later note_id=%s",
135
  note_id,
136
  )
 
6
  import httpx
7
  import time
8
 
9
+ from app.core.asr_engine import load_model, transcribe_file, transcribe_file_chunks, transcribe_file_unified
10
  from app.services.note_client import NoteServiceClient
11
  from app.core.audio_utils import get_audio_info
12
  from app.core.audio_utils import ensure_wav_16k_mono, make_temp_path
13
 
14
+ logger = logging.getLogger(__name__)
15
+
16
  def run_async(coro):
17
  try:
18
  loop = asyncio.get_running_loop()
 
37
 
38
 
39
  def transcribe_job(audio_url: str, note_id: str, user_id: str | None = None):
40
+ job_start = time.time()
41
+ logger.info("[JOB] Starting transcribe_job for note=%s, url=%s", note_id, audio_url)
42
+
43
  model = load_model()
44
  wav_path = None
45
 
46
  try:
47
  # 1️⃣ Download audio
48
+ download_start = time.time()
49
  wav_path = download_audio(audio_url)
50
+ logger.info("[JOB] Downloaded audio in %.2fs", time.time() - download_start)
51
 
52
  # Ensure WAV is 16k mono for consistent chunking and ASR behavior
53
  try:
54
  info = get_audio_info(wav_path) or {}
55
+ logger.info("[JOB] Audio info: duration=%.2fs, samplerate=%s, channels=%s",
56
+ info.get("duration", 0), info.get("samplerate"), info.get("channels"))
57
  if info.get("samplerate") != 16000 or info.get("channels") != 1:
58
+ convert_start = time.time()
59
  tmp_wav = make_temp_path(suffix=".wav")
60
  ensure_wav_16k_mono(wav_path, tmp_wav)
61
+ logger.info("[JOB] Converted to 16k mono in %.2fs", time.time() - convert_start)
62
  # replace wav_path with converted file and remove original
63
  try:
64
  os.remove(wav_path)
 
66
  pass
67
  wav_path = tmp_wav
68
  except Exception:
69
+ logger.exception("Failed to ensure wav format for %s", wav_path)
70
 
71
+ # 2️⃣ ASR - 🔥 SINGLE INFERENCE using unified function
72
+ asr_start = time.time()
73
+ text, chunks = transcribe_file_unified(model, wav_path, 30.0, 5.0)
74
+ logger.info("[JOB] ASR completed in %.2fs", time.time() - asr_start)
75
 
76
  # normalize chunks list
77
  chunks = [
 
87
  duration = info.get("duration") or 0.0
88
  chunks = [{"text": text.strip(), "start": 0.0, "end": float(duration)}]
89
  except Exception:
90
+ logger.exception("failed to create fallback chunk for note %s", note_id)
91
 
92
  # Consider transcribed if we have either timestamped chunks or non-empty text
93
  note_status = "transcribed" if (chunks or (text and text.strip())) else "error"
 
117
  try:
118
  payload["metadata"]["audio"]["url"] = audio_url
119
  except Exception:
120
+ logger.exception("Failed to attach audio_url to payload for note %s", note_id)
121
 
122
  generate_tasks = (
123
  ["normalize", "keywords", "summary", "mindmap"]
 
140
  },
141
  )
142
  )
143
+ logger.info("[JOB] Completed note=%s in %.2fs, status=%s", note_id, time.time() - job_start, note_status)
144
  except httpx.HTTPStatusError as e:
145
  if e.response.status_code == 404:
146
+ logger.warning(
147
  "Note not found on update, will retry later note_id=%s",
148
  note_id,
149
  )