PhoWhisperBaseAPI / app /api /transcribe.py
bichnhan2701's picture
add url audio
2916663
import os
import uuid
import time
import asyncio
import logging
from pathlib import Path
from fastapi import APIRouter, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from rq import Queue, Retry
from app.config import settings
from app.infra.redis_client import redis_client
from app.infra.metrics import REQUEST_COUNT, REQUEST_LATENCY, ASR_DURATION
from app.schemas.transcribe import TranscribeResponse
from app.services.note_client import NoteServiceClient
from app.jobs.transcribe_job import transcribe_job
from app.core.audio_utils import (
save_upload_file,
download_file_from_url,
ensure_wav_16k_mono,
make_temp_path,
get_audio_info,
upload_temp_audio,
)
from app.core.asr_engine import (
load_model,
transcribe_file,
transcribe_file_chunks,
)
router = APIRouter()
ASR_MODEL = None
ASYNC_THRESHOLD = 120 # seconds
logger = logging.getLogger(__name__)
# ============================================================
# Startup: load ASR model once
# ============================================================
@router.on_event("startup")
async def startup():
global ASR_MODEL
ASR_MODEL = await asyncio.to_thread(load_model, 30)
# ============================================================
# Utils
# ============================================================
def _ensure_file_limits(path: str):
if os.path.getsize(path) > settings.MAX_UPLOAD_BYTES:
raise HTTPException(413, "File size exceeds limit")
info = get_audio_info(path)
if info and info.get("duration", 0) > settings.MAX_DURATION_SECS:
raise HTTPException(413, "Audio duration exceeds limit")
def _enqueue_async_job(audio_url: str, note_id: str, user_id: str | None = None):
q = Queue("asr", connection=redis_client)
return q.enqueue(
transcribe_job,
audio_url,
note_id,
user_id,
job_timeout=1800,
retry=Retry(max=3, interval=[2, 5, 10]),
)
async def _run_sync_pipeline(tmp_wav: str, note_id: str, audio_url: str | None = None):
"""
Sync ASR → update existing note
"""
note_service = NoteServiceClient()
info = get_audio_info(tmp_wav) or {}
with ASR_DURATION.labels("/transcribe").time():
text = await asyncio.to_thread(
transcribe_file, ASR_MODEL, tmp_wav, 30.0, 5.0
)
chunks = await asyncio.to_thread(
transcribe_file_chunks, ASR_MODEL, tmp_wav, 30.0, 5.0
)
chunks = [
{
"text": c["text"],
"start": c.get("start"),
"end": c.get("end"),
}
for c in chunks
if c.get("text", "").strip()
]
# If model returned text but pipeline didn't provide timestamped chunks,
# treat as transcribed and create a single fallback chunk covering full duration.
text = text or ""
if not chunks and text.strip():
chunks = [{"text": text.strip(), "start": 0.0, "end": info.get("duration") or 0.0}]
status = "transcribed" if text.strip() else "error"
# 🔥 UPDATE — KHÔNG CREATE
# Use internal update endpoint to match create_audio_note (internal API)
payload = {
"status": status,
"raw_text": text,
"metadata": {
"audio": {
"duration": info.get("duration"),
"sample_rate": info.get("samplerate"),
"chunks": chunks,
"asr_model": "PhoWhisper-base",
}
},
"generate": ["normalize", "keywords", "summary", "mindmap"],
}
if audio_url:
payload["metadata"]["audio"]["url"] = audio_url
await note_service.update_note_internal(note_id, payload)
return {
"note_id": note_id,
"status": status,
"duration": info.get("duration"),
"audio_url": audio_url,
}
async def _create_placeholder_note(note_id: str, duration: float, audio_url: str | None = None):
client = NoteServiceClient()
payload = {
"note_id": note_id,
"type": "audio",
"status": "processing",
"raw_text": "",
"metadata": {
"audio": {
"duration": duration,
"chunks": [],
"asr_model": "PhoWhisper-base",
}
},
# ❌ KHÔNG generate ở đây
}
if audio_url:
payload["metadata"]["audio"]["url"] = audio_url
res = await client.create_audio_note(payload)
if res is None:
# 🔥 FAIL FAST
raise RuntimeError("Failed to create placeholder note")
# ============================================================
# POST /transcribe (UPLOAD FILE)
# ============================================================
@router.post("/transcribe", response_model=TranscribeResponse)
async def transcribe(file: UploadFile = File(...)):
endpoint = "/transcribe"
note_id = str(uuid.uuid4())
tmp_in = make_temp_path(suffix=Path(file.filename).suffix or ".tmp")
tmp_wav = None
with REQUEST_LATENCY.labels(endpoint).time():
try:
# 1️⃣ Save upload
await asyncio.to_thread(save_upload_file, file, tmp_in)
_ensure_file_limits(tmp_in)
# 2️⃣ Convert
tmp_wav = make_temp_path(suffix=".wav")
await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
info = get_audio_info(tmp_wav) or {}
duration = info.get("duration", 0)
# ---------- ASYNC ----------
if duration > ASYNC_THRESHOLD:
audio_url = await asyncio.to_thread(upload_temp_audio, tmp_wav)
await _create_placeholder_note(note_id, duration, audio_url)
job = _enqueue_async_job(audio_url, note_id)
REQUEST_COUNT.labels(endpoint, "queued").inc()
return JSONResponse(
status_code=202,
content={
"note_id": note_id,
"job_id": job.id,
"status": "queued",
"duration": duration,
"audio_url": audio_url,
},
)
# ---------- SYNC ----------
# Upload small audio so we can return a stable URL and persist it
audio_url = await asyncio.to_thread(upload_temp_audio, tmp_wav)
await _create_placeholder_note(note_id, duration, audio_url)
result = await _run_sync_pipeline(tmp_wav, note_id, audio_url)
REQUEST_COUNT.labels(endpoint, "success").inc()
return result
finally:
for p in (tmp_in, tmp_wav):
if p and os.path.exists(p):
os.remove(p)
# ============================================================
# POST /transcribe-url (FULL LOGIC)
# ============================================================
@router.post("/transcribe-url", response_model=TranscribeResponse)
async def transcribe_url(payload: dict):
endpoint = "/transcribe-url"
audio_url = payload.get("audio_url")
user_id = payload.get("user_id")
if not audio_url:
raise HTTPException(400, "audio_url required")
note_id = str(uuid.uuid4())
tmp_in = make_temp_path(suffix=Path(audio_url).suffix or ".tmp")
tmp_wav = None
with REQUEST_LATENCY.labels(endpoint).time():
try:
# 1️⃣ Download
await asyncio.to_thread(download_file_from_url, audio_url, tmp_in)
_ensure_file_limits(tmp_in)
# 2️⃣ Convert
tmp_wav = make_temp_path(suffix=".wav")
await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
info = get_audio_info(tmp_wav) or {}
duration = info.get("duration", 0)
# ---------- ASYNC ----------
if duration > ASYNC_THRESHOLD:
await _create_placeholder_note(note_id, duration, audio_url)
job = _enqueue_async_job(audio_url, note_id, user_id)
REQUEST_COUNT.labels(endpoint, "queued").inc()
return JSONResponse(
status_code=202,
content={
"note_id": note_id,
"job_id": job.id,
"status": "queued",
"duration": duration,
"audio_url": audio_url,
},
)
# ---------- SYNC ----------
await _create_placeholder_note(note_id, duration, audio_url)
result = await _run_sync_pipeline(tmp_wav, note_id, audio_url)
REQUEST_COUNT.labels(endpoint, "success").inc()
return result
finally:
for p in (tmp_in, tmp_wav):
if p and os.path.exists(p):
os.remove(p)