Spaces:
Sleeping
Sleeping
Commit
Β·
f84ed9c
1
Parent(s):
efea087
Change logic pipeline
Browse files- .dockerignore +10 -2
- .gitignore +3 -2
- Dockerfile +22 -10
- app/api/__init__.py +0 -0
- app/api/transcribe.py +191 -0
- app/config/__init__.py +0 -0
- app/{config.py β config/settings.py} +13 -0
- app/core/__init__.py +0 -0
- app/{model.py β core/asr_engine.py} +18 -92
- app/{audio_utils.py β core/audio_utils.py} +25 -14
- app/core/chunking.py +36 -0
- app/infra/metrics.py +32 -0
- app/infra/redis_client.py +8 -0
- app/jobs/transcribe_job.py +27 -0
- app/main.py +29 -98
- app/schemas/__init__.py +0 -0
- app/schemas/transcribe.py +14 -0
- app/services/__init__.py +0 -0
- app/services/note_client.py +48 -0
- app/services/text_normalizer.py +49 -0
- app/utils/hashing.py +7 -0
- requirements.txt +6 -0
- test/conftest.py +9 -55
- test/test_long_performance.py +0 -21
- test/test_short_and_chunk.py +0 -46
- test/test_silence_and_overlap.py +0 -12
.dockerignore
CHANGED
|
@@ -1,2 +1,10 @@
|
|
| 1 |
-
|
| 2 |
-
*.md
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
test/
|
| 2 |
+
*.md
|
| 3 |
+
.myvenv
|
| 4 |
+
__pycache__
|
| 5 |
+
*.pyc
|
| 6 |
+
.DS_Store
|
| 7 |
+
.git
|
| 8 |
+
.vscode
|
| 9 |
+
.idea
|
| 10 |
+
docker-compose.yml
|
.gitignore
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
-
|
| 2 |
.myvenv/
|
| 3 |
__pycache__/
|
| 4 |
*.pyc
|
| 5 |
.env
|
| 6 |
-
*.md
|
|
|
|
|
|
| 1 |
+
test/
|
| 2 |
.myvenv/
|
| 3 |
__pycache__/
|
| 4 |
*.pyc
|
| 5 |
.env
|
| 6 |
+
*.md
|
| 7 |
+
docker-compose.yml
|
Dockerfile
CHANGED
|
@@ -1,22 +1,34 @@
|
|
| 1 |
-
FROM python:3.
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
| 5 |
-
ffmpeg libsndfile1 git build-essential wget && \
|
| 6 |
rm -rf /var/lib/apt/lists/*
|
| 7 |
|
| 8 |
WORKDIR /app
|
| 9 |
|
| 10 |
-
#
|
| 11 |
COPY requirements.txt /app/requirements.txt
|
| 12 |
-
RUN pip install --upgrade pip
|
| 13 |
-
|
| 14 |
|
| 15 |
# copy app code
|
| 16 |
COPY . /app
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
|
| 22 |
-
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
TMP_DIR=/tmp/uploads \
|
| 6 |
+
PORT=7860
|
| 7 |
+
|
| 8 |
+
# system deps (single RUN to minimize layers)
|
| 9 |
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
| 10 |
+
ffmpeg libsndfile1 git build-essential wget curl && \
|
| 11 |
rm -rf /var/lib/apt/lists/*
|
| 12 |
|
| 13 |
WORKDIR /app
|
| 14 |
|
| 15 |
+
# install python deps using cached layer
|
| 16 |
COPY requirements.txt /app/requirements.txt
|
| 17 |
+
RUN pip install --upgrade pip && \
|
| 18 |
+
pip install --no-cache-dir -r /app/requirements.txt
|
| 19 |
|
| 20 |
# copy app code
|
| 21 |
COPY . /app
|
| 22 |
|
| 23 |
+
# create tmp dir and non-root user
|
| 24 |
+
RUN mkdir -p ${TMP_DIR} && groupadd -r app && useradd -r -g app app && \
|
| 25 |
+
chown -R app:app /app ${TMP_DIR}
|
| 26 |
+
|
| 27 |
+
USER app
|
| 28 |
+
|
| 29 |
+
EXPOSE ${PORT}
|
| 30 |
+
|
| 31 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=10s \
|
| 32 |
+
CMD curl -f http://localhost:${PORT}/health || exit 1
|
| 33 |
|
| 34 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
|
|
app/api/__init__.py
ADDED
|
File without changes
|
app/api/transcribe.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import uuid
|
| 4 |
+
import asyncio
|
| 5 |
+
from fastapi import APIRouter, UploadFile, File, HTTPException, status
|
| 6 |
+
from fastapi.responses import JSONResponse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional
|
| 9 |
+
import time
|
| 10 |
+
from app.core.audio_utils import save_upload_file, get_audio_info, ensure_wav_16k_mono, make_temp_path, download_file_from_url
|
| 11 |
+
from app.core.asr_engine import load_model, transcribe_file, transcribe_file_chunks
|
| 12 |
+
from app.config import settings
|
| 13 |
+
from app.services.text_normalizer import normalize_text
|
| 14 |
+
from app.services.note_client import NoteServiceClient
|
| 15 |
+
from rq import Queue
|
| 16 |
+
from app.infra.redis_client import redis_client
|
| 17 |
+
from app.jobs.transcribe_job import transcribe_job
|
| 18 |
+
from app.schemas.transcribe import TranscribeResponse
|
| 19 |
+
from app.infra.metrics import REQUEST_COUNT, REQUEST_LATENCY, ASR_DURATION, NORMALIZE_DURATION, ERROR_COUNT
|
| 20 |
+
|
| 21 |
+
router = APIRouter()
|
| 22 |
+
|
| 23 |
+
# load model on import/startup to avoid repeated initialization
|
| 24 |
+
# you may prefer to call load_model in FastAPI startup event
|
| 25 |
+
ASR_MODEL = None
|
| 26 |
+
|
| 27 |
+
@router.on_event("startup")
|
| 28 |
+
async def _startup():
|
| 29 |
+
global ASR_MODEL
|
| 30 |
+
# load model in thread to avoid blocking event loop
|
| 31 |
+
ASR_MODEL = await asyncio.to_thread(load_model, 30)
|
| 32 |
+
|
| 33 |
+
def _ensure_file_limits(path: str):
|
| 34 |
+
if os.path.getsize(path) > settings.MAX_UPLOAD_BYTES:
|
| 35 |
+
raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="File size exceeds limit")
|
| 36 |
+
info = get_audio_info(path)
|
| 37 |
+
if info and info.get("duration", 0) > settings.MAX_DURATION_SECS:
|
| 38 |
+
raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="Audio duration exceeds limit")
|
| 39 |
+
|
| 40 |
+
@router.post("/transcribe", response_model=TranscribeResponse)
|
| 41 |
+
async def transcribe(file: UploadFile = File(...)):
|
| 42 |
+
tmp_in = make_temp_path(suffix=Path(file.filename).suffix or ".wav")
|
| 43 |
+
tmp_wav = None
|
| 44 |
+
note_service = NoteServiceClient()
|
| 45 |
+
note_id = str(uuid.uuid4())
|
| 46 |
+
start_time = time.perf_counter()
|
| 47 |
+
endpoint = "/transcribe"
|
| 48 |
+
status_label = "success"
|
| 49 |
+
with REQUEST_LATENCY.labels(endpoint).time():
|
| 50 |
+
try:
|
| 51 |
+
# write upload to tmp (blocking) -> run in thread
|
| 52 |
+
await asyncio.to_thread(save_upload_file, file, tmp_in)
|
| 53 |
+
|
| 54 |
+
_ensure_file_limits(tmp_in)
|
| 55 |
+
|
| 56 |
+
tmp_wav = make_temp_path(suffix=".wav")
|
| 57 |
+
# ffmpeg convert is blocking -> run in thread
|
| 58 |
+
await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
|
| 59 |
+
|
| 60 |
+
# Kiα»m tra duration Δα» quyαΊΏt Δα»nh xα» lΓ½ sync hay async
|
| 61 |
+
info = get_audio_info(tmp_wav) or {}
|
| 62 |
+
duration_sec = info.get("duration", 0)
|
| 63 |
+
ASYNC_THRESHOLD = 120 # 2 phΓΊt, cΓ³ thα» chα»nh
|
| 64 |
+
if duration_sec > ASYNC_THRESHOLD:
|
| 65 |
+
# Enqueue background job bαΊ±ng RQ
|
| 66 |
+
q = Queue("asr", connection=redis_client)
|
| 67 |
+
job = q.enqueue(
|
| 68 |
+
transcribe_job,
|
| 69 |
+
tmp_wav,
|
| 70 |
+
note_id,
|
| 71 |
+
job_timeout=1800
|
| 72 |
+
)
|
| 73 |
+
logging.info(f"Enqueued background transcribe job: note_id={note_id} job_id={job.id} duration={duration_sec:.1f}s")
|
| 74 |
+
REQUEST_COUNT.labels(endpoint, "queued").inc()
|
| 75 |
+
return JSONResponse(status_code=202, content={
|
| 76 |
+
"note_id": note_id,
|
| 77 |
+
"job_id": job.id,
|
| 78 |
+
"status": "queued",
|
| 79 |
+
"duration": duration_sec
|
| 80 |
+
})
|
| 81 |
+
|
| 82 |
+
# NαΊΏu audio ngαΊ―n, xα» lΓ½ sync nhΖ° cΕ©
|
| 83 |
+
model = ASR_MODEL or await asyncio.to_thread(load_model, 30)
|
| 84 |
+
with ASR_DURATION.labels(endpoint).time():
|
| 85 |
+
text = await asyncio.to_thread(transcribe_file, model, tmp_wav, 30.0, 5.0)
|
| 86 |
+
chunks = await asyncio.to_thread(transcribe_file_chunks, model, tmp_wav, 30.0, 5.0)
|
| 87 |
+
|
| 88 |
+
# normalize via Gemini (already async safe in your service)
|
| 89 |
+
with NORMALIZE_DURATION.labels(endpoint).time():
|
| 90 |
+
normalized_text = await normalize_text(text)
|
| 91 |
+
|
| 92 |
+
info2 = get_audio_info(tmp_wav) or {}
|
| 93 |
+
# persist to Note Service (async HTTP)
|
| 94 |
+
await note_service.save_transcript(
|
| 95 |
+
note_id=note_id,
|
| 96 |
+
raw_text=text,
|
| 97 |
+
normalized_text=normalized_text,
|
| 98 |
+
duration=info2.get("duration"),
|
| 99 |
+
sample_rate=info2.get("samplerate"),
|
| 100 |
+
chunks=chunks,
|
| 101 |
+
asr_model="PhoWhisper-base",
|
| 102 |
+
normalization_model="gemini-1.5"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
duration = time.perf_counter() - start_time
|
| 106 |
+
logging.info(f"/transcribe success note_id={note_id} duration={duration:.2f}s audio_dur={info2.get('duration')}")
|
| 107 |
+
REQUEST_COUNT.labels(endpoint, status_label).inc()
|
| 108 |
+
return JSONResponse(status_code=200, content={
|
| 109 |
+
"note_id": note_id,
|
| 110 |
+
"status": "transcribed",
|
| 111 |
+
"duration": info2.get("duration")
|
| 112 |
+
})
|
| 113 |
+
except HTTPException:
|
| 114 |
+
status_label = "http_error"
|
| 115 |
+
ERROR_COUNT.labels(endpoint, status_label).inc()
|
| 116 |
+
raise
|
| 117 |
+
except Exception as e:
|
| 118 |
+
status_label = "error"
|
| 119 |
+
ERROR_COUNT.labels(endpoint, status_label).inc()
|
| 120 |
+
logging.exception(f"/transcribe failed note_id={note_id}")
|
| 121 |
+
raise HTTPException(status_code=500, detail=f"Transcription failed: {e}")
|
| 122 |
+
finally:
|
| 123 |
+
# cleanup
|
| 124 |
+
for p in [tmp_in, tmp_wav]:
|
| 125 |
+
try:
|
| 126 |
+
if p and os.path.exists(p):
|
| 127 |
+
os.remove(p)
|
| 128 |
+
except Exception:
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@router.post("/transcribe-url", response_model=TranscribeResponse)
|
| 133 |
+
async def transcribe_url(payload: dict):
|
| 134 |
+
audio_url = payload.get("audio_url")
|
| 135 |
+
user_id = payload.get("user_id")
|
| 136 |
+
if not audio_url:
|
| 137 |
+
raise HTTPException(status_code=400, detail="audio_url required")
|
| 138 |
+
if not user_id:
|
| 139 |
+
raise HTTPException(status_code=400, detail="user_id required")
|
| 140 |
+
|
| 141 |
+
tmp_in = make_temp_path(suffix=Path(audio_url).suffix or ".tmp")
|
| 142 |
+
tmp_wav = None
|
| 143 |
+
note_service = NoteServiceClient()
|
| 144 |
+
note_id = str(uuid.uuid4())
|
| 145 |
+
|
| 146 |
+
start_time = time.perf_counter()
|
| 147 |
+
try:
|
| 148 |
+
# download blocking -> thread
|
| 149 |
+
await asyncio.to_thread(download_file_from_url, audio_url, tmp_in)
|
| 150 |
+
|
| 151 |
+
_ensure_file_limits(tmp_in)
|
| 152 |
+
|
| 153 |
+
tmp_wav = make_temp_path(suffix=".wav")
|
| 154 |
+
await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
|
| 155 |
+
|
| 156 |
+
model = ASR_MODEL or await asyncio.to_thread(load_model, 30)
|
| 157 |
+
text = await asyncio.to_thread(transcribe_file, model, tmp_wav, 30.0, 5.0)
|
| 158 |
+
chunks = await asyncio.to_thread(transcribe_file_chunks, model, tmp_wav, 30.0, 5.0)
|
| 159 |
+
normalized_text = await normalize_text(text)
|
| 160 |
+
info2 = get_audio_info(tmp_wav) or {}
|
| 161 |
+
|
| 162 |
+
await note_service.save_transcript(
|
| 163 |
+
note_id=note_id,
|
| 164 |
+
raw_text=text,
|
| 165 |
+
normalized_text=normalized_text,
|
| 166 |
+
duration=info2.get("duration"),
|
| 167 |
+
sample_rate=info2.get("samplerate"),
|
| 168 |
+
chunks=chunks,
|
| 169 |
+
asr_model="PhoWhisper-base",
|
| 170 |
+
normalization_model="gemini-1.5"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
duration = time.perf_counter() - start_time
|
| 174 |
+
logging.info(f"/transcribe-url success note_id={note_id} duration={duration:.2f}s audio_dur={info2.get('duration')}")
|
| 175 |
+
return JSONResponse(status_code=200, content={
|
| 176 |
+
"note_id": note_id,
|
| 177 |
+
"status": "transcribed",
|
| 178 |
+
"duration": info2.get("duration")
|
| 179 |
+
})
|
| 180 |
+
except HTTPException:
|
| 181 |
+
raise
|
| 182 |
+
except Exception as e:
|
| 183 |
+
logging.exception(f"/transcribe-url failed note_id={note_id}")
|
| 184 |
+
raise HTTPException(status_code=500, detail=f"Transcription failed: {e}")
|
| 185 |
+
finally:
|
| 186 |
+
for p in [tmp_in, tmp_wav]:
|
| 187 |
+
try:
|
| 188 |
+
if p and os.path.exists(p):
|
| 189 |
+
os.remove(p)
|
| 190 |
+
except Exception:
|
| 191 |
+
pass
|
app/config/__init__.py
ADDED
|
File without changes
|
app/{config.py β config/settings.py}
RENAMED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
# Limits & model setting
|
|
@@ -12,3 +14,14 @@ os.makedirs(TMP_DIR, exist_ok=True)
|
|
| 12 |
# Cloud credentials (set as HF Spaces secrets or env)
|
| 13 |
# FIREBASE_SERVICE_ACCOUNT = os.getenv("FIREBASE_SERVICE_ACCOUNT_JSON") # optional
|
| 14 |
# CLOUDINARY_URL = os.getenv("CLOUDINARY_URL") # optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# App settings and configuration
|
| 2 |
+
|
| 3 |
import os
|
| 4 |
|
| 5 |
# Limits & model setting
|
|
|
|
| 14 |
# Cloud credentials (set as HF Spaces secrets or env)
|
| 15 |
# FIREBASE_SERVICE_ACCOUNT = os.getenv("FIREBASE_SERVICE_ACCOUNT_JSON") # optional
|
| 16 |
# CLOUDINARY_URL = os.getenv("CLOUDINARY_URL") # optional
|
| 17 |
+
|
| 18 |
+
# Gemini API Key (for text normalization)
|
| 19 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
|
| 20 |
+
# External services
|
| 21 |
+
NOTE_SERVICE_URL = os.getenv("NOTE_SERVICE_URL", "http://localhost:9000")
|
| 22 |
+
|
| 23 |
+
# HTTP timeouts
|
| 24 |
+
HTTPX_TIMEOUT = float(os.getenv("HTTPX_TIMEOUT", "10.0"))
|
| 25 |
+
|
| 26 |
+
# Redis URL
|
| 27 |
+
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
app/core/__init__.py
ADDED
|
File without changes
|
app/{model.py β core/asr_engine.py}
RENAMED
|
@@ -1,21 +1,14 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
import subprocess
|
| 4 |
-
import os
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import List
|
| 7 |
-
from .audio_utils import make_temp_path, ensure_wav_16k_mono, get_audio_info
|
| 8 |
-
from transformers import pipeline
|
| 9 |
import logging
|
| 10 |
-
from
|
|
|
|
| 11 |
|
| 12 |
_model = None
|
| 13 |
|
| 14 |
def load_model(chunk_length_s: int = None):
|
| 15 |
global _model
|
| 16 |
if _model is None:
|
| 17 |
-
# This will download weights at runtime (from Hugging Face Hub).
|
| 18 |
-
# If you are on HF Spaces, the image will download on first run.
|
| 19 |
logging.info(f"Loading ASR model {MODEL_NAME} ...")
|
| 20 |
kwargs = {}
|
| 21 |
if chunk_length_s is not None:
|
|
@@ -24,57 +17,13 @@ def load_model(chunk_length_s: int = None):
|
|
| 24 |
logging.info("Model loaded")
|
| 25 |
return _model
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
def
|
| 29 |
-
"""
|
| 30 |
-
Extract segment [start, start+duration) using ffmpeg into dst (wav 16k mono pcm16).
|
| 31 |
-
We call ffmpeg with -ss + -t on input for safety.
|
| 32 |
-
"""
|
| 33 |
-
# Note: -ss before -i is fast seek but less accurate for some formats. We use -ss after -i for accuracy.
|
| 34 |
-
cmd = f'ffmpeg -v error -y -ss {start:.3f} -i "{src}" -t {duration:.3f} -ar 16000 -ac 1 -acodec pcm_s16le "{dst}"'
|
| 35 |
-
proc = subprocess.run(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 36 |
-
if proc.returncode != 0:
|
| 37 |
-
raise RuntimeError(f"ffmpeg extract failed: {proc.stderr.decode(errors='ignore')}")
|
| 38 |
-
return dst
|
| 39 |
-
|
| 40 |
-
#_split_audio_to_chunks tαΊ‘o list chunk files.
|
| 41 |
-
def _split_audio_to_chunks(src_wav: str, chunk_length_s: float = 30.0, overlap_s: float = 5.0) -> List[str]:
|
| 42 |
-
"""
|
| 43 |
-
Split src_wav into chunk files (wav 16k mono). Returns list of chunk file paths in order.
|
| 44 |
-
Overlap is seconds overlapping between consecutive chunks.
|
| 45 |
-
"""
|
| 46 |
-
info = get_audio_info(src_wav)
|
| 47 |
-
if not info:
|
| 48 |
-
raise RuntimeError("Cannot read audio info")
|
| 49 |
-
duration = info["duration"]
|
| 50 |
-
step = chunk_length_s - overlap_s
|
| 51 |
-
if step <= 0:
|
| 52 |
-
raise ValueError("chunk_length_s must be > overlap_s")
|
| 53 |
-
starts = []
|
| 54 |
-
t = 0.0
|
| 55 |
-
while t < duration:
|
| 56 |
-
starts.append(t)
|
| 57 |
-
t += step
|
| 58 |
-
# ensure last chunk covers end - if last start + chunk_length < duration, we still create chunk that may be shorter
|
| 59 |
-
chunks = []
|
| 60 |
-
for i, s in enumerate(starts):
|
| 61 |
-
dst = make_temp_path(suffix=f".chunk{i}.wav")
|
| 62 |
-
_ffmpeg_extract_segment(src_wav, s, chunk_length_s, dst)
|
| 63 |
-
chunks.append(dst)
|
| 64 |
-
return chunks
|
| 65 |
-
|
| 66 |
-
#_merge_transcripts lΓ heuristic so khα»p bαΊ±ng tα»« (exact n-gram match) vΓ fallback; khΓ΄ng hoΓ n hαΊ£o nhΖ°ng giαΊ£m phαΊ§n lα»n lαΊ·p do overlap.
|
| 67 |
-
def _merge_transcripts(prev_text: str, new_text: str, max_overlap_words: int = 8) -> str:
|
| 68 |
-
"""
|
| 69 |
-
Heuristic merge: if end of prev_text and start of new_text share overlapping words, remove overlap.
|
| 70 |
-
We find the longest overlap up to max_overlap_words.
|
| 71 |
-
"""
|
| 72 |
if not prev_text:
|
| 73 |
return new_text
|
| 74 |
p_words = prev_text.strip().split()
|
| 75 |
n_words = new_text.strip().split()
|
| 76 |
max_ol = min(max_overlap_words, len(p_words), len(n_words))
|
| 77 |
-
# search for largest k where last k words of prev == first k words of new
|
| 78 |
best_k = 0
|
| 79 |
for k in range(max_ol, 0, -1):
|
| 80 |
if p_words[-k:] == n_words[:k]:
|
|
@@ -83,35 +32,22 @@ def _merge_transcripts(prev_text: str, new_text: str, max_overlap_words: int = 8
|
|
| 83 |
if best_k > 0:
|
| 84 |
merged = " ".join(p_words + n_words[best_k:])
|
| 85 |
return merged
|
| 86 |
-
# If no exact overlap, try fuzzy overlap by matching word sequences (less strict)
|
| 87 |
-
# simple heuristic: if last N words of prev appear anywhere at start of new, trim
|
| 88 |
for k in range(max_ol, 1, -1):
|
| 89 |
seq = " ".join(p_words[-k:])
|
| 90 |
if seq in new_text:
|
| 91 |
idx = new_text.find(seq)
|
| 92 |
-
# remove through the sequence
|
| 93 |
merged = " ".join(p_words + new_text[idx + len(seq):].strip().split())
|
| 94 |
return merged
|
| 95 |
-
# fallback: just concatenate with space
|
| 96 |
return prev_text.rstrip() + " " + new_text.lstrip()
|
| 97 |
|
| 98 |
-
#transcribe_long_audio chαΊ‘y model trΓͺn tα»«ng chunk (tuαΊ§n tα»± theo mαΊ·c Δα»nh); cΓ³ tuα»³ chα»n parallel=True Δα» xα» lΓ½ Δα»ng thα»i (cαΊ§n thαΊn trα»ng vα»i GPU memory).
|
| 99 |
def transcribe_long_audio(model, wav_path: str, chunk_length_s: float = 30.0, overlap_s: float = 5.0, parallel: bool = False) -> str:
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
- wav_path: already normalized wav (16k, mono)
|
| 105 |
-
Returns stitched transcript string.
|
| 106 |
-
"""
|
| 107 |
-
# 1) split into chunks
|
| 108 |
-
chunks = _split_audio_to_chunks(wav_path, chunk_length_s=chunk_length_s, overlap_s=overlap_s)
|
| 109 |
logging.info(f"Split into {len(chunks)} chunks")
|
| 110 |
texts = []
|
| 111 |
-
|
| 112 |
-
# 2) process chunks (sequential for safety). Optionally implement limited parallelism.
|
| 113 |
if parallel:
|
| 114 |
-
# optional: implement ThreadPool/ProcessPool with care for GPU memory; by default we do sequential
|
| 115 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 116 |
def process_chunk(path):
|
| 117 |
try:
|
|
@@ -122,7 +58,7 @@ def transcribe_long_audio(model, wav_path: str, chunk_length_s: float = 30.0, ov
|
|
| 122 |
except Exception as e:
|
| 123 |
logging.exception("Chunk inference failed")
|
| 124 |
return ""
|
| 125 |
-
with ThreadPoolExecutor(max_workers=2) as ex:
|
| 126 |
futures = {ex.submit(process_chunk, c): c for c in chunks}
|
| 127 |
for fut in as_completed(futures):
|
| 128 |
texts.append(fut.result() or "")
|
|
@@ -133,12 +69,9 @@ def transcribe_long_audio(model, wav_path: str, chunk_length_s: float = 30.0, ov
|
|
| 133 |
texts.append(out.get("text", "") or "")
|
| 134 |
else:
|
| 135 |
texts.append(str(out) or "")
|
| 136 |
-
|
| 137 |
-
# 3) merge texts to single transcript, removing overlap duplicates
|
| 138 |
merged = ""
|
| 139 |
for t in texts:
|
| 140 |
-
merged =
|
| 141 |
-
# 4) cleanup chunk files
|
| 142 |
for c in chunks:
|
| 143 |
try:
|
| 144 |
os.remove(c)
|
|
@@ -146,31 +79,24 @@ def transcribe_long_audio(model, wav_path: str, chunk_length_s: float = 30.0, ov
|
|
| 146 |
pass
|
| 147 |
return merged
|
| 148 |
|
| 149 |
-
# transcribe_file quyαΊΏt Δα»nh cΓ³ chunk hay khΓ΄ng.
|
| 150 |
def transcribe_file(model, wav_path: str, max_chunk_length: float = 30.0, overlap_s: float = 5.0):
|
| 151 |
-
|
| 152 |
-
Main entry: if audio short, run single pass; if long, call chunked transcribe.
|
| 153 |
-
"""
|
| 154 |
info = get_audio_info(wav_path) or {}
|
| 155 |
duration = info.get("duration", 0.0)
|
| 156 |
-
|
| 157 |
-
if duration and duration > max_chunk_length * 1.1: # slightly bigger than chunk length
|
| 158 |
logging.info(f"Long audio detected ({duration}s) -> chunking")
|
| 159 |
return transcribe_long_audio(model, wav_path, chunk_length_s=max_chunk_length, overlap_s=overlap_s)
|
| 160 |
-
# short audio -> direct
|
| 161 |
out = model(wav_path)
|
| 162 |
if isinstance(out, dict):
|
| 163 |
return out.get("text") or ""
|
| 164 |
return str(out)
|
| 165 |
|
| 166 |
-
# HΓ m trαΊ£ vα» danh sΓ‘ch dict chα»©a start, end, text cho tα»«ng chunk
|
| 167 |
def transcribe_file_chunks(model, wav_path: str, max_chunk_length: float = 30.0, overlap_s: float = 5.0):
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
info = get_audio_info(wav_path) or {}
|
| 172 |
duration = info.get("duration", 0.0)
|
| 173 |
-
# TΓnh toΓ‘n cΓ‘c mα»c thα»i gian bαΊ―t ΔαΊ§u cho tα»«ng chunk
|
| 174 |
step = max_chunk_length - overlap_s
|
| 175 |
if step <= 0:
|
| 176 |
raise ValueError("max_chunk_length must be > overlap_s")
|
|
@@ -183,7 +109,7 @@ def transcribe_file_chunks(model, wav_path: str, max_chunk_length: float = 30.0,
|
|
| 183 |
for i, s in enumerate(starts):
|
| 184 |
chunk_end = min(s + max_chunk_length, duration)
|
| 185 |
dst = make_temp_path(suffix=f".chunk{i}.wav")
|
| 186 |
-
|
| 187 |
out = model(dst)
|
| 188 |
if isinstance(out, dict):
|
| 189 |
text = out.get("text", "")
|
|
|
|
| 1 |
+
# PhoWhisper inference engine
|
| 2 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import logging
|
| 4 |
+
from transformers import pipeline
|
| 5 |
+
from app.config.settings import MODEL_NAME
|
| 6 |
|
| 7 |
_model = None
|
| 8 |
|
| 9 |
def load_model(chunk_length_s: int = None):
|
| 10 |
global _model
|
| 11 |
if _model is None:
|
|
|
|
|
|
|
| 12 |
logging.info(f"Loading ASR model {MODEL_NAME} ...")
|
| 13 |
kwargs = {}
|
| 14 |
if chunk_length_s is not None:
|
|
|
|
| 17 |
logging.info("Model loaded")
|
| 18 |
return _model
|
| 19 |
|
| 20 |
+
# Heuristic merge for chunked transcripts
|
| 21 |
+
def merge_transcripts(prev_text: str, new_text: str, max_overlap_words: int = 8) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
if not prev_text:
|
| 23 |
return new_text
|
| 24 |
p_words = prev_text.strip().split()
|
| 25 |
n_words = new_text.strip().split()
|
| 26 |
max_ol = min(max_overlap_words, len(p_words), len(n_words))
|
|
|
|
| 27 |
best_k = 0
|
| 28 |
for k in range(max_ol, 0, -1):
|
| 29 |
if p_words[-k:] == n_words[:k]:
|
|
|
|
| 32 |
if best_k > 0:
|
| 33 |
merged = " ".join(p_words + n_words[best_k:])
|
| 34 |
return merged
|
|
|
|
|
|
|
| 35 |
for k in range(max_ol, 1, -1):
|
| 36 |
seq = " ".join(p_words[-k:])
|
| 37 |
if seq in new_text:
|
| 38 |
idx = new_text.find(seq)
|
|
|
|
| 39 |
merged = " ".join(p_words + new_text[idx + len(seq):].strip().split())
|
| 40 |
return merged
|
|
|
|
| 41 |
return prev_text.rstrip() + " " + new_text.lstrip()
|
| 42 |
|
|
|
|
| 43 |
def transcribe_long_audio(model, wav_path: str, chunk_length_s: float = 30.0, overlap_s: float = 5.0, parallel: bool = False) -> str:
|
| 44 |
+
from app.core.chunking import split_audio_to_chunks
|
| 45 |
+
from app.core.audio_utils import make_temp_path
|
| 46 |
+
import os
|
| 47 |
+
chunks = split_audio_to_chunks(wav_path, chunk_length_s=chunk_length_s, overlap_s=overlap_s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
logging.info(f"Split into {len(chunks)} chunks")
|
| 49 |
texts = []
|
|
|
|
|
|
|
| 50 |
if parallel:
|
|
|
|
| 51 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 52 |
def process_chunk(path):
|
| 53 |
try:
|
|
|
|
| 58 |
except Exception as e:
|
| 59 |
logging.exception("Chunk inference failed")
|
| 60 |
return ""
|
| 61 |
+
with ThreadPoolExecutor(max_workers=2) as ex:
|
| 62 |
futures = {ex.submit(process_chunk, c): c for c in chunks}
|
| 63 |
for fut in as_completed(futures):
|
| 64 |
texts.append(fut.result() or "")
|
|
|
|
| 69 |
texts.append(out.get("text", "") or "")
|
| 70 |
else:
|
| 71 |
texts.append(str(out) or "")
|
|
|
|
|
|
|
| 72 |
merged = ""
|
| 73 |
for t in texts:
|
| 74 |
+
merged = merge_transcripts(merged, t, max_overlap_words=12)
|
|
|
|
| 75 |
for c in chunks:
|
| 76 |
try:
|
| 77 |
os.remove(c)
|
|
|
|
| 79 |
pass
|
| 80 |
return merged
|
| 81 |
|
|
|
|
| 82 |
def transcribe_file(model, wav_path: str, max_chunk_length: float = 30.0, overlap_s: float = 5.0):
|
| 83 |
+
from app.core.audio_utils import get_audio_info
|
|
|
|
|
|
|
| 84 |
info = get_audio_info(wav_path) or {}
|
| 85 |
duration = info.get("duration", 0.0)
|
| 86 |
+
if duration and duration > max_chunk_length * 1.1:
|
|
|
|
| 87 |
logging.info(f"Long audio detected ({duration}s) -> chunking")
|
| 88 |
return transcribe_long_audio(model, wav_path, chunk_length_s=max_chunk_length, overlap_s=overlap_s)
|
|
|
|
| 89 |
out = model(wav_path)
|
| 90 |
if isinstance(out, dict):
|
| 91 |
return out.get("text") or ""
|
| 92 |
return str(out)
|
| 93 |
|
|
|
|
| 94 |
def transcribe_file_chunks(model, wav_path: str, max_chunk_length: float = 30.0, overlap_s: float = 5.0):
|
| 95 |
+
from app.core.audio_utils import get_audio_info, make_temp_path
|
| 96 |
+
from app.core.chunking import ffmpeg_extract_segment
|
| 97 |
+
import os
|
| 98 |
info = get_audio_info(wav_path) or {}
|
| 99 |
duration = info.get("duration", 0.0)
|
|
|
|
| 100 |
step = max_chunk_length - overlap_s
|
| 101 |
if step <= 0:
|
| 102 |
raise ValueError("max_chunk_length must be > overlap_s")
|
|
|
|
| 109 |
for i, s in enumerate(starts):
|
| 110 |
chunk_end = min(s + max_chunk_length, duration)
|
| 111 |
dst = make_temp_path(suffix=f".chunk{i}.wav")
|
| 112 |
+
ffmpeg_extract_segment(wav_path, s, chunk_end - s, dst)
|
| 113 |
out = model(dst)
|
| 114 |
if isinstance(out, dict):
|
| 115 |
text = out.get("text", "")
|
app/{audio_utils.py β core/audio_utils.py}
RENAMED
|
@@ -1,23 +1,23 @@
|
|
| 1 |
-
|
| 2 |
import subprocess
|
| 3 |
import shlex
|
| 4 |
import uuid
|
| 5 |
import requests
|
| 6 |
from pathlib import Path
|
| 7 |
import soundfile as sf
|
| 8 |
-
from .config import TMP_DIR, MAX_UPLOAD_BYTES
|
| 9 |
|
| 10 |
def save_upload_file(upload_file, dest_path: str):
|
| 11 |
-
"""Save
|
| 12 |
with open(dest_path, "wb") as f:
|
| 13 |
while True:
|
| 14 |
-
chunk = upload_file.file.read(1024*1024)
|
| 15 |
if not chunk:
|
| 16 |
break
|
| 17 |
f.write(chunk)
|
| 18 |
|
| 19 |
def download_file_from_url(url: str, dest_path: str, timeout=30):
|
| 20 |
-
"""Download remote file to dest_path."""
|
| 21 |
r = requests.get(url, stream=True, timeout=timeout)
|
| 22 |
r.raise_for_status()
|
| 23 |
total = 0
|
|
@@ -30,26 +30,37 @@ def download_file_from_url(url: str, dest_path: str, timeout=30):
|
|
| 30 |
f.write(chunk)
|
| 31 |
|
| 32 |
def get_audio_info(path: str):
|
| 33 |
-
"""Return duration (s)
|
| 34 |
try:
|
| 35 |
info = sf.info(path)
|
| 36 |
duration = info.frames / info.samplerate
|
| 37 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
except Exception:
|
| 39 |
return None
|
| 40 |
|
| 41 |
def ensure_wav_16k_mono(src_path: str, dest_path: str):
|
| 42 |
"""
|
| 43 |
-
|
| 44 |
-
Returns dest_path if ok, raises exception on error.
|
| 45 |
"""
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
if proc.returncode != 0:
|
| 51 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 52 |
return dest_path
|
| 53 |
|
| 54 |
def make_temp_path(suffix=".wav"):
|
|
|
|
| 55 |
return str(Path(TMP_DIR) / f"{uuid.uuid4().hex}{suffix}")
|
|
|
|
| 1 |
+
# Audio utilities: ffmpeg, normalization, etc.
|
| 2 |
import subprocess
|
| 3 |
import shlex
|
| 4 |
import uuid
|
| 5 |
import requests
|
| 6 |
from pathlib import Path
|
| 7 |
import soundfile as sf
|
| 8 |
+
from app.config.settings import TMP_DIR, MAX_UPLOAD_BYTES
|
| 9 |
|
| 10 |
def save_upload_file(upload_file, dest_path: str):
|
| 11 |
+
"""Save FastAPI UploadFile to dest_path (streaming)."""
|
| 12 |
with open(dest_path, "wb") as f:
|
| 13 |
while True:
|
| 14 |
+
chunk = upload_file.file.read(1024 * 1024)
|
| 15 |
if not chunk:
|
| 16 |
break
|
| 17 |
f.write(chunk)
|
| 18 |
|
| 19 |
def download_file_from_url(url: str, dest_path: str, timeout=30):
|
| 20 |
+
"""Download remote file to dest_path with size limit."""
|
| 21 |
r = requests.get(url, stream=True, timeout=timeout)
|
| 22 |
r.raise_for_status()
|
| 23 |
total = 0
|
|
|
|
| 30 |
f.write(chunk)
|
| 31 |
|
| 32 |
def get_audio_info(path: str):
|
| 33 |
+
"""Return duration (s), sample_rate, channels using soundfile."""
|
| 34 |
try:
|
| 35 |
info = sf.info(path)
|
| 36 |
duration = info.frames / info.samplerate
|
| 37 |
+
return {
|
| 38 |
+
"duration": duration,
|
| 39 |
+
"samplerate": info.samplerate,
|
| 40 |
+
"channels": info.channels,
|
| 41 |
+
}
|
| 42 |
except Exception:
|
| 43 |
return None
|
| 44 |
|
| 45 |
def ensure_wav_16k_mono(src_path: str, dest_path: str):
|
| 46 |
"""
|
| 47 |
+
Convert any audio to WAV PCM16, 16kHz, mono using ffmpeg.
|
|
|
|
| 48 |
"""
|
| 49 |
+
cmd = (
|
| 50 |
+
f'ffmpeg -v error -y -i "{src_path}" '
|
| 51 |
+
f'-ar 16000 -ac 1 -acodec pcm_s16le "{dest_path}"'
|
| 52 |
+
)
|
| 53 |
+
proc = subprocess.run(
|
| 54 |
+
shlex.split(cmd),
|
| 55 |
+
stdout=subprocess.PIPE,
|
| 56 |
+
stderr=subprocess.PIPE,
|
| 57 |
+
)
|
| 58 |
if proc.returncode != 0:
|
| 59 |
+
raise RuntimeError(
|
| 60 |
+
f"ffmpeg convert failed: {proc.stderr.decode(errors='ignore')}"
|
| 61 |
+
)
|
| 62 |
return dest_path
|
| 63 |
|
| 64 |
def make_temp_path(suffix=".wav"):
|
| 65 |
+
"""Generate unique temp file path under TMP_DIR."""
|
| 66 |
return str(Path(TMP_DIR) / f"{uuid.uuid4().hex}{suffix}")
|
app/core/chunking.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Audio chunking/splitting/merging logic
|
| 2 |
+
|
| 3 |
+
import shlex
|
| 4 |
+
import subprocess
|
| 5 |
+
from typing import List
|
| 6 |
+
from app.core.audio_utils import get_audio_info, make_temp_path
|
| 7 |
+
|
| 8 |
+
def ffmpeg_extract_segment(src: str, start: float, duration: float, dst: str):
|
| 9 |
+
"""
|
| 10 |
+
Extract segment [start, start+duration) using ffmpeg into dst (wav 16k mono pcm16).
|
| 11 |
+
"""
|
| 12 |
+
cmd = f'ffmpeg -v error -y -ss {start:.3f} -i "{src}" -t {duration:.3f} -ar 16000 -ac 1 -acodec pcm_s16le "{dst}"'
|
| 13 |
+
proc = subprocess.run(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 14 |
+
if proc.returncode != 0:
|
| 15 |
+
raise RuntimeError(f"ffmpeg extract failed: {proc.stderr.decode(errors='ignore')}")
|
| 16 |
+
return dst
|
| 17 |
+
|
| 18 |
+
def split_audio_to_chunks(src_wav: str, chunk_length_s: float = 30.0, overlap_s: float = 5.0) -> List[str]:
|
| 19 |
+
info = get_audio_info(src_wav)
|
| 20 |
+
if not info:
|
| 21 |
+
raise RuntimeError("Cannot read audio info")
|
| 22 |
+
duration = info["duration"]
|
| 23 |
+
step = chunk_length_s - overlap_s
|
| 24 |
+
if step <= 0:
|
| 25 |
+
raise ValueError("chunk_length_s must be > overlap_s")
|
| 26 |
+
starts = []
|
| 27 |
+
t = 0.0
|
| 28 |
+
while t < duration:
|
| 29 |
+
starts.append(t)
|
| 30 |
+
t += step
|
| 31 |
+
chunks = []
|
| 32 |
+
for i, s in enumerate(starts):
|
| 33 |
+
chunk_path = make_temp_path(suffix=f"_chunk{i}.wav")
|
| 34 |
+
ffmpeg_extract_segment(src_wav, s, min(chunk_length_s, duration - s), chunk_path)
|
| 35 |
+
chunks.append(chunk_path)
|
| 36 |
+
return chunks
|
app/infra/metrics.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from prometheus_client import Counter, Histogram
|
| 2 |
+
|
| 3 |
+
REQUEST_COUNT = Counter(
|
| 4 |
+
"asr_requests_total",
|
| 5 |
+
"Total ASR requests",
|
| 6 |
+
["endpoint", "status"]
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
REQUEST_LATENCY = Histogram(
|
| 11 |
+
"asr_request_latency_seconds",
|
| 12 |
+
"ASR request latency",
|
| 13 |
+
["endpoint"]
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
ASR_DURATION = Histogram(
|
| 17 |
+
"asr_model_duration_seconds",
|
| 18 |
+
"ASR model inference duration",
|
| 19 |
+
["endpoint"]
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
NORMALIZE_DURATION = Histogram(
|
| 23 |
+
"normalize_duration_seconds",
|
| 24 |
+
"Text normalization duration",
|
| 25 |
+
["endpoint"]
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
ERROR_COUNT = Counter(
|
| 29 |
+
"asr_error_total",
|
| 30 |
+
"Total ASR errors",
|
| 31 |
+
["endpoint", "error_type"]
|
| 32 |
+
)
|
app/infra/redis_client.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import redis
|
| 3 |
+
from app.config.settings import REDIS_URL
|
| 4 |
+
|
| 5 |
+
redis_client = redis.Redis.from_url(
|
| 6 |
+
REDIS_URL,
|
| 7 |
+
decode_responses=True
|
| 8 |
+
)
|
app/jobs/transcribe_job.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.core.asr_engine import load_model, transcribe_file
|
| 2 |
+
from app.services.text_normalizer import normalize_text
|
| 3 |
+
from app.services.note_client import NoteServiceClient
|
| 4 |
+
|
| 5 |
+
# This function will be run by RQ worker
|
| 6 |
+
def transcribe_job(tmp_wav: str, note_id: str):
|
| 7 |
+
model = load_model()
|
| 8 |
+
text = transcribe_file(model, tmp_wav, 30.0, 5.0)
|
| 9 |
+
# normalize_text cΓ³ thα» lΓ async, nhΖ°ng RQ chα» chαΊ‘y sync nΓͺn cαΊ§n chαΊ‘y event loop nαΊΏu cαΊ§n
|
| 10 |
+
import asyncio
|
| 11 |
+
if asyncio.iscoroutinefunction(normalize_text):
|
| 12 |
+
normalized = asyncio.run(normalize_text(text))
|
| 13 |
+
else:
|
| 14 |
+
normalized = normalize_text(text)
|
| 15 |
+
note_service = NoteServiceClient()
|
| 16 |
+
# Gα»i transcript sang Note Service
|
| 17 |
+
note_service.save_transcript(
|
| 18 |
+
note_id=note_id,
|
| 19 |
+
raw_text=text,
|
| 20 |
+
normalized_text=normalized,
|
| 21 |
+
duration=None,
|
| 22 |
+
sample_rate=None,
|
| 23 |
+
chunks=None,
|
| 24 |
+
asr_model="PhoWhisper-base",
|
| 25 |
+
normalization_model="gemini-1.5"
|
| 26 |
+
)
|
| 27 |
+
return True
|
app/main.py
CHANGED
|
@@ -1,16 +1,24 @@
|
|
| 1 |
-
|
| 2 |
-
import
|
| 3 |
-
from
|
| 4 |
-
|
| 5 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
-
from pathlib import Path
|
| 7 |
import logging
|
| 8 |
-
from .
|
| 9 |
-
from .
|
| 10 |
-
from .
|
|
|
|
| 11 |
|
| 12 |
app = FastAPI(title="PhoWhisper ASR API")
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# CORS β tighten in prod
|
| 15 |
app.add_middleware(
|
| 16 |
CORSMiddleware,
|
|
@@ -19,99 +27,22 @@ app.add_middleware(
|
|
| 19 |
allow_headers=["*"],
|
| 20 |
)
|
| 21 |
|
| 22 |
-
# load model lazily on first request to avoid startup heavy cost
|
| 23 |
-
MODEL = None
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
# optionally pre-load model (comment if you want lazy)
|
| 30 |
-
# global MODEL
|
| 31 |
-
# MODEL = load_model()
|
| 32 |
-
logging.info("API startup complete")
|
| 33 |
|
|
|
|
| 34 |
@app.get("/health")
|
| 35 |
def health():
|
| 36 |
return {"status": "ok"}
|
| 37 |
|
| 38 |
-
@app.post("/transcribe")
|
| 39 |
-
async def transcribe(file: UploadFile = File(...)):
|
| 40 |
-
# 1. basic size check: UploadFile does not expose size easily, so stream-write and check
|
| 41 |
-
tmp_in = make_temp_path(suffix=Path(file.filename).suffix or ".wav")
|
| 42 |
-
try:
|
| 43 |
-
save_upload_file(file, tmp_in)
|
| 44 |
-
if os.path.getsize(tmp_in) > MAX_UPLOAD_BYTES:
|
| 45 |
-
raise HTTPException(status_code=413, detail="File size exceeds limit")
|
| 46 |
-
info = get_audio_info(tmp_in)
|
| 47 |
-
if info and info.get("duration") and info["duration"] > MAX_DURATION_SECS:
|
| 48 |
-
raise HTTPException(status_code=413, detail="Audio duration exceeds limit")
|
| 49 |
-
# 2. normalize
|
| 50 |
-
tmp_wav = make_temp_path(suffix=".wav")
|
| 51 |
-
ensure_wav_16k_mono(tmp_in, tmp_wav)
|
| 52 |
-
# 3. load model if needed
|
| 53 |
-
global MODEL
|
| 54 |
-
if MODEL is None:
|
| 55 |
-
MODEL = load_model(chunk_length_s=30)
|
| 56 |
-
text = transcribe_file(MODEL, tmp_wav, max_chunk_length=30.0, overlap_s=5.0)
|
| 57 |
-
chunks = transcribe_file_chunks(MODEL, tmp_wav, max_chunk_length=30.0, overlap_s=5.0)
|
| 58 |
-
info2 = get_audio_info(tmp_wav) or {}
|
| 59 |
-
return JSONResponse({
|
| 60 |
-
"text": text,
|
| 61 |
-
"duration": info2.get("duration"),
|
| 62 |
-
"sample_rate": info2.get("samplerate"),
|
| 63 |
-
"chunks": chunks
|
| 64 |
-
})
|
| 65 |
-
except HTTPException:
|
| 66 |
-
raise
|
| 67 |
-
except Exception as e:
|
| 68 |
-
logging.exception("Transcribe failed")
|
| 69 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 70 |
-
finally:
|
| 71 |
-
# cleanup temp files
|
| 72 |
-
for p in [tmp_in, locals().get("tmp_wav")]:
|
| 73 |
-
try:
|
| 74 |
-
if p and os.path.exists(p):
|
| 75 |
-
os.remove(p)
|
| 76 |
-
except Exception:
|
| 77 |
-
pass
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
download_file_from_url(audio_url, tmp_in)
|
| 87 |
-
if os.path.getsize(tmp_in) > MAX_UPLOAD_BYTES:
|
| 88 |
-
raise HTTPException(status_code=413, detail="File size exceeds limit")
|
| 89 |
-
info = get_audio_info(tmp_in)
|
| 90 |
-
if info and info.get("duration") and info["duration"] > MAX_DURATION_SECS:
|
| 91 |
-
raise HTTPException(status_code=413, detail="Audio duration exceeds limit")
|
| 92 |
-
tmp_wav = make_temp_path(suffix=".wav")
|
| 93 |
-
ensure_wav_16k_mono(tmp_in, tmp_wav)
|
| 94 |
-
global MODEL
|
| 95 |
-
if MODEL is None:
|
| 96 |
-
MODEL = load_model(chunk_length_s=30)
|
| 97 |
-
text = transcribe_file(MODEL, tmp_wav, max_chunk_length=30.0, overlap_s=5.0)
|
| 98 |
-
chunks = transcribe_file_chunks(MODEL, tmp_wav, max_chunk_length=30.0, overlap_s=5.0)
|
| 99 |
-
info2 = get_audio_info(tmp_wav) or {}
|
| 100 |
-
return JSONResponse({
|
| 101 |
-
"text": text,
|
| 102 |
-
"duration": info2.get("duration"),
|
| 103 |
-
"sample_rate": info2.get("samplerate"),
|
| 104 |
-
"chunks": chunks
|
| 105 |
-
})
|
| 106 |
-
except HTTPException:
|
| 107 |
-
raise
|
| 108 |
-
except Exception as e:
|
| 109 |
-
logging.exception("Transcribe-url failed")
|
| 110 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 111 |
-
finally:
|
| 112 |
-
for p in [tmp_in, locals().get("tmp_wav")]:
|
| 113 |
-
try:
|
| 114 |
-
if p and os.path.exists(p):
|
| 115 |
-
os.remove(p)
|
| 116 |
-
except Exception:
|
| 117 |
-
pass
|
|
|
|
| 1 |
+
|
| 2 |
+
from fastapi import FastAPI, Response
|
| 3 |
+
from prometheus_client import generate_latest
|
| 4 |
+
import asyncio
|
|
|
|
|
|
|
| 5 |
import logging
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from app.api.transcribe import router as transcribe_router
|
| 8 |
+
from app.core.asr_engine import load_model
|
| 9 |
+
|
| 10 |
|
| 11 |
app = FastAPI(title="PhoWhisper ASR API")
|
| 12 |
|
| 13 |
+
# Preload ASR model at startup
|
| 14 |
+
@app.on_event("startup")
|
| 15 |
+
async def preload_asr_model():
|
| 16 |
+
# Load model in thread to avoid blocking event loop
|
| 17 |
+
logging.info("Preloading ASR model at startup...")
|
| 18 |
+
await asyncio.to_thread(load_model, 30)
|
| 19 |
+
logging.info("ASR model preloaded.")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
# CORS β tighten in prod
|
| 23 |
app.add_middleware(
|
| 24 |
CORSMiddleware,
|
|
|
|
| 27 |
allow_headers=["*"],
|
| 28 |
)
|
| 29 |
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
# --- OLD LOGIC: ΔΓ£ chuyα»n sang app/api/transcribe.py ---
|
| 32 |
+
# - Δα»nh nghΔ©a endpoint trα»±c tiαΊΏp
|
| 33 |
+
# - Chα»©a toΓ n bα» logic xα» lΓ½
|
| 34 |
+
# - ΔΓ£ refactor thΓ nh router riΓͺng vΓ tΓ‘ch core/service
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
# Health check (cΓ³ thα» giα»― lαΊ‘i nαΊΏu muα»n)
|
| 37 |
@app.get("/health")
|
| 38 |
def health():
|
| 39 |
return {"status": "ok"}
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
# Expose /metrics endpoint for Prometheus
|
| 43 |
+
@app.get("/metrics")
|
| 44 |
+
def metrics():
|
| 45 |
+
return Response(generate_latest(), media_type="text/plain")
|
| 46 |
+
|
| 47 |
+
# Include API routers
|
| 48 |
+
app.include_router(transcribe_router)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/schemas/__init__.py
ADDED
|
File without changes
|
app/schemas/transcribe.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Request/Response models for transcription
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
|
| 6 |
+
class Chunk(BaseModel):
|
| 7 |
+
start: float
|
| 8 |
+
end: float
|
| 9 |
+
text: str
|
| 10 |
+
|
| 11 |
+
class TranscribeResponse(BaseModel):
|
| 12 |
+
note_id: str
|
| 13 |
+
status: str
|
| 14 |
+
duration: Optional[float] = None
|
app/services/__init__.py
ADDED
|
File without changes
|
app/services/note_client.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import httpx
|
| 2 |
+
from app.config.settings import NOTE_SERVICE_URL, HTTPX_TIMEOUT
|
| 3 |
+
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception
|
| 4 |
+
|
| 5 |
+
class NoteServiceClient:
|
| 6 |
+
def __init__(self, base_url: str = None):
|
| 7 |
+
self.base_url = (base_url or NOTE_SERVICE_URL).rstrip("/")
|
| 8 |
+
|
| 9 |
+
@retry(
|
| 10 |
+
stop=stop_after_attempt(3),
|
| 11 |
+
wait=wait_exponential(multiplier=1, min=1, max=8),
|
| 12 |
+
reraise=True,
|
| 13 |
+
retry=retry_if_exception(
|
| 14 |
+
lambda e: (
|
| 15 |
+
isinstance(e, httpx.RequestError) or
|
| 16 |
+
(isinstance(e, httpx.HTTPStatusError) and 500 <= e.response.status_code < 600)
|
| 17 |
+
)
|
| 18 |
+
)
|
| 19 |
+
)
|
| 20 |
+
async def save_transcript(self, note_id: str, raw_text: str, normalized_text: str,
|
| 21 |
+
duration: float, sample_rate: int, chunks: list,
|
| 22 |
+
asr_model: str = "PhoWhisper-base",
|
| 23 |
+
normalization_model: str = "gemini-1.5"):
|
| 24 |
+
url = f"{self.base_url}/notes/{note_id}/transcript"
|
| 25 |
+
payload = {
|
| 26 |
+
"raw_text": raw_text,
|
| 27 |
+
"normalized_text": normalized_text,
|
| 28 |
+
"duration": duration,
|
| 29 |
+
"sample_rate": sample_rate,
|
| 30 |
+
"chunks": chunks,
|
| 31 |
+
"asr_model": asr_model,
|
| 32 |
+
"normalization_model": normalization_model
|
| 33 |
+
}
|
| 34 |
+
timeout = httpx.Timeout(HTTPX_TIMEOUT)
|
| 35 |
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
| 36 |
+
try:
|
| 37 |
+
resp = await client.post(url, json=payload)
|
| 38 |
+
resp.raise_for_status()
|
| 39 |
+
return resp.json()
|
| 40 |
+
except httpx.HTTPStatusError as e:
|
| 41 |
+
# Chα» retry nαΊΏu lΓ 5xx
|
| 42 |
+
if 500 <= e.response.status_code < 600:
|
| 43 |
+
raise
|
| 44 |
+
else:
|
| 45 |
+
raise
|
| 46 |
+
except httpx.RequestError as e:
|
| 47 |
+
# Retry network errors
|
| 48 |
+
raise
|
app/services/text_normalizer.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.infra.redis_client import redis_client
|
| 2 |
+
from app.utils.hashing import sha256
|
| 3 |
+
|
| 4 |
+
CACHE_TTL = 60 * 60 * 24 * 3 # 3 days
|
| 5 |
+
|
| 6 |
+
# Simple in-memory cache (cΓ³ thα» thay bαΊ±ng Redis, v.v. sau nΓ y)
|
| 7 |
+
# _normalize_cache = {}
|
| 8 |
+
|
| 9 |
+
# --- Gemini client thα»±c tαΊΏ (Google GenerativeAI) ---
|
| 10 |
+
|
| 11 |
+
import google.generativeai as genai
|
| 12 |
+
from app.config.settings import GEMINI_API_KEY
|
| 13 |
+
|
| 14 |
+
if GEMINI_API_KEY:
|
| 15 |
+
genai.configure(api_key=GEMINI_API_KEY)
|
| 16 |
+
_gemini_model = genai.GenerativeModel("gemini-pro")
|
| 17 |
+
else:
|
| 18 |
+
_gemini_model = None
|
| 19 |
+
|
| 20 |
+
async def normalize_text(raw_text: str) -> str:
|
| 21 |
+
cache_key = f"normalize:{sha256(raw_text)}"
|
| 22 |
+
cached = redis_client.get(cache_key)
|
| 23 |
+
if cached:
|
| 24 |
+
return cached
|
| 25 |
+
|
| 26 |
+
prompt = f"""
|
| 27 |
+
BαΊ‘n lΓ hα» thα»ng chuαΊ©n hΓ³a transcript tiαΊΏng Viα»t.
|
| 28 |
+
- KHΓNG thΓͺm Γ½ mα»i
|
| 29 |
+
- Giα»― nguyΓͺn nα»i dung
|
| 30 |
+
- Chα» sα»a chΓnh tαΊ£, dαΊ₯u cΓ’u, xuα»ng dΓ²ng hợp lΓ½
|
| 31 |
+
|
| 32 |
+
VΔn bαΊ£n:
|
| 33 |
+
{raw_text}
|
| 34 |
+
"""
|
| 35 |
+
result = raw_text
|
| 36 |
+
if _gemini_model:
|
| 37 |
+
# Google GenerativeAI Gemini API (synchronous, wrap in thread for async)
|
| 38 |
+
import asyncio
|
| 39 |
+
loop = asyncio.get_event_loop()
|
| 40 |
+
def call_gemini():
|
| 41 |
+
response = _gemini_model.generate_content(prompt)
|
| 42 |
+
return response.text.strip() if hasattr(response, 'text') else str(response)
|
| 43 |
+
result = await loop.run_in_executor(None, call_gemini)
|
| 44 |
+
else:
|
| 45 |
+
# NαΊΏu chΖ°a cαΊ₯u hΓ¬nh Gemini, trαΊ£ vα» text gα»c
|
| 46 |
+
result = raw_text
|
| 47 |
+
result = result.strip()
|
| 48 |
+
redis_client.setex(cache_key, CACHE_TTL, result)
|
| 49 |
+
return result
|
app/utils/hashing.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Hashing utilities for cache keys, helpers
|
| 3 |
+
|
| 4 |
+
import hashlib
|
| 5 |
+
|
| 6 |
+
def sha256(text: str) -> str:
|
| 7 |
+
return hashlib.sha256(text.encode('utf-8')).hexdigest()
|
requirements.txt
CHANGED
|
@@ -5,5 +5,11 @@ torch
|
|
| 5 |
soundfile
|
| 6 |
python-multipart
|
| 7 |
requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
numpy
|
| 9 |
pytest
|
|
|
|
| 5 |
soundfile
|
| 6 |
python-multipart
|
| 7 |
requests
|
| 8 |
+
httpx
|
| 9 |
+
redis
|
| 10 |
+
rq
|
| 11 |
+
tenacity
|
| 12 |
+
prometheus-client
|
| 13 |
+
google-generativeai
|
| 14 |
numpy
|
| 15 |
pytest
|
test/conftest.py
CHANGED
|
@@ -1,57 +1,11 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import soundfile as sf
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
import os
|
| 5 |
-
import tempfile
|
| 6 |
-
import re
|
| 7 |
import pytest
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
def
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def generate_silence_wav(path: str, duration_s: float, sr: int = 16000):
|
| 17 |
-
import numpy as np
|
| 18 |
-
data = np.zeros(int(sr * duration_s), dtype=float)
|
| 19 |
-
sf.write(path, data, sr, subtype='PCM_16')
|
| 20 |
-
|
| 21 |
-
@pytest.fixture
|
| 22 |
-
def tmp_audio_dir(tmp_path):
|
| 23 |
-
d = tmp_path / "audio"
|
| 24 |
-
d.mkdir()
|
| 25 |
-
return d
|
| 26 |
-
|
| 27 |
-
# Fake model: examine file path, if path contains ".chunk{N}." then return predictable text.
|
| 28 |
-
# Example chunk file name "...chunk3.wav" -> "chunk3_text"
|
| 29 |
-
def fake_model_from_path(path: str):
|
| 30 |
-
name = str(path)
|
| 31 |
-
m = re.search(r"chunk(\d+)", name)
|
| 32 |
-
if m:
|
| 33 |
-
i = int(m.group(1))
|
| 34 |
-
# produce a text designed to create overlaps:
|
| 35 |
-
# chunk0 -> "alpha beta gamma"
|
| 36 |
-
# chunk1 -> "beta gamma delta"
|
| 37 |
-
# chunk2 -> "gamma delta epsilon"
|
| 38 |
-
# general pattern derived from i
|
| 39 |
-
words = [
|
| 40 |
-
["alpha", "beta", "gamma"],
|
| 41 |
-
["beta", "gamma", "delta"],
|
| 42 |
-
["gamma", "delta", "epsilon"],
|
| 43 |
-
["delta", "epsilon", "zeta"],
|
| 44 |
-
]
|
| 45 |
-
w = words[i % len(words)]
|
| 46 |
-
return {"text": " ".join(w)}
|
| 47 |
-
# fallback: return a simple marker based on filename
|
| 48 |
-
return {"text": Path(name).stem}
|
| 49 |
-
|
| 50 |
-
# A fake pipeline object: callable that takes a file path and returns dict {"text": ...}
|
| 51 |
-
class FakePipeline:
|
| 52 |
-
def __call__(self, path):
|
| 53 |
-
return fake_model_from_path(path)
|
| 54 |
-
|
| 55 |
-
@pytest.fixture
|
| 56 |
-
def fake_pipeline():
|
| 57 |
-
return FakePipeline()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import pytest
|
| 2 |
+
import tempfile
|
| 3 |
+
import os
|
| 4 |
|
| 5 |
+
@pytest.fixture(autouse=True)
|
| 6 |
+
def mock_env(monkeypatch):
|
| 7 |
+
monkeypatch.setenv("TMP_DIR", tempfile.gettempdir())
|
| 8 |
+
monkeypatch.setenv("MAX_UPLOAD_BYTES", "1048576")
|
| 9 |
+
monkeypatch.setenv("MAX_DURATION_SECS", "3600")
|
| 10 |
+
monkeypatch.setenv("NOTE_SERVICE_URL", "http://note")
|
| 11 |
+
monkeypatch.setenv("REDIS_URL", "redis://localhost:6379/0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test/test_long_performance.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
import pytest
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
from conftest import generate_sine_wav
|
| 6 |
-
from app.model import transcribe_file
|
| 7 |
-
|
| 8 |
-
@pytest.mark.skipif(os.getenv("RUN_LONG_TESTS", "0") != "1", reason="long tests disabled by default")
|
| 9 |
-
def test_very_long_audio_runtime(tmp_path, fake_pipeline):
|
| 10 |
-
# create 10min (600s) wav β heavy; enabled only if RUN_LONG_TESTS=1
|
| 11 |
-
p = tmp_path / "very_long.wav"
|
| 12 |
-
generate_sine_wav(str(p), duration_s=600.0) # 10 minutes
|
| 13 |
-
# measure time per minute
|
| 14 |
-
start = time.time()
|
| 15 |
-
text = transcribe_file(fake_pipeline, str(p), max_chunk_length=30.0, overlap_s=5.0)
|
| 16 |
-
elapsed = time.time() - start
|
| 17 |
-
# compute approx seconds per minute of audio processed
|
| 18 |
-
avg_sec_per_min = elapsed / (600.0 / 60.0)
|
| 19 |
-
print(f"Elapsed {elapsed:.2f}s; avg seconds per audio-minute: {avg_sec_per_min:.2f}")
|
| 20 |
-
# Basic assert: completed and returned a string
|
| 21 |
-
assert isinstance(text, str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test/test_short_and_chunk.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
import pytest
|
| 4 |
-
from app.model import transcribe_file, _split_audio_to_chunks, _merge_transcripts
|
| 5 |
-
from conftest import generate_sine_wav
|
| 6 |
-
|
| 7 |
-
def test_short_audio_direct(tmp_path, fake_pipeline):
|
| 8 |
-
# create short wav: 5s
|
| 9 |
-
p = tmp_path / "short.wav"
|
| 10 |
-
generate_sine_wav(str(p), duration_s=5.0)
|
| 11 |
-
# call transcribe_file with chunk threshold 30s -> should not chunk
|
| 12 |
-
text = transcribe_file(fake_pipeline, str(p), max_chunk_length=30.0, overlap_s=5.0)
|
| 13 |
-
# fake pipeline returns filename-stem as text for non-chunk files
|
| 14 |
-
assert "short" in text or len(text) > 0
|
| 15 |
-
|
| 16 |
-
def test_chunk_split_and_merge(tmp_path):
|
| 17 |
-
# Create audio ~75s to force chunking into 3 chunks with (L=30, O=5) -> starts 0,25,50
|
| 18 |
-
p = tmp_path / "long75.wav"
|
| 19 |
-
# 75s sine; note: generating >60s may be heavy; for CI shorten if needed
|
| 20 |
-
from conftest import generate_sine_wav
|
| 21 |
-
generate_sine_wav(str(p), duration_s=75.0)
|
| 22 |
-
# Use internal split function to inspect chunking
|
| 23 |
-
chunks = _split_audio_to_chunks(str(p), chunk_length_s=30.0, overlap_s=5.0)
|
| 24 |
-
# Expect at least 3 chunks
|
| 25 |
-
assert len(chunks) >= 3
|
| 26 |
-
# Simulate pipeline outputs for each chunk like in fake_model_from_path
|
| 27 |
-
simulated_texts = []
|
| 28 |
-
for idx, c in enumerate(chunks):
|
| 29 |
-
# derive same pattern as fake_model_from_path: chunk{i} -> list words
|
| 30 |
-
# We just test merging behavior: create overlapping words
|
| 31 |
-
if idx == 0:
|
| 32 |
-
simulated_texts.append("alpha beta gamma")
|
| 33 |
-
elif idx == 1:
|
| 34 |
-
simulated_texts.append("beta gamma delta")
|
| 35 |
-
elif idx == 2:
|
| 36 |
-
simulated_texts.append("gamma delta epsilon")
|
| 37 |
-
else:
|
| 38 |
-
simulated_texts.append(f"chunk{idx}")
|
| 39 |
-
# Merge one by one
|
| 40 |
-
merged = ""
|
| 41 |
-
for t in simulated_texts:
|
| 42 |
-
merged = _merge_transcripts(merged, t, max_overlap_words=5)
|
| 43 |
-
# After merging, no duplicate immediate sequences like "beta gamma beta gamma"
|
| 44 |
-
assert "beta gamma beta" not in merged
|
| 45 |
-
# Ensure merged contains parts from all chunks
|
| 46 |
-
assert "alpha" in merged and "epsilon" in merged
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test/test_silence_and_overlap.py
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
from app.model import _merge_transcripts
|
| 2 |
-
import pytest
|
| 3 |
-
|
| 4 |
-
def test_silence_edge_overlap():
|
| 5 |
-
# Simulate chunk outputs where chunk1 ends with filler repeated and chunk2 starts with filler
|
| 6 |
-
a = "hello um um"
|
| 7 |
-
b = "um um good morning"
|
| 8 |
-
merged = _merge_transcripts(a, b, max_overlap_words=4)
|
| 9 |
-
# Should not create triple 'um' (heuristic will remove overlap)
|
| 10 |
-
assert "um um um" not in merged
|
| 11 |
-
# Should still contain core words
|
| 12 |
-
assert "hello" in merged and "good" in merged
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|