bichnhan2701 commited on
Commit
f84ed9c
Β·
1 Parent(s): efea087

Change logic pipeline

Browse files
.dockerignore CHANGED
@@ -1,2 +1,10 @@
1
- tests/
2
- *.md
 
 
 
 
 
 
 
 
 
1
+ test/
2
+ *.md
3
+ .myvenv
4
+ __pycache__
5
+ *.pyc
6
+ .DS_Store
7
+ .git
8
+ .vscode
9
+ .idea
10
+ docker-compose.yml
.gitignore CHANGED
@@ -1,6 +1,7 @@
1
- tests/
2
  .myvenv/
3
  __pycache__/
4
  *.pyc
5
  .env
6
- *.md
 
 
1
+ test/
2
  .myvenv/
3
  __pycache__/
4
  *.pyc
5
  .env
6
+ *.md
7
+ docker-compose.yml
Dockerfile CHANGED
@@ -1,22 +1,34 @@
1
- FROM python:3.10-slim
2
 
3
- # install system deps
 
 
 
 
 
4
  RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
5
- ffmpeg libsndfile1 git build-essential wget && \
6
  rm -rf /var/lib/apt/lists/*
7
 
8
  WORKDIR /app
9
 
10
- # copy requirements first for layer caching
11
  COPY requirements.txt /app/requirements.txt
12
- RUN pip install --upgrade pip
13
- RUN pip install --no-cache-dir -r /app/requirements.txt
14
 
15
  # copy app code
16
  COPY . /app
17
 
18
- ENV PORT=7860
19
- EXPOSE 7860
 
 
 
 
 
 
 
 
20
 
21
- # default command
22
- CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
 
1
+ FROM python:3.11-slim
2
 
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ TMP_DIR=/tmp/uploads \
6
+ PORT=7860
7
+
8
+ # system deps (single RUN to minimize layers)
9
  RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
10
+ ffmpeg libsndfile1 git build-essential wget curl && \
11
  rm -rf /var/lib/apt/lists/*
12
 
13
  WORKDIR /app
14
 
15
+ # install python deps using cached layer
16
  COPY requirements.txt /app/requirements.txt
17
+ RUN pip install --upgrade pip && \
18
+ pip install --no-cache-dir -r /app/requirements.txt
19
 
20
  # copy app code
21
  COPY . /app
22
 
23
+ # create tmp dir and non-root user
24
+ RUN mkdir -p ${TMP_DIR} && groupadd -r app && useradd -r -g app app && \
25
+ chown -R app:app /app ${TMP_DIR}
26
+
27
+ USER app
28
+
29
+ EXPOSE ${PORT}
30
+
31
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=10s \
32
+ CMD curl -f http://localhost:${PORT}/health || exit 1
33
 
34
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
 
app/api/__init__.py ADDED
File without changes
app/api/transcribe.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import uuid
4
+ import asyncio
5
+ from fastapi import APIRouter, UploadFile, File, HTTPException, status
6
+ from fastapi.responses import JSONResponse
7
+ from pathlib import Path
8
+ from typing import Optional
9
+ import time
10
+ from app.core.audio_utils import save_upload_file, get_audio_info, ensure_wav_16k_mono, make_temp_path, download_file_from_url
11
+ from app.core.asr_engine import load_model, transcribe_file, transcribe_file_chunks
12
+ from app.config import settings
13
+ from app.services.text_normalizer import normalize_text
14
+ from app.services.note_client import NoteServiceClient
15
+ from rq import Queue
16
+ from app.infra.redis_client import redis_client
17
+ from app.jobs.transcribe_job import transcribe_job
18
+ from app.schemas.transcribe import TranscribeResponse
19
+ from app.infra.metrics import REQUEST_COUNT, REQUEST_LATENCY, ASR_DURATION, NORMALIZE_DURATION, ERROR_COUNT
20
+
21
+ router = APIRouter()
22
+
23
+ # load model on import/startup to avoid repeated initialization
24
+ # you may prefer to call load_model in FastAPI startup event
25
+ ASR_MODEL = None
26
+
27
+ @router.on_event("startup")
28
+ async def _startup():
29
+ global ASR_MODEL
30
+ # load model in thread to avoid blocking event loop
31
+ ASR_MODEL = await asyncio.to_thread(load_model, 30)
32
+
33
+ def _ensure_file_limits(path: str):
34
+ if os.path.getsize(path) > settings.MAX_UPLOAD_BYTES:
35
+ raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="File size exceeds limit")
36
+ info = get_audio_info(path)
37
+ if info and info.get("duration", 0) > settings.MAX_DURATION_SECS:
38
+ raise HTTPException(status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail="Audio duration exceeds limit")
39
+
40
+ @router.post("/transcribe", response_model=TranscribeResponse)
41
+ async def transcribe(file: UploadFile = File(...)):
42
+ tmp_in = make_temp_path(suffix=Path(file.filename).suffix or ".wav")
43
+ tmp_wav = None
44
+ note_service = NoteServiceClient()
45
+ note_id = str(uuid.uuid4())
46
+ start_time = time.perf_counter()
47
+ endpoint = "/transcribe"
48
+ status_label = "success"
49
+ with REQUEST_LATENCY.labels(endpoint).time():
50
+ try:
51
+ # write upload to tmp (blocking) -> run in thread
52
+ await asyncio.to_thread(save_upload_file, file, tmp_in)
53
+
54
+ _ensure_file_limits(tmp_in)
55
+
56
+ tmp_wav = make_temp_path(suffix=".wav")
57
+ # ffmpeg convert is blocking -> run in thread
58
+ await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
59
+
60
+ # Kiểm tra duration để quyαΊΏt Δ‘α»‹nh xα»­ lΓ½ sync hay async
61
+ info = get_audio_info(tmp_wav) or {}
62
+ duration_sec = info.get("duration", 0)
63
+ ASYNC_THRESHOLD = 120 # 2 phΓΊt, cΓ³ thể chỉnh
64
+ if duration_sec > ASYNC_THRESHOLD:
65
+ # Enqueue background job bαΊ±ng RQ
66
+ q = Queue("asr", connection=redis_client)
67
+ job = q.enqueue(
68
+ transcribe_job,
69
+ tmp_wav,
70
+ note_id,
71
+ job_timeout=1800
72
+ )
73
+ logging.info(f"Enqueued background transcribe job: note_id={note_id} job_id={job.id} duration={duration_sec:.1f}s")
74
+ REQUEST_COUNT.labels(endpoint, "queued").inc()
75
+ return JSONResponse(status_code=202, content={
76
+ "note_id": note_id,
77
+ "job_id": job.id,
78
+ "status": "queued",
79
+ "duration": duration_sec
80
+ })
81
+
82
+ # NαΊΏu audio ngαΊ―n, xα»­ lΓ½ sync nhΖ° cΕ©
83
+ model = ASR_MODEL or await asyncio.to_thread(load_model, 30)
84
+ with ASR_DURATION.labels(endpoint).time():
85
+ text = await asyncio.to_thread(transcribe_file, model, tmp_wav, 30.0, 5.0)
86
+ chunks = await asyncio.to_thread(transcribe_file_chunks, model, tmp_wav, 30.0, 5.0)
87
+
88
+ # normalize via Gemini (already async safe in your service)
89
+ with NORMALIZE_DURATION.labels(endpoint).time():
90
+ normalized_text = await normalize_text(text)
91
+
92
+ info2 = get_audio_info(tmp_wav) or {}
93
+ # persist to Note Service (async HTTP)
94
+ await note_service.save_transcript(
95
+ note_id=note_id,
96
+ raw_text=text,
97
+ normalized_text=normalized_text,
98
+ duration=info2.get("duration"),
99
+ sample_rate=info2.get("samplerate"),
100
+ chunks=chunks,
101
+ asr_model="PhoWhisper-base",
102
+ normalization_model="gemini-1.5"
103
+ )
104
+
105
+ duration = time.perf_counter() - start_time
106
+ logging.info(f"/transcribe success note_id={note_id} duration={duration:.2f}s audio_dur={info2.get('duration')}")
107
+ REQUEST_COUNT.labels(endpoint, status_label).inc()
108
+ return JSONResponse(status_code=200, content={
109
+ "note_id": note_id,
110
+ "status": "transcribed",
111
+ "duration": info2.get("duration")
112
+ })
113
+ except HTTPException:
114
+ status_label = "http_error"
115
+ ERROR_COUNT.labels(endpoint, status_label).inc()
116
+ raise
117
+ except Exception as e:
118
+ status_label = "error"
119
+ ERROR_COUNT.labels(endpoint, status_label).inc()
120
+ logging.exception(f"/transcribe failed note_id={note_id}")
121
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {e}")
122
+ finally:
123
+ # cleanup
124
+ for p in [tmp_in, tmp_wav]:
125
+ try:
126
+ if p and os.path.exists(p):
127
+ os.remove(p)
128
+ except Exception:
129
+ pass
130
+
131
+
132
+ @router.post("/transcribe-url", response_model=TranscribeResponse)
133
+ async def transcribe_url(payload: dict):
134
+ audio_url = payload.get("audio_url")
135
+ user_id = payload.get("user_id")
136
+ if not audio_url:
137
+ raise HTTPException(status_code=400, detail="audio_url required")
138
+ if not user_id:
139
+ raise HTTPException(status_code=400, detail="user_id required")
140
+
141
+ tmp_in = make_temp_path(suffix=Path(audio_url).suffix or ".tmp")
142
+ tmp_wav = None
143
+ note_service = NoteServiceClient()
144
+ note_id = str(uuid.uuid4())
145
+
146
+ start_time = time.perf_counter()
147
+ try:
148
+ # download blocking -> thread
149
+ await asyncio.to_thread(download_file_from_url, audio_url, tmp_in)
150
+
151
+ _ensure_file_limits(tmp_in)
152
+
153
+ tmp_wav = make_temp_path(suffix=".wav")
154
+ await asyncio.to_thread(ensure_wav_16k_mono, tmp_in, tmp_wav)
155
+
156
+ model = ASR_MODEL or await asyncio.to_thread(load_model, 30)
157
+ text = await asyncio.to_thread(transcribe_file, model, tmp_wav, 30.0, 5.0)
158
+ chunks = await asyncio.to_thread(transcribe_file_chunks, model, tmp_wav, 30.0, 5.0)
159
+ normalized_text = await normalize_text(text)
160
+ info2 = get_audio_info(tmp_wav) or {}
161
+
162
+ await note_service.save_transcript(
163
+ note_id=note_id,
164
+ raw_text=text,
165
+ normalized_text=normalized_text,
166
+ duration=info2.get("duration"),
167
+ sample_rate=info2.get("samplerate"),
168
+ chunks=chunks,
169
+ asr_model="PhoWhisper-base",
170
+ normalization_model="gemini-1.5"
171
+ )
172
+
173
+ duration = time.perf_counter() - start_time
174
+ logging.info(f"/transcribe-url success note_id={note_id} duration={duration:.2f}s audio_dur={info2.get('duration')}")
175
+ return JSONResponse(status_code=200, content={
176
+ "note_id": note_id,
177
+ "status": "transcribed",
178
+ "duration": info2.get("duration")
179
+ })
180
+ except HTTPException:
181
+ raise
182
+ except Exception as e:
183
+ logging.exception(f"/transcribe-url failed note_id={note_id}")
184
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {e}")
185
+ finally:
186
+ for p in [tmp_in, tmp_wav]:
187
+ try:
188
+ if p and os.path.exists(p):
189
+ os.remove(p)
190
+ except Exception:
191
+ pass
app/config/__init__.py ADDED
File without changes
app/{config.py β†’ config/settings.py} RENAMED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
 
3
  # Limits & model setting
@@ -12,3 +14,14 @@ os.makedirs(TMP_DIR, exist_ok=True)
12
  # Cloud credentials (set as HF Spaces secrets or env)
13
  # FIREBASE_SERVICE_ACCOUNT = os.getenv("FIREBASE_SERVICE_ACCOUNT_JSON") # optional
14
  # CLOUDINARY_URL = os.getenv("CLOUDINARY_URL") # optional
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # App settings and configuration
2
+
3
  import os
4
 
5
  # Limits & model setting
 
14
  # Cloud credentials (set as HF Spaces secrets or env)
15
  # FIREBASE_SERVICE_ACCOUNT = os.getenv("FIREBASE_SERVICE_ACCOUNT_JSON") # optional
16
  # CLOUDINARY_URL = os.getenv("CLOUDINARY_URL") # optional
17
+
18
+ # Gemini API Key (for text normalization)
19
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
20
+ # External services
21
+ NOTE_SERVICE_URL = os.getenv("NOTE_SERVICE_URL", "http://localhost:9000")
22
+
23
+ # HTTP timeouts
24
+ HTTPX_TIMEOUT = float(os.getenv("HTTPX_TIMEOUT", "10.0"))
25
+
26
+ # Redis URL
27
+ REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
app/core/__init__.py ADDED
File without changes
app/{model.py β†’ core/asr_engine.py} RENAMED
@@ -1,21 +1,14 @@
1
- import math
2
- import shlex
3
- import subprocess
4
- import os
5
- from pathlib import Path
6
- from typing import List
7
- from .audio_utils import make_temp_path, ensure_wav_16k_mono, get_audio_info
8
- from transformers import pipeline
9
  import logging
10
- from .config import MODEL_NAME
 
11
 
12
  _model = None
13
 
14
  def load_model(chunk_length_s: int = None):
15
  global _model
16
  if _model is None:
17
- # This will download weights at runtime (from Hugging Face Hub).
18
- # If you are on HF Spaces, the image will download on first run.
19
  logging.info(f"Loading ASR model {MODEL_NAME} ...")
20
  kwargs = {}
21
  if chunk_length_s is not None:
@@ -24,57 +17,13 @@ def load_model(chunk_length_s: int = None):
24
  logging.info("Model loaded")
25
  return _model
26
 
27
- #_ffmpeg_extract_segment dΓΉng ffmpeg để tαΊ‘o mα»—i chunk ở dαΊ‘ng WAV 16k mono PCM16 (Δ‘αΊ£m bαΊ£o chuαΊ©n cho model).
28
- def _ffmpeg_extract_segment(src: str, start: float, duration: float, dst: str):
29
- """
30
- Extract segment [start, start+duration) using ffmpeg into dst (wav 16k mono pcm16).
31
- We call ffmpeg with -ss + -t on input for safety.
32
- """
33
- # Note: -ss before -i is fast seek but less accurate for some formats. We use -ss after -i for accuracy.
34
- cmd = f'ffmpeg -v error -y -ss {start:.3f} -i "{src}" -t {duration:.3f} -ar 16000 -ac 1 -acodec pcm_s16le "{dst}"'
35
- proc = subprocess.run(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
36
- if proc.returncode != 0:
37
- raise RuntimeError(f"ffmpeg extract failed: {proc.stderr.decode(errors='ignore')}")
38
- return dst
39
-
40
- #_split_audio_to_chunks tαΊ‘o list chunk files.
41
- def _split_audio_to_chunks(src_wav: str, chunk_length_s: float = 30.0, overlap_s: float = 5.0) -> List[str]:
42
- """
43
- Split src_wav into chunk files (wav 16k mono). Returns list of chunk file paths in order.
44
- Overlap is seconds overlapping between consecutive chunks.
45
- """
46
- info = get_audio_info(src_wav)
47
- if not info:
48
- raise RuntimeError("Cannot read audio info")
49
- duration = info["duration"]
50
- step = chunk_length_s - overlap_s
51
- if step <= 0:
52
- raise ValueError("chunk_length_s must be > overlap_s")
53
- starts = []
54
- t = 0.0
55
- while t < duration:
56
- starts.append(t)
57
- t += step
58
- # ensure last chunk covers end - if last start + chunk_length < duration, we still create chunk that may be shorter
59
- chunks = []
60
- for i, s in enumerate(starts):
61
- dst = make_temp_path(suffix=f".chunk{i}.wav")
62
- _ffmpeg_extract_segment(src_wav, s, chunk_length_s, dst)
63
- chunks.append(dst)
64
- return chunks
65
-
66
- #_merge_transcripts lΓ  heuristic so khα»›p bαΊ±ng tα»« (exact n-gram match) vΓ  fallback; khΓ΄ng hoΓ n hαΊ£o nhΖ°ng giαΊ£m phαΊ§n lα»›n lαΊ·p do overlap.
67
- def _merge_transcripts(prev_text: str, new_text: str, max_overlap_words: int = 8) -> str:
68
- """
69
- Heuristic merge: if end of prev_text and start of new_text share overlapping words, remove overlap.
70
- We find the longest overlap up to max_overlap_words.
71
- """
72
  if not prev_text:
73
  return new_text
74
  p_words = prev_text.strip().split()
75
  n_words = new_text.strip().split()
76
  max_ol = min(max_overlap_words, len(p_words), len(n_words))
77
- # search for largest k where last k words of prev == first k words of new
78
  best_k = 0
79
  for k in range(max_ol, 0, -1):
80
  if p_words[-k:] == n_words[:k]:
@@ -83,35 +32,22 @@ def _merge_transcripts(prev_text: str, new_text: str, max_overlap_words: int = 8
83
  if best_k > 0:
84
  merged = " ".join(p_words + n_words[best_k:])
85
  return merged
86
- # If no exact overlap, try fuzzy overlap by matching word sequences (less strict)
87
- # simple heuristic: if last N words of prev appear anywhere at start of new, trim
88
  for k in range(max_ol, 1, -1):
89
  seq = " ".join(p_words[-k:])
90
  if seq in new_text:
91
  idx = new_text.find(seq)
92
- # remove through the sequence
93
  merged = " ".join(p_words + new_text[idx + len(seq):].strip().split())
94
  return merged
95
- # fallback: just concatenate with space
96
  return prev_text.rstrip() + " " + new_text.lstrip()
97
 
98
- #transcribe_long_audio chαΊ‘y model trΓͺn tα»«ng chunk (tuαΊ§n tα»± theo mαΊ·c Δ‘α»‹nh); cΓ³ tuα»³ chọn parallel=True để xα»­ lΓ½ Δ‘α»“ng thời (cαΊ§n thαΊ­n trọng vα»›i GPU memory).
99
  def transcribe_long_audio(model, wav_path: str, chunk_length_s: float = 30.0, overlap_s: float = 5.0, parallel: bool = False) -> str:
100
- """
101
- Transcribe a long wav file by splitting into overlapping chunks,
102
- transcribing each, then merging transcripts.
103
- - model: loaded pipeline
104
- - wav_path: already normalized wav (16k, mono)
105
- Returns stitched transcript string.
106
- """
107
- # 1) split into chunks
108
- chunks = _split_audio_to_chunks(wav_path, chunk_length_s=chunk_length_s, overlap_s=overlap_s)
109
  logging.info(f"Split into {len(chunks)} chunks")
110
  texts = []
111
-
112
- # 2) process chunks (sequential for safety). Optionally implement limited parallelism.
113
  if parallel:
114
- # optional: implement ThreadPool/ProcessPool with care for GPU memory; by default we do sequential
115
  from concurrent.futures import ThreadPoolExecutor, as_completed
116
  def process_chunk(path):
117
  try:
@@ -122,7 +58,7 @@ def transcribe_long_audio(model, wav_path: str, chunk_length_s: float = 30.0, ov
122
  except Exception as e:
123
  logging.exception("Chunk inference failed")
124
  return ""
125
- with ThreadPoolExecutor(max_workers=2) as ex: # limit concurrency
126
  futures = {ex.submit(process_chunk, c): c for c in chunks}
127
  for fut in as_completed(futures):
128
  texts.append(fut.result() or "")
@@ -133,12 +69,9 @@ def transcribe_long_audio(model, wav_path: str, chunk_length_s: float = 30.0, ov
133
  texts.append(out.get("text", "") or "")
134
  else:
135
  texts.append(str(out) or "")
136
-
137
- # 3) merge texts to single transcript, removing overlap duplicates
138
  merged = ""
139
  for t in texts:
140
- merged = _merge_transcripts(merged, t, max_overlap_words=12)
141
- # 4) cleanup chunk files
142
  for c in chunks:
143
  try:
144
  os.remove(c)
@@ -146,31 +79,24 @@ def transcribe_long_audio(model, wav_path: str, chunk_length_s: float = 30.0, ov
146
  pass
147
  return merged
148
 
149
- # transcribe_file quyαΊΏt Δ‘α»‹nh cΓ³ chunk hay khΓ΄ng.
150
  def transcribe_file(model, wav_path: str, max_chunk_length: float = 30.0, overlap_s: float = 5.0):
151
- """
152
- Main entry: if audio short, run single pass; if long, call chunked transcribe.
153
- """
154
  info = get_audio_info(wav_path) or {}
155
  duration = info.get("duration", 0.0)
156
- # threshold to decide chunking: if duration > chunk_length -> chunk
157
- if duration and duration > max_chunk_length * 1.1: # slightly bigger than chunk length
158
  logging.info(f"Long audio detected ({duration}s) -> chunking")
159
  return transcribe_long_audio(model, wav_path, chunk_length_s=max_chunk_length, overlap_s=overlap_s)
160
- # short audio -> direct
161
  out = model(wav_path)
162
  if isinstance(out, dict):
163
  return out.get("text") or ""
164
  return str(out)
165
 
166
- # Hàm trả về danh sÑch dict chứa start, end, text cho từng chunk
167
  def transcribe_file_chunks(model, wav_path: str, max_chunk_length: float = 30.0, overlap_s: float = 5.0):
168
- """
169
- Chia audio thành cÑc chunk, transcribe từng chunk, trả về list dict: {start, end, text}
170
- """
171
  info = get_audio_info(wav_path) or {}
172
  duration = info.get("duration", 0.0)
173
- # TΓ­nh toΓ‘n cΓ‘c mα»‘c thời gian bαΊ―t Δ‘αΊ§u cho tα»«ng chunk
174
  step = max_chunk_length - overlap_s
175
  if step <= 0:
176
  raise ValueError("max_chunk_length must be > overlap_s")
@@ -183,7 +109,7 @@ def transcribe_file_chunks(model, wav_path: str, max_chunk_length: float = 30.0,
183
  for i, s in enumerate(starts):
184
  chunk_end = min(s + max_chunk_length, duration)
185
  dst = make_temp_path(suffix=f".chunk{i}.wav")
186
- _ffmpeg_extract_segment(wav_path, s, chunk_end - s, dst)
187
  out = model(dst)
188
  if isinstance(out, dict):
189
  text = out.get("text", "")
 
1
+ # PhoWhisper inference engine
2
+
 
 
 
 
 
 
3
  import logging
4
+ from transformers import pipeline
5
+ from app.config.settings import MODEL_NAME
6
 
7
  _model = None
8
 
9
  def load_model(chunk_length_s: int = None):
10
  global _model
11
  if _model is None:
 
 
12
  logging.info(f"Loading ASR model {MODEL_NAME} ...")
13
  kwargs = {}
14
  if chunk_length_s is not None:
 
17
  logging.info("Model loaded")
18
  return _model
19
 
20
+ # Heuristic merge for chunked transcripts
21
+ def merge_transcripts(prev_text: str, new_text: str, max_overlap_words: int = 8) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  if not prev_text:
23
  return new_text
24
  p_words = prev_text.strip().split()
25
  n_words = new_text.strip().split()
26
  max_ol = min(max_overlap_words, len(p_words), len(n_words))
 
27
  best_k = 0
28
  for k in range(max_ol, 0, -1):
29
  if p_words[-k:] == n_words[:k]:
 
32
  if best_k > 0:
33
  merged = " ".join(p_words + n_words[best_k:])
34
  return merged
 
 
35
  for k in range(max_ol, 1, -1):
36
  seq = " ".join(p_words[-k:])
37
  if seq in new_text:
38
  idx = new_text.find(seq)
 
39
  merged = " ".join(p_words + new_text[idx + len(seq):].strip().split())
40
  return merged
 
41
  return prev_text.rstrip() + " " + new_text.lstrip()
42
 
 
43
  def transcribe_long_audio(model, wav_path: str, chunk_length_s: float = 30.0, overlap_s: float = 5.0, parallel: bool = False) -> str:
44
+ from app.core.chunking import split_audio_to_chunks
45
+ from app.core.audio_utils import make_temp_path
46
+ import os
47
+ chunks = split_audio_to_chunks(wav_path, chunk_length_s=chunk_length_s, overlap_s=overlap_s)
 
 
 
 
 
48
  logging.info(f"Split into {len(chunks)} chunks")
49
  texts = []
 
 
50
  if parallel:
 
51
  from concurrent.futures import ThreadPoolExecutor, as_completed
52
  def process_chunk(path):
53
  try:
 
58
  except Exception as e:
59
  logging.exception("Chunk inference failed")
60
  return ""
61
+ with ThreadPoolExecutor(max_workers=2) as ex:
62
  futures = {ex.submit(process_chunk, c): c for c in chunks}
63
  for fut in as_completed(futures):
64
  texts.append(fut.result() or "")
 
69
  texts.append(out.get("text", "") or "")
70
  else:
71
  texts.append(str(out) or "")
 
 
72
  merged = ""
73
  for t in texts:
74
+ merged = merge_transcripts(merged, t, max_overlap_words=12)
 
75
  for c in chunks:
76
  try:
77
  os.remove(c)
 
79
  pass
80
  return merged
81
 
 
82
  def transcribe_file(model, wav_path: str, max_chunk_length: float = 30.0, overlap_s: float = 5.0):
83
+ from app.core.audio_utils import get_audio_info
 
 
84
  info = get_audio_info(wav_path) or {}
85
  duration = info.get("duration", 0.0)
86
+ if duration and duration > max_chunk_length * 1.1:
 
87
  logging.info(f"Long audio detected ({duration}s) -> chunking")
88
  return transcribe_long_audio(model, wav_path, chunk_length_s=max_chunk_length, overlap_s=overlap_s)
 
89
  out = model(wav_path)
90
  if isinstance(out, dict):
91
  return out.get("text") or ""
92
  return str(out)
93
 
 
94
  def transcribe_file_chunks(model, wav_path: str, max_chunk_length: float = 30.0, overlap_s: float = 5.0):
95
+ from app.core.audio_utils import get_audio_info, make_temp_path
96
+ from app.core.chunking import ffmpeg_extract_segment
97
+ import os
98
  info = get_audio_info(wav_path) or {}
99
  duration = info.get("duration", 0.0)
 
100
  step = max_chunk_length - overlap_s
101
  if step <= 0:
102
  raise ValueError("max_chunk_length must be > overlap_s")
 
109
  for i, s in enumerate(starts):
110
  chunk_end = min(s + max_chunk_length, duration)
111
  dst = make_temp_path(suffix=f".chunk{i}.wav")
112
+ ffmpeg_extract_segment(wav_path, s, chunk_end - s, dst)
113
  out = model(dst)
114
  if isinstance(out, dict):
115
  text = out.get("text", "")
app/{audio_utils.py β†’ core/audio_utils.py} RENAMED
@@ -1,23 +1,23 @@
1
- import os
2
  import subprocess
3
  import shlex
4
  import uuid
5
  import requests
6
  from pathlib import Path
7
  import soundfile as sf
8
- from .config import TMP_DIR, MAX_UPLOAD_BYTES, MAX_DURATION_SECS
9
 
10
  def save_upload_file(upload_file, dest_path: str):
11
- """Save fastapi UploadFile to dest_path"""
12
  with open(dest_path, "wb") as f:
13
  while True:
14
- chunk = upload_file.file.read(1024*1024)
15
  if not chunk:
16
  break
17
  f.write(chunk)
18
 
19
  def download_file_from_url(url: str, dest_path: str, timeout=30):
20
- """Download remote file to dest_path."""
21
  r = requests.get(url, stream=True, timeout=timeout)
22
  r.raise_for_status()
23
  total = 0
@@ -30,26 +30,37 @@ def download_file_from_url(url: str, dest_path: str, timeout=30):
30
  f.write(chunk)
31
 
32
  def get_audio_info(path: str):
33
- """Return duration (s) and sample_rate using soundfile."""
34
  try:
35
  info = sf.info(path)
36
  duration = info.frames / info.samplerate
37
- return {"duration": duration, "samplerate": info.samplerate, "channels": info.channels}
 
 
 
 
38
  except Exception:
39
  return None
40
 
41
  def ensure_wav_16k_mono(src_path: str, dest_path: str):
42
  """
43
- Use ffmpeg to convert any audio to wav (PCM16), 16kHz, mono.
44
- Returns dest_path if ok, raises exception on error.
45
  """
46
- # Using ffmpeg command line
47
- # -y overwrite
48
- cmd = f'ffmpeg -v error -y -i "{src_path}" -ar 16000 -ac 1 -acodec pcm_s16le "{dest_path}"'
49
- proc = subprocess.run(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
 
 
 
 
 
50
  if proc.returncode != 0:
51
- raise RuntimeError(f"ffmpeg convert failed: {proc.stderr.decode(errors='ignore')}")
 
 
52
  return dest_path
53
 
54
  def make_temp_path(suffix=".wav"):
 
55
  return str(Path(TMP_DIR) / f"{uuid.uuid4().hex}{suffix}")
 
1
+ # Audio utilities: ffmpeg, normalization, etc.
2
  import subprocess
3
  import shlex
4
  import uuid
5
  import requests
6
  from pathlib import Path
7
  import soundfile as sf
8
+ from app.config.settings import TMP_DIR, MAX_UPLOAD_BYTES
9
 
10
  def save_upload_file(upload_file, dest_path: str):
11
+ """Save FastAPI UploadFile to dest_path (streaming)."""
12
  with open(dest_path, "wb") as f:
13
  while True:
14
+ chunk = upload_file.file.read(1024 * 1024)
15
  if not chunk:
16
  break
17
  f.write(chunk)
18
 
19
  def download_file_from_url(url: str, dest_path: str, timeout=30):
20
+ """Download remote file to dest_path with size limit."""
21
  r = requests.get(url, stream=True, timeout=timeout)
22
  r.raise_for_status()
23
  total = 0
 
30
  f.write(chunk)
31
 
32
  def get_audio_info(path: str):
33
+ """Return duration (s), sample_rate, channels using soundfile."""
34
  try:
35
  info = sf.info(path)
36
  duration = info.frames / info.samplerate
37
+ return {
38
+ "duration": duration,
39
+ "samplerate": info.samplerate,
40
+ "channels": info.channels,
41
+ }
42
  except Exception:
43
  return None
44
 
45
  def ensure_wav_16k_mono(src_path: str, dest_path: str):
46
  """
47
+ Convert any audio to WAV PCM16, 16kHz, mono using ffmpeg.
 
48
  """
49
+ cmd = (
50
+ f'ffmpeg -v error -y -i "{src_path}" '
51
+ f'-ar 16000 -ac 1 -acodec pcm_s16le "{dest_path}"'
52
+ )
53
+ proc = subprocess.run(
54
+ shlex.split(cmd),
55
+ stdout=subprocess.PIPE,
56
+ stderr=subprocess.PIPE,
57
+ )
58
  if proc.returncode != 0:
59
+ raise RuntimeError(
60
+ f"ffmpeg convert failed: {proc.stderr.decode(errors='ignore')}"
61
+ )
62
  return dest_path
63
 
64
  def make_temp_path(suffix=".wav"):
65
+ """Generate unique temp file path under TMP_DIR."""
66
  return str(Path(TMP_DIR) / f"{uuid.uuid4().hex}{suffix}")
app/core/chunking.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Audio chunking/splitting/merging logic
2
+
3
+ import shlex
4
+ import subprocess
5
+ from typing import List
6
+ from app.core.audio_utils import get_audio_info, make_temp_path
7
+
8
+ def ffmpeg_extract_segment(src: str, start: float, duration: float, dst: str):
9
+ """
10
+ Extract segment [start, start+duration) using ffmpeg into dst (wav 16k mono pcm16).
11
+ """
12
+ cmd = f'ffmpeg -v error -y -ss {start:.3f} -i "{src}" -t {duration:.3f} -ar 16000 -ac 1 -acodec pcm_s16le "{dst}"'
13
+ proc = subprocess.run(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
14
+ if proc.returncode != 0:
15
+ raise RuntimeError(f"ffmpeg extract failed: {proc.stderr.decode(errors='ignore')}")
16
+ return dst
17
+
18
+ def split_audio_to_chunks(src_wav: str, chunk_length_s: float = 30.0, overlap_s: float = 5.0) -> List[str]:
19
+ info = get_audio_info(src_wav)
20
+ if not info:
21
+ raise RuntimeError("Cannot read audio info")
22
+ duration = info["duration"]
23
+ step = chunk_length_s - overlap_s
24
+ if step <= 0:
25
+ raise ValueError("chunk_length_s must be > overlap_s")
26
+ starts = []
27
+ t = 0.0
28
+ while t < duration:
29
+ starts.append(t)
30
+ t += step
31
+ chunks = []
32
+ for i, s in enumerate(starts):
33
+ chunk_path = make_temp_path(suffix=f"_chunk{i}.wav")
34
+ ffmpeg_extract_segment(src_wav, s, min(chunk_length_s, duration - s), chunk_path)
35
+ chunks.append(chunk_path)
36
+ return chunks
app/infra/metrics.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from prometheus_client import Counter, Histogram
2
+
3
+ REQUEST_COUNT = Counter(
4
+ "asr_requests_total",
5
+ "Total ASR requests",
6
+ ["endpoint", "status"]
7
+ )
8
+
9
+
10
+ REQUEST_LATENCY = Histogram(
11
+ "asr_request_latency_seconds",
12
+ "ASR request latency",
13
+ ["endpoint"]
14
+ )
15
+
16
+ ASR_DURATION = Histogram(
17
+ "asr_model_duration_seconds",
18
+ "ASR model inference duration",
19
+ ["endpoint"]
20
+ )
21
+
22
+ NORMALIZE_DURATION = Histogram(
23
+ "normalize_duration_seconds",
24
+ "Text normalization duration",
25
+ ["endpoint"]
26
+ )
27
+
28
+ ERROR_COUNT = Counter(
29
+ "asr_error_total",
30
+ "Total ASR errors",
31
+ ["endpoint", "error_type"]
32
+ )
app/infra/redis_client.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import redis
3
+ from app.config.settings import REDIS_URL
4
+
5
+ redis_client = redis.Redis.from_url(
6
+ REDIS_URL,
7
+ decode_responses=True
8
+ )
app/jobs/transcribe_job.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.core.asr_engine import load_model, transcribe_file
2
+ from app.services.text_normalizer import normalize_text
3
+ from app.services.note_client import NoteServiceClient
4
+
5
+ # This function will be run by RQ worker
6
+ def transcribe_job(tmp_wav: str, note_id: str):
7
+ model = load_model()
8
+ text = transcribe_file(model, tmp_wav, 30.0, 5.0)
9
+ # normalize_text cΓ³ thể lΓ  async, nhΖ°ng RQ chỉ chαΊ‘y sync nΓͺn cαΊ§n chαΊ‘y event loop nαΊΏu cαΊ§n
10
+ import asyncio
11
+ if asyncio.iscoroutinefunction(normalize_text):
12
+ normalized = asyncio.run(normalize_text(text))
13
+ else:
14
+ normalized = normalize_text(text)
15
+ note_service = NoteServiceClient()
16
+ # Gα»­i transcript sang Note Service
17
+ note_service.save_transcript(
18
+ note_id=note_id,
19
+ raw_text=text,
20
+ normalized_text=normalized,
21
+ duration=None,
22
+ sample_rate=None,
23
+ chunks=None,
24
+ asr_model="PhoWhisper-base",
25
+ normalization_model="gemini-1.5"
26
+ )
27
+ return True
app/main.py CHANGED
@@ -1,16 +1,24 @@
1
- import os
2
- import shutil
3
- from fastapi import FastAPI, UploadFile, File, HTTPException, Request
4
- from fastapi.responses import JSONResponse
5
- from fastapi.middleware.cors import CORSMiddleware
6
- from pathlib import Path
7
  import logging
8
- from .config import TMP_DIR, MAX_UPLOAD_BYTES, MAX_DURATION_SECS
9
- from .audio_utils import save_upload_file, download_file_from_url, get_audio_info, ensure_wav_16k_mono, make_temp_path
10
- from .model import load_model, transcribe_file, transcribe_file_chunks
 
11
 
12
  app = FastAPI(title="PhoWhisper ASR API")
13
 
 
 
 
 
 
 
 
 
 
14
  # CORS β€” tighten in prod
15
  app.add_middleware(
16
  CORSMiddleware,
@@ -19,99 +27,22 @@ app.add_middleware(
19
  allow_headers=["*"],
20
  )
21
 
22
- # load model lazily on first request to avoid startup heavy cost
23
- MODEL = None
24
 
25
- @app.on_event("startup")
26
- def on_startup():
27
- # ensure tmp dir exists
28
- Path(TMP_DIR).mkdir(parents=True, exist_ok=True)
29
- # optionally pre-load model (comment if you want lazy)
30
- # global MODEL
31
- # MODEL = load_model()
32
- logging.info("API startup complete")
33
 
 
34
  @app.get("/health")
35
  def health():
36
  return {"status": "ok"}
37
 
38
- @app.post("/transcribe")
39
- async def transcribe(file: UploadFile = File(...)):
40
- # 1. basic size check: UploadFile does not expose size easily, so stream-write and check
41
- tmp_in = make_temp_path(suffix=Path(file.filename).suffix or ".wav")
42
- try:
43
- save_upload_file(file, tmp_in)
44
- if os.path.getsize(tmp_in) > MAX_UPLOAD_BYTES:
45
- raise HTTPException(status_code=413, detail="File size exceeds limit")
46
- info = get_audio_info(tmp_in)
47
- if info and info.get("duration") and info["duration"] > MAX_DURATION_SECS:
48
- raise HTTPException(status_code=413, detail="Audio duration exceeds limit")
49
- # 2. normalize
50
- tmp_wav = make_temp_path(suffix=".wav")
51
- ensure_wav_16k_mono(tmp_in, tmp_wav)
52
- # 3. load model if needed
53
- global MODEL
54
- if MODEL is None:
55
- MODEL = load_model(chunk_length_s=30)
56
- text = transcribe_file(MODEL, tmp_wav, max_chunk_length=30.0, overlap_s=5.0)
57
- chunks = transcribe_file_chunks(MODEL, tmp_wav, max_chunk_length=30.0, overlap_s=5.0)
58
- info2 = get_audio_info(tmp_wav) or {}
59
- return JSONResponse({
60
- "text": text,
61
- "duration": info2.get("duration"),
62
- "sample_rate": info2.get("samplerate"),
63
- "chunks": chunks
64
- })
65
- except HTTPException:
66
- raise
67
- except Exception as e:
68
- logging.exception("Transcribe failed")
69
- raise HTTPException(status_code=500, detail=str(e))
70
- finally:
71
- # cleanup temp files
72
- for p in [tmp_in, locals().get("tmp_wav")]:
73
- try:
74
- if p and os.path.exists(p):
75
- os.remove(p)
76
- except Exception:
77
- pass
78
 
79
- @app.post("/transcribe-url")
80
- async def transcribe_url(payload: dict):
81
- audio_url = payload.get("audio_url")
82
- if not audio_url:
83
- raise HTTPException(status_code=400, detail="audio_url required")
84
- tmp_in = make_temp_path(suffix=Path(audio_url).suffix or ".tmp")
85
- try:
86
- download_file_from_url(audio_url, tmp_in)
87
- if os.path.getsize(tmp_in) > MAX_UPLOAD_BYTES:
88
- raise HTTPException(status_code=413, detail="File size exceeds limit")
89
- info = get_audio_info(tmp_in)
90
- if info and info.get("duration") and info["duration"] > MAX_DURATION_SECS:
91
- raise HTTPException(status_code=413, detail="Audio duration exceeds limit")
92
- tmp_wav = make_temp_path(suffix=".wav")
93
- ensure_wav_16k_mono(tmp_in, tmp_wav)
94
- global MODEL
95
- if MODEL is None:
96
- MODEL = load_model(chunk_length_s=30)
97
- text = transcribe_file(MODEL, tmp_wav, max_chunk_length=30.0, overlap_s=5.0)
98
- chunks = transcribe_file_chunks(MODEL, tmp_wav, max_chunk_length=30.0, overlap_s=5.0)
99
- info2 = get_audio_info(tmp_wav) or {}
100
- return JSONResponse({
101
- "text": text,
102
- "duration": info2.get("duration"),
103
- "sample_rate": info2.get("samplerate"),
104
- "chunks": chunks
105
- })
106
- except HTTPException:
107
- raise
108
- except Exception as e:
109
- logging.exception("Transcribe-url failed")
110
- raise HTTPException(status_code=500, detail=str(e))
111
- finally:
112
- for p in [tmp_in, locals().get("tmp_wav")]:
113
- try:
114
- if p and os.path.exists(p):
115
- os.remove(p)
116
- except Exception:
117
- pass
 
1
+
2
+ from fastapi import FastAPI, Response
3
+ from prometheus_client import generate_latest
4
+ import asyncio
 
 
5
  import logging
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from app.api.transcribe import router as transcribe_router
8
+ from app.core.asr_engine import load_model
9
+
10
 
11
  app = FastAPI(title="PhoWhisper ASR API")
12
 
13
+ # Preload ASR model at startup
14
+ @app.on_event("startup")
15
+ async def preload_asr_model():
16
+ # Load model in thread to avoid blocking event loop
17
+ logging.info("Preloading ASR model at startup...")
18
+ await asyncio.to_thread(load_model, 30)
19
+ logging.info("ASR model preloaded.")
20
+
21
+
22
  # CORS β€” tighten in prod
23
  app.add_middleware(
24
  CORSMiddleware,
 
27
  allow_headers=["*"],
28
  )
29
 
 
 
30
 
31
+ # --- OLD LOGIC: Đã chuyển sang app/api/transcribe.py ---
32
+ # - Định nghΔ©a endpoint trα»±c tiαΊΏp
33
+ # - Chα»©a toΓ n bα»™ logic xα»­ lΓ½
34
+ # - Đã refactor thΓ nh router riΓͺng vΓ  tΓ‘ch core/service
 
 
 
 
35
 
36
+ # Health check (cΓ³ thể giα»― lαΊ‘i nαΊΏu muα»‘n)
37
  @app.get("/health")
38
  def health():
39
  return {"status": "ok"}
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Expose /metrics endpoint for Prometheus
43
+ @app.get("/metrics")
44
+ def metrics():
45
+ return Response(generate_latest(), media_type="text/plain")
46
+
47
+ # Include API routers
48
+ app.include_router(transcribe_router)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/schemas/__init__.py ADDED
File without changes
app/schemas/transcribe.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Request/Response models for transcription
2
+
3
+ from pydantic import BaseModel
4
+ from typing import List, Optional
5
+
6
+ class Chunk(BaseModel):
7
+ start: float
8
+ end: float
9
+ text: str
10
+
11
+ class TranscribeResponse(BaseModel):
12
+ note_id: str
13
+ status: str
14
+ duration: Optional[float] = None
app/services/__init__.py ADDED
File without changes
app/services/note_client.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ from app.config.settings import NOTE_SERVICE_URL, HTTPX_TIMEOUT
3
+ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception
4
+
5
+ class NoteServiceClient:
6
+ def __init__(self, base_url: str = None):
7
+ self.base_url = (base_url or NOTE_SERVICE_URL).rstrip("/")
8
+
9
+ @retry(
10
+ stop=stop_after_attempt(3),
11
+ wait=wait_exponential(multiplier=1, min=1, max=8),
12
+ reraise=True,
13
+ retry=retry_if_exception(
14
+ lambda e: (
15
+ isinstance(e, httpx.RequestError) or
16
+ (isinstance(e, httpx.HTTPStatusError) and 500 <= e.response.status_code < 600)
17
+ )
18
+ )
19
+ )
20
+ async def save_transcript(self, note_id: str, raw_text: str, normalized_text: str,
21
+ duration: float, sample_rate: int, chunks: list,
22
+ asr_model: str = "PhoWhisper-base",
23
+ normalization_model: str = "gemini-1.5"):
24
+ url = f"{self.base_url}/notes/{note_id}/transcript"
25
+ payload = {
26
+ "raw_text": raw_text,
27
+ "normalized_text": normalized_text,
28
+ "duration": duration,
29
+ "sample_rate": sample_rate,
30
+ "chunks": chunks,
31
+ "asr_model": asr_model,
32
+ "normalization_model": normalization_model
33
+ }
34
+ timeout = httpx.Timeout(HTTPX_TIMEOUT)
35
+ async with httpx.AsyncClient(timeout=timeout) as client:
36
+ try:
37
+ resp = await client.post(url, json=payload)
38
+ resp.raise_for_status()
39
+ return resp.json()
40
+ except httpx.HTTPStatusError as e:
41
+ # Chỉ retry nαΊΏu lΓ  5xx
42
+ if 500 <= e.response.status_code < 600:
43
+ raise
44
+ else:
45
+ raise
46
+ except httpx.RequestError as e:
47
+ # Retry network errors
48
+ raise
app/services/text_normalizer.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.infra.redis_client import redis_client
2
+ from app.utils.hashing import sha256
3
+
4
+ CACHE_TTL = 60 * 60 * 24 * 3 # 3 days
5
+
6
+ # Simple in-memory cache (cΓ³ thể thay bαΊ±ng Redis, v.v. sau nΓ y)
7
+ # _normalize_cache = {}
8
+
9
+ # --- Gemini client thα»±c tαΊΏ (Google GenerativeAI) ---
10
+
11
+ import google.generativeai as genai
12
+ from app.config.settings import GEMINI_API_KEY
13
+
14
+ if GEMINI_API_KEY:
15
+ genai.configure(api_key=GEMINI_API_KEY)
16
+ _gemini_model = genai.GenerativeModel("gemini-pro")
17
+ else:
18
+ _gemini_model = None
19
+
20
+ async def normalize_text(raw_text: str) -> str:
21
+ cache_key = f"normalize:{sha256(raw_text)}"
22
+ cached = redis_client.get(cache_key)
23
+ if cached:
24
+ return cached
25
+
26
+ prompt = f"""
27
+ BαΊ‘n lΓ  hệ thα»‘ng chuαΊ©n hΓ³a transcript tiαΊΏng Việt.
28
+ - KHΓ”NG thΓͺm Γ½ mα»›i
29
+ - Giα»― nguyΓͺn nα»™i dung
30
+ - Chỉ sα»­a chΓ­nh tαΊ£, dαΊ₯u cΓ’u, xuα»‘ng dΓ²ng hợp lΓ½
31
+
32
+ Văn bản:
33
+ {raw_text}
34
+ """
35
+ result = raw_text
36
+ if _gemini_model:
37
+ # Google GenerativeAI Gemini API (synchronous, wrap in thread for async)
38
+ import asyncio
39
+ loop = asyncio.get_event_loop()
40
+ def call_gemini():
41
+ response = _gemini_model.generate_content(prompt)
42
+ return response.text.strip() if hasattr(response, 'text') else str(response)
43
+ result = await loop.run_in_executor(None, call_gemini)
44
+ else:
45
+ # NαΊΏu chΖ°a cαΊ₯u hΓ¬nh Gemini, trαΊ£ về text gα»‘c
46
+ result = raw_text
47
+ result = result.strip()
48
+ redis_client.setex(cache_key, CACHE_TTL, result)
49
+ return result
app/utils/hashing.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ # Hashing utilities for cache keys, helpers
3
+
4
+ import hashlib
5
+
6
+ def sha256(text: str) -> str:
7
+ return hashlib.sha256(text.encode('utf-8')).hexdigest()
requirements.txt CHANGED
@@ -5,5 +5,11 @@ torch
5
  soundfile
6
  python-multipart
7
  requests
 
 
 
 
 
 
8
  numpy
9
  pytest
 
5
  soundfile
6
  python-multipart
7
  requests
8
+ httpx
9
+ redis
10
+ rq
11
+ tenacity
12
+ prometheus-client
13
+ google-generativeai
14
  numpy
15
  pytest
test/conftest.py CHANGED
@@ -1,57 +1,11 @@
1
- import numpy as np
2
- import soundfile as sf
3
- from pathlib import Path
4
- import os
5
- import tempfile
6
- import re
7
  import pytest
 
 
8
 
9
- # helper to synthesize a sine wave (mono) wav file
10
- def generate_sine_wav(path: str, duration_s: float, sr: int = 16000, freq: float = 440.0, amplitude: float = 0.2):
11
- t = np.linspace(0, duration_s, int(sr * duration_s), endpoint=False)
12
- data = amplitude * np.sin(2 * np.pi * freq * t)
13
- sf.write(path, data, sr, subtype='PCM_16')
14
-
15
- # helper to generate silence (zeros)
16
- def generate_silence_wav(path: str, duration_s: float, sr: int = 16000):
17
- import numpy as np
18
- data = np.zeros(int(sr * duration_s), dtype=float)
19
- sf.write(path, data, sr, subtype='PCM_16')
20
-
21
- @pytest.fixture
22
- def tmp_audio_dir(tmp_path):
23
- d = tmp_path / "audio"
24
- d.mkdir()
25
- return d
26
-
27
- # Fake model: examine file path, if path contains ".chunk{N}." then return predictable text.
28
- # Example chunk file name "...chunk3.wav" -> "chunk3_text"
29
- def fake_model_from_path(path: str):
30
- name = str(path)
31
- m = re.search(r"chunk(\d+)", name)
32
- if m:
33
- i = int(m.group(1))
34
- # produce a text designed to create overlaps:
35
- # chunk0 -> "alpha beta gamma"
36
- # chunk1 -> "beta gamma delta"
37
- # chunk2 -> "gamma delta epsilon"
38
- # general pattern derived from i
39
- words = [
40
- ["alpha", "beta", "gamma"],
41
- ["beta", "gamma", "delta"],
42
- ["gamma", "delta", "epsilon"],
43
- ["delta", "epsilon", "zeta"],
44
- ]
45
- w = words[i % len(words)]
46
- return {"text": " ".join(w)}
47
- # fallback: return a simple marker based on filename
48
- return {"text": Path(name).stem}
49
-
50
- # A fake pipeline object: callable that takes a file path and returns dict {"text": ...}
51
- class FakePipeline:
52
- def __call__(self, path):
53
- return fake_model_from_path(path)
54
-
55
- @pytest.fixture
56
- def fake_pipeline():
57
- return FakePipeline()
 
 
 
 
 
 
 
1
  import pytest
2
+ import tempfile
3
+ import os
4
 
5
+ @pytest.fixture(autouse=True)
6
+ def mock_env(monkeypatch):
7
+ monkeypatch.setenv("TMP_DIR", tempfile.gettempdir())
8
+ monkeypatch.setenv("MAX_UPLOAD_BYTES", "1048576")
9
+ monkeypatch.setenv("MAX_DURATION_SECS", "3600")
10
+ monkeypatch.setenv("NOTE_SERVICE_URL", "http://note")
11
+ monkeypatch.setenv("REDIS_URL", "redis://localhost:6379/0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test/test_long_performance.py DELETED
@@ -1,21 +0,0 @@
1
- import os
2
- import time
3
- import pytest
4
- from pathlib import Path
5
- from conftest import generate_sine_wav
6
- from app.model import transcribe_file
7
-
8
- @pytest.mark.skipif(os.getenv("RUN_LONG_TESTS", "0") != "1", reason="long tests disabled by default")
9
- def test_very_long_audio_runtime(tmp_path, fake_pipeline):
10
- # create 10min (600s) wav β€” heavy; enabled only if RUN_LONG_TESTS=1
11
- p = tmp_path / "very_long.wav"
12
- generate_sine_wav(str(p), duration_s=600.0) # 10 minutes
13
- # measure time per minute
14
- start = time.time()
15
- text = transcribe_file(fake_pipeline, str(p), max_chunk_length=30.0, overlap_s=5.0)
16
- elapsed = time.time() - start
17
- # compute approx seconds per minute of audio processed
18
- avg_sec_per_min = elapsed / (600.0 / 60.0)
19
- print(f"Elapsed {elapsed:.2f}s; avg seconds per audio-minute: {avg_sec_per_min:.2f}")
20
- # Basic assert: completed and returned a string
21
- assert isinstance(text, str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test/test_short_and_chunk.py DELETED
@@ -1,46 +0,0 @@
1
- import os
2
- from pathlib import Path
3
- import pytest
4
- from app.model import transcribe_file, _split_audio_to_chunks, _merge_transcripts
5
- from conftest import generate_sine_wav
6
-
7
- def test_short_audio_direct(tmp_path, fake_pipeline):
8
- # create short wav: 5s
9
- p = tmp_path / "short.wav"
10
- generate_sine_wav(str(p), duration_s=5.0)
11
- # call transcribe_file with chunk threshold 30s -> should not chunk
12
- text = transcribe_file(fake_pipeline, str(p), max_chunk_length=30.0, overlap_s=5.0)
13
- # fake pipeline returns filename-stem as text for non-chunk files
14
- assert "short" in text or len(text) > 0
15
-
16
- def test_chunk_split_and_merge(tmp_path):
17
- # Create audio ~75s to force chunking into 3 chunks with (L=30, O=5) -> starts 0,25,50
18
- p = tmp_path / "long75.wav"
19
- # 75s sine; note: generating >60s may be heavy; for CI shorten if needed
20
- from conftest import generate_sine_wav
21
- generate_sine_wav(str(p), duration_s=75.0)
22
- # Use internal split function to inspect chunking
23
- chunks = _split_audio_to_chunks(str(p), chunk_length_s=30.0, overlap_s=5.0)
24
- # Expect at least 3 chunks
25
- assert len(chunks) >= 3
26
- # Simulate pipeline outputs for each chunk like in fake_model_from_path
27
- simulated_texts = []
28
- for idx, c in enumerate(chunks):
29
- # derive same pattern as fake_model_from_path: chunk{i} -> list words
30
- # We just test merging behavior: create overlapping words
31
- if idx == 0:
32
- simulated_texts.append("alpha beta gamma")
33
- elif idx == 1:
34
- simulated_texts.append("beta gamma delta")
35
- elif idx == 2:
36
- simulated_texts.append("gamma delta epsilon")
37
- else:
38
- simulated_texts.append(f"chunk{idx}")
39
- # Merge one by one
40
- merged = ""
41
- for t in simulated_texts:
42
- merged = _merge_transcripts(merged, t, max_overlap_words=5)
43
- # After merging, no duplicate immediate sequences like "beta gamma beta gamma"
44
- assert "beta gamma beta" not in merged
45
- # Ensure merged contains parts from all chunks
46
- assert "alpha" in merged and "epsilon" in merged
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test/test_silence_and_overlap.py DELETED
@@ -1,12 +0,0 @@
1
- from app.model import _merge_transcripts
2
- import pytest
3
-
4
- def test_silence_edge_overlap():
5
- # Simulate chunk outputs where chunk1 ends with filler repeated and chunk2 starts with filler
6
- a = "hello um um"
7
- b = "um um good morning"
8
- merged = _merge_transcripts(a, b, max_overlap_words=4)
9
- # Should not create triple 'um' (heuristic will remove overlap)
10
- assert "um um um" not in merged
11
- # Should still contain core words
12
- assert "hello" in merged and "good" in merged