Spaces:
Running
Running
| import os | |
| import uuid | |
| import time | |
| import asyncio | |
| import logging | |
| from pathlib import Path | |
| from fastapi import APIRouter, UploadFile, File, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from rq import Queue, Retry | |
| from app.config import settings | |
| from app.infra.redis_client import redis_client | |
| from app.infra.metrics import REQUEST_COUNT, REQUEST_LATENCY, ASR_DURATION | |
| from app.schemas.transcribe import TranscribeResponse | |
| from app.services.note_client import NoteServiceClient | |
| from app.jobs.transcribe_job import transcribe_job | |
| from app.core.audio_utils import ( | |
| save_upload_file, | |
| download_file_from_url, | |
| ensure_wav_16k_mono, | |
| make_temp_path, | |
| get_audio_info, | |
| upload_temp_audio, | |
| ) | |
| from app.core.asr_engine import ( | |
| load_model, | |
| transcribe_file, | |
| transcribe_file_chunks, | |
| ) | |
| router = APIRouter() | |
| ASR_MODEL = None | |
| ASYNC_THRESHOLD = 120 # seconds | |
| logger = logging.getLogger(__name__) | |
| # ============================================================ | |
| # Startup: load ASR model once | |
| # ============================================================ | |
| async def startup(): | |
| global ASR_MODEL | |
| ASR_MODEL = await asyncio.to_thread(load_model, 30) | |
| # ============================================================ | |
| # Utils | |
| # ============================================================ | |
| def _ensure_file_limits(path: str): | |
| if os.path.getsize(path) > settings.MAX_UPLOAD_BYTES: | |
| raise HTTPException(413, "File size exceeds limit") | |
| info = get_audio_info(path) | |
| if info and info.get("duration", 0) > settings.MAX_DURATION_SECS: | |
| raise HTTPException(413, "Audio duration exceeds limit") | |
| def _enqueue_async_job(audio_url: str, note_id: str, user_id: str | None = None): | |
| q = Queue("asr", connection=redis_client) | |
| return q.enqueue( | |
| transcribe_job, | |
| audio_url, | |
| note_id, | |
| user_id, | |
| job_timeout=1800, | |
| retry=Retry(max=3, interval=[2, 5, 10]), | |
| ) | |
| async def _run_sync_pipeline(tmp_wav: str, note_id: str, audio_url: str | None = None): | |
| """ | |
| Sync ASR → update existing note | |
| """ | |
| note_service = NoteServiceClient() | |
| info = get_audio_info(tmp_wav) or {} | |
| with ASR_DURATION.labels("/transcribe").time(): | |
| text = await asyncio.to_thread( | |
| transcribe_file, ASR_MODEL, tmp_wav, 30.0, 5.0 | |
| ) | |
| chunks = await asyncio.to_thread( | |
| transcribe_file_chunks, ASR_MODEL, tmp_wav, 30.0, 5.0 | |
| ) | |
| chunks = [ | |
| { | |
| "text": c["text"], | |
| "start": c.get("start"), | |
| "end": c.get("end"), | |
| } | |
| for c in chunks | |
| if c.get("text", "").strip() | |
| ] | |
| # If model returned text but pipeline didn't provide timestamped chunks, | |
| # treat as transcribed and create a single fallback chunk covering full duration. | |
| text = text or "" | |
| if not chunks and text.strip(): | |
| chunks = [{"text": text.strip(), "start": 0.0, "end": info.get("duration") or 0.0}] | |
| status = "transcribed" if text.strip() else "error" | |
| # 🔥 UPDATE — KHÔNG CREATE | |
| # Use internal update endpoint to match create_audio_note (internal API) | |
| payload = { | |
| "status": status, | |
| "raw_text": text, | |
| "metadata": { | |
| "audio": { | |
| "duration": info.get("duration"), | |
| "sample_rate": info.get("samplerate"), | |
| "chunks": chunks, | |
| "asr_model": "PhoWhisper-base", | |
| } | |
| }, | |
| "generate": ["normalize", "keywords", "summary", "mindmap"], | |
| } | |
| if audio_url: | |
| payload["metadata"]["audio"]["url"] = audio_url | |
| await note_service.update_note_internal(note_id, payload) | |
| return { | |
| "note_id": note_id, | |
| "status": status, | |
| "duration": info.get("duration"), | |
| "audio_url": audio_url, | |
| } | |
| async def _create_placeholder_note(note_id: str, duration: float, audio_url: str | None = None): | |
| client = NoteServiceClient() | |
| payload = { | |
| "note_id": note_id, | |
| "type": "audio", | |
| "status": "processing", | |
| "raw_text": "", | |
| "metadata": { | |
| "audio": { | |
| "duration": duration, | |
| "chunks": [], | |
| "asr_model": "PhoWhisper-base", | |
| } | |
| }, | |
| # ❌ KHÔNG generate ở đây | |
| } | |
| if audio_url: | |
| payload["metadata"]["audio"]["url"] = audio_url | |
| res = await client.create_audio_note(payload) | |
| if res is None: | |
| # 🔥 FAIL FAST | |
| raise RuntimeError("Failed to create placeholder note") | |
| # ============================================================ | |
| # POST /transcribe (UPLOAD FILE) | |
| # ============================================================ | |
| async def transcribe(file: UploadFile = File(...)): | |
| endpoint = "/transcribe" | |
| note_id = str(uuid.uuid4()) | |
| tmp_in = make_temp_path(suffix=Path(file.filename).suffix or ".tmp") | |
| tmp_wav = None | |
| with REQUEST_LATENCY.labels(endpoint).time(): | |
| try: | |
| # 1️⃣ Save upload | |
| await asyncio.to_thread(save_upload_file, file, tmp_in) | |
| _ensure_file_limits(tmp_in) | |
| # 2️⃣ Convert | |
| tmp_wav = make_temp_path(suffix=".wav") | |
| await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav) | |
| info = get_audio_info(tmp_wav) or {} | |
| duration = info.get("duration", 0) | |
| # ---------- ASYNC ---------- | |
| if duration > ASYNC_THRESHOLD: | |
| audio_url = await asyncio.to_thread(upload_temp_audio, tmp_wav) | |
| await _create_placeholder_note(note_id, duration, audio_url) | |
| job = _enqueue_async_job(audio_url, note_id) | |
| REQUEST_COUNT.labels(endpoint, "queued").inc() | |
| return JSONResponse( | |
| status_code=202, | |
| content={ | |
| "note_id": note_id, | |
| "job_id": job.id, | |
| "status": "queued", | |
| "duration": duration, | |
| "audio_url": audio_url, | |
| }, | |
| ) | |
| # ---------- SYNC ---------- | |
| # Upload small audio so we can return a stable URL and persist it | |
| audio_url = await asyncio.to_thread(upload_temp_audio, tmp_wav) | |
| await _create_placeholder_note(note_id, duration, audio_url) | |
| result = await _run_sync_pipeline(tmp_wav, note_id, audio_url) | |
| REQUEST_COUNT.labels(endpoint, "success").inc() | |
| return result | |
| finally: | |
| for p in (tmp_in, tmp_wav): | |
| if p and os.path.exists(p): | |
| os.remove(p) | |
| # ============================================================ | |
| # POST /transcribe-url (FULL LOGIC) | |
| # ============================================================ | |
| async def transcribe_url(payload: dict): | |
| endpoint = "/transcribe-url" | |
| audio_url = payload.get("audio_url") | |
| user_id = payload.get("user_id") | |
| if not audio_url: | |
| raise HTTPException(400, "audio_url required") | |
| note_id = str(uuid.uuid4()) | |
| tmp_in = make_temp_path(suffix=Path(audio_url).suffix or ".tmp") | |
| tmp_wav = None | |
| with REQUEST_LATENCY.labels(endpoint).time(): | |
| try: | |
| # 1️⃣ Download | |
| await asyncio.to_thread(download_file_from_url, audio_url, tmp_in) | |
| _ensure_file_limits(tmp_in) | |
| # 2️⃣ Convert | |
| tmp_wav = make_temp_path(suffix=".wav") | |
| await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav) | |
| info = get_audio_info(tmp_wav) or {} | |
| duration = info.get("duration", 0) | |
| # ---------- ASYNC ---------- | |
| if duration > ASYNC_THRESHOLD: | |
| await _create_placeholder_note(note_id, duration, audio_url) | |
| job = _enqueue_async_job(audio_url, note_id, user_id) | |
| REQUEST_COUNT.labels(endpoint, "queued").inc() | |
| return JSONResponse( | |
| status_code=202, | |
| content={ | |
| "note_id": note_id, | |
| "job_id": job.id, | |
| "status": "queued", | |
| "duration": duration, | |
| "audio_url": audio_url, | |
| }, | |
| ) | |
| # ---------- SYNC ---------- | |
| await _create_placeholder_note(note_id, duration, audio_url) | |
| result = await _run_sync_pipeline(tmp_wav, note_id, audio_url) | |
| REQUEST_COUNT.labels(endpoint, "success").inc() | |
| return result | |
| finally: | |
| for p in (tmp_in, tmp_wav): | |
| if p and os.path.exists(p): | |
| os.remove(p) | |