import os import uuid import time import asyncio import logging from pathlib import Path from fastapi import APIRouter, UploadFile, File, HTTPException from fastapi.responses import JSONResponse from rq import Queue, Retry from app.config import settings from app.infra.redis_client import redis_client from app.infra.metrics import REQUEST_COUNT, REQUEST_LATENCY, ASR_DURATION from app.schemas.transcribe import TranscribeResponse from app.services.note_client import NoteServiceClient from app.jobs.transcribe_job import transcribe_job from app.core.audio_utils import ( save_upload_file, download_file_from_url, ensure_wav_16k_mono, make_temp_path, get_audio_info, upload_temp_audio, ) from app.core.asr_engine import ( load_model, transcribe_file, transcribe_file_chunks, ) router = APIRouter() ASR_MODEL = None ASYNC_THRESHOLD = 120 # seconds logger = logging.getLogger(__name__) # ============================================================ # Startup: load ASR model once # ============================================================ @router.on_event("startup") async def startup(): global ASR_MODEL ASR_MODEL = await asyncio.to_thread(load_model, 30) # ============================================================ # Utils # ============================================================ def _ensure_file_limits(path: str): if os.path.getsize(path) > settings.MAX_UPLOAD_BYTES: raise HTTPException(413, "File size exceeds limit") info = get_audio_info(path) if info and info.get("duration", 0) > settings.MAX_DURATION_SECS: raise HTTPException(413, "Audio duration exceeds limit") def _enqueue_async_job(audio_url: str, note_id: str, user_id: str | None = None): q = Queue("asr", connection=redis_client) return q.enqueue( transcribe_job, audio_url, note_id, user_id, job_timeout=1800, retry=Retry(max=3, interval=[2, 5, 10]), ) async def _run_sync_pipeline(tmp_wav: str, note_id: str, audio_url: str | None = None): """ Sync ASR → update existing note """ note_service = NoteServiceClient() info = get_audio_info(tmp_wav) or {} with ASR_DURATION.labels("/transcribe").time(): text = await asyncio.to_thread( transcribe_file, ASR_MODEL, tmp_wav, 30.0, 5.0 ) chunks = await asyncio.to_thread( transcribe_file_chunks, ASR_MODEL, tmp_wav, 30.0, 5.0 ) chunks = [ { "text": c["text"], "start": c.get("start"), "end": c.get("end"), } for c in chunks if c.get("text", "").strip() ] # If model returned text but pipeline didn't provide timestamped chunks, # treat as transcribed and create a single fallback chunk covering full duration. text = text or "" if not chunks and text.strip(): chunks = [{"text": text.strip(), "start": 0.0, "end": info.get("duration") or 0.0}] status = "transcribed" if text.strip() else "error" # 🔥 UPDATE — KHÔNG CREATE # Use internal update endpoint to match create_audio_note (internal API) payload = { "status": status, "raw_text": text, "metadata": { "audio": { "duration": info.get("duration"), "sample_rate": info.get("samplerate"), "chunks": chunks, "asr_model": "PhoWhisper-base", } }, "generate": ["normalize", "keywords", "summary", "mindmap"], } if audio_url: payload["metadata"]["audio"]["url"] = audio_url await note_service.update_note_internal(note_id, payload) return { "note_id": note_id, "status": status, "duration": info.get("duration"), "audio_url": audio_url, } async def _create_placeholder_note(note_id: str, duration: float, audio_url: str | None = None): client = NoteServiceClient() payload = { "note_id": note_id, "type": "audio", "status": "processing", "raw_text": "", "metadata": { "audio": { "duration": duration, "chunks": [], "asr_model": "PhoWhisper-base", } }, # ❌ KHÔNG generate ở đây } if audio_url: payload["metadata"]["audio"]["url"] = audio_url res = await client.create_audio_note(payload) if res is None: # 🔥 FAIL FAST raise RuntimeError("Failed to create placeholder note") # ============================================================ # POST /transcribe (UPLOAD FILE) # ============================================================ @router.post("/transcribe", response_model=TranscribeResponse) async def transcribe(file: UploadFile = File(...)): endpoint = "/transcribe" note_id = str(uuid.uuid4()) tmp_in = make_temp_path(suffix=Path(file.filename).suffix or ".tmp") tmp_wav = None with REQUEST_LATENCY.labels(endpoint).time(): try: # 1️⃣ Save upload await asyncio.to_thread(save_upload_file, file, tmp_in) _ensure_file_limits(tmp_in) # 2️⃣ Convert tmp_wav = make_temp_path(suffix=".wav") await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav) info = get_audio_info(tmp_wav) or {} duration = info.get("duration", 0) # ---------- ASYNC ---------- if duration > ASYNC_THRESHOLD: audio_url = await asyncio.to_thread(upload_temp_audio, tmp_wav) await _create_placeholder_note(note_id, duration, audio_url) job = _enqueue_async_job(audio_url, note_id) REQUEST_COUNT.labels(endpoint, "queued").inc() return JSONResponse( status_code=202, content={ "note_id": note_id, "job_id": job.id, "status": "queued", "duration": duration, "audio_url": audio_url, }, ) # ---------- SYNC ---------- # Upload small audio so we can return a stable URL and persist it audio_url = await asyncio.to_thread(upload_temp_audio, tmp_wav) await _create_placeholder_note(note_id, duration, audio_url) result = await _run_sync_pipeline(tmp_wav, note_id, audio_url) REQUEST_COUNT.labels(endpoint, "success").inc() return result finally: for p in (tmp_in, tmp_wav): if p and os.path.exists(p): os.remove(p) # ============================================================ # POST /transcribe-url (FULL LOGIC) # ============================================================ @router.post("/transcribe-url", response_model=TranscribeResponse) async def transcribe_url(payload: dict): endpoint = "/transcribe-url" audio_url = payload.get("audio_url") user_id = payload.get("user_id") if not audio_url: raise HTTPException(400, "audio_url required") note_id = str(uuid.uuid4()) tmp_in = make_temp_path(suffix=Path(audio_url).suffix or ".tmp") tmp_wav = None with REQUEST_LATENCY.labels(endpoint).time(): try: # 1️⃣ Download await asyncio.to_thread(download_file_from_url, audio_url, tmp_in) _ensure_file_limits(tmp_in) # 2️⃣ Convert tmp_wav = make_temp_path(suffix=".wav") await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav) info = get_audio_info(tmp_wav) or {} duration = info.get("duration", 0) # ---------- ASYNC ---------- if duration > ASYNC_THRESHOLD: await _create_placeholder_note(note_id, duration, audio_url) job = _enqueue_async_job(audio_url, note_id, user_id) REQUEST_COUNT.labels(endpoint, "queued").inc() return JSONResponse( status_code=202, content={ "note_id": note_id, "job_id": job.id, "status": "queued", "duration": duration, "audio_url": audio_url, }, ) # ---------- SYNC ---------- await _create_placeholder_note(note_id, duration, audio_url) result = await _run_sync_pipeline(tmp_wav, note_id, audio_url) REQUEST_COUNT.labels(endpoint, "success").inc() return result finally: for p in (tmp_in, tmp_wav): if p and os.path.exists(p): os.remove(p)