bichnhan2701 commited on
Commit
c816b75
·
1 Parent(s): 7be097f

Update api transcribe

Browse files
Files changed (1) hide show
  1. app/api/transcribe.py +157 -217
app/api/transcribe.py CHANGED
@@ -1,289 +1,229 @@
1
  import os
2
- import logging
3
  import uuid
 
4
  import asyncio
5
- from fastapi import APIRouter, UploadFile, File, HTTPException, status
6
- from fastapi.responses import JSONResponse
7
  from pathlib import Path
8
- from typing import Optional
9
- import time
10
- from app.core.audio_utils import (
11
- save_upload_file,
12
- get_audio_info,
13
- ensure_wav_16k_mono,
14
- make_temp_path,
15
- download_file_from_url
16
- )
17
- from app.core.asr_engine import (
18
- load_model,
19
- transcribe_file,
20
- transcribe_file_chunks
21
- )
22
- from app.config import settings
23
- from app.services.note_client import NoteServiceClient
24
  from rq import Queue
 
 
25
  from app.infra.redis_client import redis_client
26
- from app.jobs.transcribe_job import transcribe_job
27
  from app.schemas.transcribe import TranscribeResponse
28
- from app.infra.metrics import (
29
- REQUEST_COUNT,
30
- REQUEST_LATENCY,
31
- ASR_DURATION,
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
 
34
  router = APIRouter()
35
  ASR_MODEL = None
 
36
 
 
 
 
 
37
  @router.on_event("startup")
38
- async def _startup():
39
  global ASR_MODEL
40
- # load model in thread to avoid blocking event loop
41
  ASR_MODEL = await asyncio.to_thread(load_model, 30)
42
 
 
 
 
 
43
  def _ensure_file_limits(path: str):
44
  if os.path.getsize(path) > settings.MAX_UPLOAD_BYTES:
45
- raise HTTPException(
46
- status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
47
- detail="File size exceeds limit",
48
- )
49
  info = get_audio_info(path)
50
  if info and info.get("duration", 0) > settings.MAX_DURATION_SECS:
51
- raise HTTPException(
52
- status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
53
- detail="Audio duration exceeds limit",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  )
55
-
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  @router.post("/transcribe", response_model=TranscribeResponse)
58
  async def transcribe(file: UploadFile = File(...)):
59
- tmp_in = make_temp_path(suffix=Path(file.filename).suffix or ".wav")
60
- tmp_wav = None
61
- note_service = NoteServiceClient()
62
  note_id = str(uuid.uuid4())
63
 
64
- start_time = time.perf_counter()
65
- endpoint = "/transcribe"
66
- status_label = "success"
67
 
68
  with REQUEST_LATENCY.labels(endpoint).time():
69
  try:
70
- # write upload to tmp (blocking) -> run in thread
71
  await asyncio.to_thread(save_upload_file, file, tmp_in)
72
-
73
  _ensure_file_limits(tmp_in)
74
 
 
75
  tmp_wav = make_temp_path(suffix=".wav")
76
- # ffmpeg convert is blocking -> run in thread
77
  await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
78
 
79
- # Kiểm tra duration để quyết định xử lý sync hay async
80
  info = get_audio_info(tmp_wav) or {}
81
- duration_sec = info.get("duration", 0)
82
- ASYNC_THRESHOLD = 120 # 2 phút, có thể chỉnh
83
- # ---------- ASYNC JOB ----------
84
- if duration_sec > ASYNC_THRESHOLD:
85
- # Enqueue background job bằng RQ
86
- q = Queue("asr", connection=redis_client)
87
- job = q.enqueue(
88
- transcribe_job,
89
- tmp_wav,
90
- note_id,
91
- job_timeout=1800
92
- )
93
- logging.info(f"Enqueued background transcribe job: note_id={note_id} job_id={job.id} duration={duration_sec:.1f}s")
94
  REQUEST_COUNT.labels(endpoint, "queued").inc()
95
- return JSONResponse(status_code=202, content={
96
- "note_id": note_id,
97
- "job_id": job.id,
98
- "status": "queued",
99
- "duration": duration_sec
100
- })
101
- # ---------- SYNC PIPELINE ----------
102
- # Nếu audio ngắn, xử lý sync như cũ
103
- model = ASR_MODEL or await asyncio.to_thread(load_model, 30)
104
- with ASR_DURATION.labels(endpoint).time():
105
- text = await asyncio.to_thread(transcribe_file, model, tmp_wav, 30.0, 5.0)
106
- chunks = await asyncio.to_thread(transcribe_file_chunks, model, tmp_wav, 30.0, 5.0)
107
- # 🔥 DROP invalid chunks
108
- chunks = [
109
- c for c in chunks
110
- if c.get("text", "").strip() and c.get("end", 0) > c.get("start", 0)
111
- ]
112
- note_status = "transcribed" if chunks and any(c.get("text", "").strip() for c in chunks) else "error"
113
-
114
- info2 = get_audio_info(tmp_wav) or {}
115
- # persist to Note Service (async HTTP)
116
- payload = {
117
- "note_id": note_id,
118
- "type": "audio",
119
- "status": note_status,
120
- "raw_text": text,
121
- "metadata": {
122
- "audio": {
123
- "duration": info2.get("duration"),
124
- "sample_rate": info2.get("samplerate"),
125
- "chunks": chunks,
126
- "asr_model": "PhoWhisper-base"
127
- }
128
- },
129
- "generate": ["normalize", "keywords", "summary", "mindmap"]
130
- }
131
- logging.info(
132
- "Create audio note note_id=%s status=%s chunks=%d text_len=%d",
133
- note_id,
134
- note_status,
135
- len(chunks) if chunks else 0,
136
- len(text or ""),
137
- )
138
- await note_service.create_audio_note(payload)
139
-
140
- duration = time.perf_counter() - start_time
141
- logging.info(f"/transcribe success note_id={note_id} duration={duration:.2f}s audio_dur={info2.get('duration')}")
142
- REQUEST_COUNT.labels(endpoint, status_label).inc()
143
- return JSONResponse(
144
- status_code=200,
145
- content={
146
- "note_id": note_id,
147
- "status": note_status,
148
- "duration": info2.get("duration"),
149
- },
150
- )
151
-
152
  finally:
153
- # cleanup
154
- for p in [tmp_in, tmp_wav]:
155
- try:
156
- if p and os.path.exists(p):
157
- os.remove(p)
158
- except Exception:
159
- pass
160
 
 
 
 
 
161
  @router.post("/transcribe-url", response_model=TranscribeResponse)
162
  async def transcribe_url(payload: dict):
 
 
 
163
  audio_url = payload.get("audio_url")
164
  user_id = payload.get("user_id")
165
 
166
  if not audio_url:
167
- raise HTTPException(status_code=400, detail="audio_url required")
168
- if not user_id:
169
- raise HTTPException(status_code=400, detail="user_id required")
170
 
171
- tmp_in = make_temp_path(suffix=Path(audio_url).suffix or ".tmp")
172
- tmp_wav = None
173
  note_id = str(uuid.uuid4())
174
- note_service = NoteServiceClient()
175
 
176
- endpoint = "/transcribe-url"
177
- start_time = time.perf_counter()
178
- status_label = "success"
179
 
180
  with REQUEST_LATENCY.labels(endpoint).time():
181
  try:
182
- # 1. Download from Cloudinary (blocking)
183
  await asyncio.to_thread(download_file_from_url, audio_url, tmp_in)
184
-
185
- # 2. File & duration limits
186
  _ensure_file_limits(tmp_in)
187
 
188
- # 3. Convert to wav 16k mono
189
  tmp_wav = make_temp_path(suffix=".wav")
190
  await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
191
 
192
- # 4. Check duration for sync / async
193
  info = get_audio_info(tmp_wav) or {}
194
- duration_sec = info.get("duration", 0)
195
- ASYNC_THRESHOLD = 120 # seconds
196
-
197
- # ---------- ASYNC JOB ----------
198
- if duration_sec > ASYNC_THRESHOLD:
199
- q = Queue("asr", connection=redis_client)
200
- job = q.enqueue(
201
- transcribe_job,
202
- tmp_wav,
203
- note_id,
204
- job_timeout=1800,
205
- )
206
 
207
- logging.info(
208
- f"/transcribe-url queued note_id={note_id} "
209
- f"job_id={job.id} duration={duration_sec:.1f}s"
210
- )
211
- REQUEST_COUNT.labels(endpoint, "queued").inc()
212
 
 
213
  return JSONResponse(
214
  status_code=202,
215
  content={
216
  "note_id": note_id,
217
  "job_id": job.id,
218
  "status": "queued",
219
- "duration": duration_sec,
220
  },
221
  )
222
 
223
- # ---------- SYNC PIPELINE ----------
224
- model = ASR_MODEL or await asyncio.to_thread(load_model, 30)
225
-
226
- with ASR_DURATION.labels(endpoint).time():
227
- text = await asyncio.to_thread(
228
- transcribe_file, model, tmp_wav, 30.0, 5.0
229
- )
230
- chunks = await asyncio.to_thread(
231
- transcribe_file_chunks, model, tmp_wav, 30.0, 5.0
232
- )
233
- # 🔥 DROP invalid chunks
234
- chunks = [
235
- c for c in chunks
236
- if c.get("text", "").strip() and c.get("end", 0) > c.get("start", 0)
237
- ]
238
-
239
- note_status = "transcribed" if chunks and any(c.get("text", "").strip() for c in chunks) else "error"
240
-
241
- # 5. Persist to Note Service
242
- payload = {
243
- "note_id": note_id,
244
- "type": "audio",
245
- "status": note_status,
246
- "raw_text": text,
247
- "metadata": {
248
- "audio": {
249
- "duration": info.get("duration"),
250
- "sample_rate": info.get("samplerate"),
251
- "chunks": chunks,
252
- "asr_model": "PhoWhisper-base"
253
- }
254
- },
255
- "generate": ["normalize", "keywords", "summary", "mindmap"]
256
- }
257
 
258
- logging.info(
259
- "Create audio note note_id=%s status=%s chunks=%d text_len=%d",
260
- note_id,
261
- note_status,
262
- len(chunks) if chunks else 0,
263
- len(text or ""),
264
- )
265
- await note_service.create_audio_note(payload)
266
-
267
- duration = time.perf_counter() - start_time
268
- logging.info(
269
- f"/transcribe-url success note_id={note_id} "
270
- f"duration={duration:.2f}s audio_dur={info.get('duration')}"
271
- )
272
-
273
- REQUEST_COUNT.labels(endpoint, status_label).inc()
274
- return JSONResponse(
275
- status_code=200,
276
- content={
277
- "note_id": note_id,
278
- "status": note_status,
279
- "duration": info.get("duration"),
280
- },
281
- )
282
 
283
  finally:
284
- for p in [tmp_in, tmp_wav]:
285
- try:
286
- if p and os.path.exists(p):
287
- os.remove(p)
288
- except Exception:
289
- pass
 
1
  import os
 
2
  import uuid
3
+ import time
4
  import asyncio
5
+ import logging
 
6
  from pathlib import Path
7
+
8
+ from fastapi import APIRouter, UploadFile, File, HTTPException
9
+ from fastapi.responses import JSONResponse
10
+
 
 
 
 
 
 
 
 
 
 
 
 
11
  from rq import Queue
12
+
13
+ from app.config import settings
14
  from app.infra.redis_client import redis_client
15
+ from app.infra.metrics import REQUEST_COUNT, REQUEST_LATENCY, ASR_DURATION
16
  from app.schemas.transcribe import TranscribeResponse
17
+ from app.services.note_client import NoteServiceClient
18
+ from app.jobs.transcribe_job import transcribe_job
19
+
20
+ from app.core.audio_utils import (
21
+ save_upload_file,
22
+ download_file_from_url,
23
+ ensure_wav_16k_mono,
24
+ make_temp_path,
25
+ get_audio_info,
26
+ upload_temp_audio,
27
+ )
28
+ from app.core.asr_engine import (
29
+ load_model,
30
+ transcribe_file,
31
+ transcribe_file_chunks,
32
  )
33
 
34
  router = APIRouter()
35
  ASR_MODEL = None
36
+ ASYNC_THRESHOLD = 120 # seconds
37
 
38
+
39
+ # ============================================================
40
+ # Startup: load ASR model once
41
+ # ============================================================
42
  @router.on_event("startup")
43
+ async def startup():
44
  global ASR_MODEL
 
45
  ASR_MODEL = await asyncio.to_thread(load_model, 30)
46
 
47
+
48
+ # ============================================================
49
+ # Utils
50
+ # ============================================================
51
  def _ensure_file_limits(path: str):
52
  if os.path.getsize(path) > settings.MAX_UPLOAD_BYTES:
53
+ raise HTTPException(413, "File size exceeds limit")
54
+
 
 
55
  info = get_audio_info(path)
56
  if info and info.get("duration", 0) > settings.MAX_DURATION_SECS:
57
+ raise HTTPException(413, "Audio duration exceeds limit")
58
+
59
+
60
+ async def _run_sync_pipeline(
61
+ tmp_wav: str,
62
+ note_id: str,
63
+ ):
64
+ """
65
+ Run sync ASR + persist to Note Service
66
+ """
67
+ note_service = NoteServiceClient()
68
+ info = get_audio_info(tmp_wav) or {}
69
+
70
+ model = ASR_MODEL
71
+ with ASR_DURATION.labels("/transcribe").time():
72
+ text = await asyncio.to_thread(
73
+ transcribe_file, model, tmp_wav, 30.0, 5.0
74
+ )
75
+ chunks = await asyncio.to_thread(
76
+ transcribe_file_chunks, model, tmp_wav, 30.0, 5.0
77
  )
 
78
 
79
+ chunks = [c for c in chunks if c.get("text", "").strip()]
80
+ status = "transcribed" if chunks else "error"
81
+
82
+ payload = {
83
+ "note_id": note_id,
84
+ "type": "audio",
85
+ "status": status,
86
+ "raw_text": text,
87
+ "metadata": {
88
+ "audio": {
89
+ "duration": info.get("duration"),
90
+ "sample_rate": info.get("samplerate"),
91
+ "chunks": chunks,
92
+ "asr_model": "PhoWhisper-base",
93
+ }
94
+ },
95
+ "generate": ["normalize", "keywords", "summary", "mindmap"],
96
+ }
97
+
98
+ await note_service.create_audio_note(payload)
99
+
100
+ return {
101
+ "note_id": note_id,
102
+ "status": status,
103
+ "duration": info.get("duration"),
104
+ }
105
+
106
+
107
+ def _enqueue_async_job(audio_url: str, note_id: str, user_id: str | None = None):
108
+ q = Queue("asr", connection=redis_client)
109
+ job = q.enqueue(
110
+ transcribe_job,
111
+ audio_url,
112
+ note_id,
113
+ user_id,
114
+ job_timeout=1800,
115
+ )
116
+ return job
117
+
118
+
119
+ # ============================================================
120
+ # POST /transcribe (UPLOAD FILE)
121
+ # ============================================================
122
  @router.post("/transcribe", response_model=TranscribeResponse)
123
  async def transcribe(file: UploadFile = File(...)):
124
+ endpoint = "/transcribe"
125
+ start = time.perf_counter()
 
126
  note_id = str(uuid.uuid4())
127
 
128
+ tmp_in = make_temp_path(suffix=Path(file.filename).suffix or ".tmp")
129
+ tmp_wav = None
 
130
 
131
  with REQUEST_LATENCY.labels(endpoint).time():
132
  try:
133
+ # 1. Save upload
134
  await asyncio.to_thread(save_upload_file, file, tmp_in)
 
135
  _ensure_file_limits(tmp_in)
136
 
137
+ # 2. Convert
138
  tmp_wav = make_temp_path(suffix=".wav")
 
139
  await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
140
 
 
141
  info = get_audio_info(tmp_wav) or {}
142
+ duration = info.get("duration", 0)
143
+
144
+ # ---------- ASYNC ----------
145
+ if duration > ASYNC_THRESHOLD:
146
+ audio_url = await asyncio.to_thread(upload_temp_audio, tmp_wav)
147
+ job = _enqueue_async_job(audio_url, note_id)
148
+
 
 
 
 
 
 
149
  REQUEST_COUNT.labels(endpoint, "queued").inc()
150
+ return JSONResponse(
151
+ status_code=202,
152
+ content={
153
+ "note_id": note_id,
154
+ "job_id": job.id,
155
+ "status": "queued",
156
+ "duration": duration,
157
+ },
158
+ )
159
+
160
+ # ---------- SYNC ----------
161
+ result = await _run_sync_pipeline(tmp_wav, note_id)
162
+
163
+ REQUEST_COUNT.labels(endpoint, "success").inc()
164
+ return result
165
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  finally:
167
+ for p in (tmp_in, tmp_wav):
168
+ if p and os.path.exists(p):
169
+ os.remove(p)
 
 
 
 
170
 
171
+
172
+ # ============================================================
173
+ # POST /transcribe-url (FULL LOGIC, same as /transcribe)
174
+ # ============================================================
175
  @router.post("/transcribe-url", response_model=TranscribeResponse)
176
  async def transcribe_url(payload: dict):
177
+ endpoint = "/transcribe-url"
178
+ start = time.perf_counter()
179
+
180
  audio_url = payload.get("audio_url")
181
  user_id = payload.get("user_id")
182
 
183
  if not audio_url:
184
+ raise HTTPException(400, "audio_url required")
 
 
185
 
 
 
186
  note_id = str(uuid.uuid4())
 
187
 
188
+ tmp_in = make_temp_path(suffix=Path(audio_url).suffix or ".tmp")
189
+ tmp_wav = None
 
190
 
191
  with REQUEST_LATENCY.labels(endpoint).time():
192
  try:
193
+ # 1. Download audio
194
  await asyncio.to_thread(download_file_from_url, audio_url, tmp_in)
 
 
195
  _ensure_file_limits(tmp_in)
196
 
197
+ # 2. Convert
198
  tmp_wav = make_temp_path(suffix=".wav")
199
  await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
200
 
 
201
  info = get_audio_info(tmp_wav) or {}
202
+ duration = info.get("duration", 0)
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ # ---------- ASYNC ----------
205
+ if duration > ASYNC_THRESHOLD:
206
+ # use ORIGINAL url for async job
207
+ job = _enqueue_async_job(audio_url, note_id, user_id)
 
208
 
209
+ REQUEST_COUNT.labels(endpoint, "queued").inc()
210
  return JSONResponse(
211
  status_code=202,
212
  content={
213
  "note_id": note_id,
214
  "job_id": job.id,
215
  "status": "queued",
216
+ "duration": duration,
217
  },
218
  )
219
 
220
+ # ---------- SYNC ----------
221
+ result = await _run_sync_pipeline(tmp_wav, note_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ REQUEST_COUNT.labels(endpoint, "success").inc()
224
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  finally:
227
+ for p in (tmp_in, tmp_wav):
228
+ if p and os.path.exists(p):
229
+ os.remove(p)