bichnhan2701 commited on
Commit
ca4e471
·
1 Parent(s): be6b69c

Update transcribe

Browse files
Files changed (2) hide show
  1. app/api/transcribe.py +73 -44
  2. app/jobs/transcribe_job.py +7 -2
app/api/transcribe.py CHANGED
@@ -7,7 +7,6 @@ from pathlib import Path
7
 
8
  from fastapi import APIRouter, UploadFile, File, HTTPException
9
  from fastapi.responses import JSONResponse
10
-
11
  from rq import Queue
12
 
13
  from app.config import settings
@@ -25,6 +24,7 @@ from app.core.audio_utils import (
25
  get_audio_info,
26
  upload_temp_audio,
27
  )
 
28
  from app.core.asr_engine import (
29
  load_model,
30
  transcribe_file,
@@ -35,6 +35,7 @@ router = APIRouter()
35
  ASR_MODEL = None
36
  ASYNC_THRESHOLD = 120 # seconds
37
 
 
38
 
39
  # ============================================================
40
  # Startup: load ASR model once
@@ -57,45 +58,60 @@ def _ensure_file_limits(path: str):
57
  raise HTTPException(413, "Audio duration exceeds limit")
58
 
59
 
60
- async def _run_sync_pipeline(
61
- tmp_wav: str,
62
- note_id: str,
63
- ):
 
 
 
 
 
 
 
 
64
  """
65
- Run sync ASR + persist to Note Service
66
  """
67
  note_service = NoteServiceClient()
68
  info = get_audio_info(tmp_wav) or {}
69
 
70
- model = ASR_MODEL
71
  with ASR_DURATION.labels("/transcribe").time():
72
  text = await asyncio.to_thread(
73
- transcribe_file, model, tmp_wav, 30.0, 5.0
74
  )
75
  chunks = await asyncio.to_thread(
76
- transcribe_file_chunks, model, tmp_wav, 30.0, 5.0
77
  )
78
 
79
- chunks = [c for c in chunks if c.get("text", "").strip()]
 
 
 
 
 
 
 
 
 
80
  status = "transcribed" if chunks else "error"
81
 
82
- payload = {
83
- "note_id": note_id,
84
- "type": "audio",
85
- "status": status,
86
- "raw_text": text,
87
- "metadata": {
88
- "audio": {
89
- "duration": info.get("duration"),
90
- "sample_rate": info.get("samplerate"),
91
- "chunks": chunks,
92
- "asr_model": "PhoWhisper-base",
93
- }
 
 
94
  },
95
- "generate": ["normalize", "keywords", "summary", "mindmap"],
96
- }
97
-
98
- await note_service.create_audio_note(payload)
99
 
100
  return {
101
  "note_id": note_id,
@@ -104,16 +120,28 @@ async def _run_sync_pipeline(
104
  }
105
 
106
 
107
- def _enqueue_async_job(audio_url: str, note_id: str, user_id: str | None = None):
108
- q = Queue("asr", connection=redis_client)
109
- job = q.enqueue(
110
- transcribe_job,
111
- audio_url,
112
- note_id,
113
- user_id,
114
- job_timeout=1800,
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  )
116
- return job
117
 
118
 
119
  # ============================================================
@@ -122,7 +150,6 @@ def _enqueue_async_job(audio_url: str, note_id: str, user_id: str | None = None)
122
  @router.post("/transcribe", response_model=TranscribeResponse)
123
  async def transcribe(file: UploadFile = File(...)):
124
  endpoint = "/transcribe"
125
- start = time.perf_counter()
126
  note_id = str(uuid.uuid4())
127
 
128
  tmp_in = make_temp_path(suffix=Path(file.filename).suffix or ".tmp")
@@ -130,11 +157,11 @@ async def transcribe(file: UploadFile = File(...)):
130
 
131
  with REQUEST_LATENCY.labels(endpoint).time():
132
  try:
133
- # 1. Save upload
134
  await asyncio.to_thread(save_upload_file, file, tmp_in)
135
  _ensure_file_limits(tmp_in)
136
 
137
- # 2. Convert
138
  tmp_wav = make_temp_path(suffix=".wav")
139
  await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
140
 
@@ -144,6 +171,8 @@ async def transcribe(file: UploadFile = File(...)):
144
  # ---------- ASYNC ----------
145
  if duration > ASYNC_THRESHOLD:
146
  audio_url = await asyncio.to_thread(upload_temp_audio, tmp_wav)
 
 
147
  job = _enqueue_async_job(audio_url, note_id)
148
 
149
  REQUEST_COUNT.labels(endpoint, "queued").inc()
@@ -158,6 +187,7 @@ async def transcribe(file: UploadFile = File(...)):
158
  )
159
 
160
  # ---------- SYNC ----------
 
161
  result = await _run_sync_pipeline(tmp_wav, note_id)
162
 
163
  REQUEST_COUNT.labels(endpoint, "success").inc()
@@ -170,12 +200,11 @@ async def transcribe(file: UploadFile = File(...)):
170
 
171
 
172
  # ============================================================
173
- # POST /transcribe-url (FULL LOGIC, same as /transcribe)
174
  # ============================================================
175
  @router.post("/transcribe-url", response_model=TranscribeResponse)
176
  async def transcribe_url(payload: dict):
177
  endpoint = "/transcribe-url"
178
- start = time.perf_counter()
179
 
180
  audio_url = payload.get("audio_url")
181
  user_id = payload.get("user_id")
@@ -184,17 +213,16 @@ async def transcribe_url(payload: dict):
184
  raise HTTPException(400, "audio_url required")
185
 
186
  note_id = str(uuid.uuid4())
187
-
188
  tmp_in = make_temp_path(suffix=Path(audio_url).suffix or ".tmp")
189
  tmp_wav = None
190
 
191
  with REQUEST_LATENCY.labels(endpoint).time():
192
  try:
193
- # 1. Download audio
194
  await asyncio.to_thread(download_file_from_url, audio_url, tmp_in)
195
  _ensure_file_limits(tmp_in)
196
 
197
- # 2. Convert
198
  tmp_wav = make_temp_path(suffix=".wav")
199
  await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
200
 
@@ -203,7 +231,7 @@ async def transcribe_url(payload: dict):
203
 
204
  # ---------- ASYNC ----------
205
  if duration > ASYNC_THRESHOLD:
206
- # use ORIGINAL url for async job
207
  job = _enqueue_async_job(audio_url, note_id, user_id)
208
 
209
  REQUEST_COUNT.labels(endpoint, "queued").inc()
@@ -218,6 +246,7 @@ async def transcribe_url(payload: dict):
218
  )
219
 
220
  # ---------- SYNC ----------
 
221
  result = await _run_sync_pipeline(tmp_wav, note_id)
222
 
223
  REQUEST_COUNT.labels(endpoint, "success").inc()
 
7
 
8
  from fastapi import APIRouter, UploadFile, File, HTTPException
9
  from fastapi.responses import JSONResponse
 
10
  from rq import Queue
11
 
12
  from app.config import settings
 
24
  get_audio_info,
25
  upload_temp_audio,
26
  )
27
+
28
  from app.core.asr_engine import (
29
  load_model,
30
  transcribe_file,
 
35
  ASR_MODEL = None
36
  ASYNC_THRESHOLD = 120 # seconds
37
 
38
+ logger = logging.getLogger(__name__)
39
 
40
  # ============================================================
41
  # Startup: load ASR model once
 
58
  raise HTTPException(413, "Audio duration exceeds limit")
59
 
60
 
61
+ def _enqueue_async_job(audio_url: str, note_id: str, user_id: str | None = None):
62
+ q = Queue("asr", connection=redis_client)
63
+ return q.enqueue(
64
+ transcribe_job,
65
+ audio_url,
66
+ note_id,
67
+ user_id,
68
+ job_timeout=1800,
69
+ )
70
+
71
+
72
+ async def _run_sync_pipeline(tmp_wav: str, note_id: str):
73
  """
74
+ Sync ASR update existing note
75
  """
76
  note_service = NoteServiceClient()
77
  info = get_audio_info(tmp_wav) or {}
78
 
 
79
  with ASR_DURATION.labels("/transcribe").time():
80
  text = await asyncio.to_thread(
81
+ transcribe_file, ASR_MODEL, tmp_wav, 30.0, 5.0
82
  )
83
  chunks = await asyncio.to_thread(
84
+ transcribe_file_chunks, ASR_MODEL, tmp_wav, 30.0, 5.0
85
  )
86
 
87
+ chunks = [
88
+ {
89
+ "text": c["text"],
90
+ "start": c.get("start"),
91
+ "end": c.get("end"),
92
+ }
93
+ for c in chunks
94
+ if c.get("text", "").strip()
95
+ ]
96
+
97
  status = "transcribed" if chunks else "error"
98
 
99
+ # 🔥 UPDATE — KHÔNG CREATE
100
+ await note_service.update_note(
101
+ note_id,
102
+ {
103
+ "status": status,
104
+ "raw_text": text,
105
+ "metadata": {
106
+ "audio": {
107
+ "duration": info.get("duration"),
108
+ "sample_rate": info.get("samplerate"),
109
+ "chunks": chunks,
110
+ "asr_model": "PhoWhisper-base",
111
+ }
112
+ },
113
  },
114
+ )
 
 
 
115
 
116
  return {
117
  "note_id": note_id,
 
120
  }
121
 
122
 
123
+ async def _create_placeholder_note(note_id: str, duration: float):
124
+ """
125
+ Tạo note NGAY LẬP TỨC để:
126
+ - SSE không trả not_found
127
+ - enrich có object để update
128
+ """
129
+ await NoteServiceClient().create_audio_note(
130
+ {
131
+ "note_id": note_id,
132
+ "type": "audio",
133
+ "status": "processing",
134
+ "raw_text": "",
135
+ "metadata": {
136
+ "audio": {
137
+ "duration": duration,
138
+ "chunks": [],
139
+ "asr_model": "PhoWhisper-base",
140
+ }
141
+ },
142
+ "generate": ["normalize", "keywords", "summary", "mindmap"],
143
+ }
144
  )
 
145
 
146
 
147
  # ============================================================
 
150
  @router.post("/transcribe", response_model=TranscribeResponse)
151
  async def transcribe(file: UploadFile = File(...)):
152
  endpoint = "/transcribe"
 
153
  note_id = str(uuid.uuid4())
154
 
155
  tmp_in = make_temp_path(suffix=Path(file.filename).suffix or ".tmp")
 
157
 
158
  with REQUEST_LATENCY.labels(endpoint).time():
159
  try:
160
+ # 1️⃣ Save upload
161
  await asyncio.to_thread(save_upload_file, file, tmp_in)
162
  _ensure_file_limits(tmp_in)
163
 
164
+ # 2️⃣ Convert
165
  tmp_wav = make_temp_path(suffix=".wav")
166
  await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
167
 
 
171
  # ---------- ASYNC ----------
172
  if duration > ASYNC_THRESHOLD:
173
  audio_url = await asyncio.to_thread(upload_temp_audio, tmp_wav)
174
+
175
+ await _create_placeholder_note(note_id, duration)
176
  job = _enqueue_async_job(audio_url, note_id)
177
 
178
  REQUEST_COUNT.labels(endpoint, "queued").inc()
 
187
  )
188
 
189
  # ---------- SYNC ----------
190
+ await _create_placeholder_note(note_id, duration)
191
  result = await _run_sync_pipeline(tmp_wav, note_id)
192
 
193
  REQUEST_COUNT.labels(endpoint, "success").inc()
 
200
 
201
 
202
  # ============================================================
203
+ # POST /transcribe-url (FULL LOGIC)
204
  # ============================================================
205
  @router.post("/transcribe-url", response_model=TranscribeResponse)
206
  async def transcribe_url(payload: dict):
207
  endpoint = "/transcribe-url"
 
208
 
209
  audio_url = payload.get("audio_url")
210
  user_id = payload.get("user_id")
 
213
  raise HTTPException(400, "audio_url required")
214
 
215
  note_id = str(uuid.uuid4())
 
216
  tmp_in = make_temp_path(suffix=Path(audio_url).suffix or ".tmp")
217
  tmp_wav = None
218
 
219
  with REQUEST_LATENCY.labels(endpoint).time():
220
  try:
221
+ # 1️⃣ Download
222
  await asyncio.to_thread(download_file_from_url, audio_url, tmp_in)
223
  _ensure_file_limits(tmp_in)
224
 
225
+ # 2️⃣ Convert
226
  tmp_wav = make_temp_path(suffix=".wav")
227
  await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
228
 
 
231
 
232
  # ---------- ASYNC ----------
233
  if duration > ASYNC_THRESHOLD:
234
+ await _create_placeholder_note(note_id, duration)
235
  job = _enqueue_async_job(audio_url, note_id, user_id)
236
 
237
  REQUEST_COUNT.labels(endpoint, "queued").inc()
 
246
  )
247
 
248
  # ---------- SYNC ----------
249
+ await _create_placeholder_note(note_id, duration)
250
  result = await _run_sync_pipeline(tmp_wav, note_id)
251
 
252
  REQUEST_COUNT.labels(endpoint, "success").inc()
app/jobs/transcribe_job.py CHANGED
@@ -1,6 +1,7 @@
1
  import asyncio
2
  import tempfile
3
  import os
 
4
  import requests
5
 
6
  from app.core.asr_engine import load_model, transcribe_file, transcribe_file_chunks
@@ -62,8 +63,12 @@ def transcribe_job(audio_url: str, note_id: str, user_id: str | None = None):
62
  }
63
 
64
  client = NoteServiceClient()
65
- asyncio.run(client.create_audio_note(payload))
66
-
 
 
 
 
67
  finally:
68
  # 3️⃣ Cleanup
69
  if wav_path and os.path.exists(wav_path):
 
1
  import asyncio
2
  import tempfile
3
  import os
4
+ from xmlrpc import client
5
  import requests
6
 
7
  from app.core.asr_engine import load_model, transcribe_file, transcribe_file_chunks
 
63
  }
64
 
65
  client = NoteServiceClient()
66
+ asyncio.run(client.update_note(note_id, {
67
+ "status": note_status,
68
+ "raw_text": text,
69
+ "metadata": payload["metadata"],
70
+ }))
71
+
72
  finally:
73
  # 3️⃣ Cleanup
74
  if wav_path and os.path.exists(wav_path):