bichnhan2701 commited on
Commit
7701a0c
·
1 Parent(s): 7158b5e

update phowhisper verver

Browse files
app/config/settings.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
 
3
  MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 100 * 1024 * 1024))
4
  MAX_DURATION_SECS = int(os.getenv("MAX_DURATION_SECS", 60 * 60))
@@ -23,3 +24,10 @@ REDIS_URL = os.getenv(
23
  )
24
 
25
  HTTPX_TIMEOUT = float(os.getenv("HTTPX_TIMEOUT", "10.0"))
 
 
 
 
 
 
 
 
1
  import os
2
+ from pydantic import BaseSettings
3
 
4
  MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 100 * 1024 * 1024))
5
  MAX_DURATION_SECS = int(os.getenv("MAX_DURATION_SECS", 60 * 60))
 
24
  )
25
 
26
  HTTPX_TIMEOUT = float(os.getenv("HTTPX_TIMEOUT", "10.0"))
27
+
28
+ class Settings(BaseSettings):
29
+ CLOUDINARY_CLOUD_NAME: str
30
+ CLOUDINARY_API_KEY: str
31
+ CLOUDINARY_API_SECRET: str
32
+
33
+ settings = Settings()
app/core/asr_engine.py CHANGED
@@ -1,189 +1,112 @@
1
- # PhoWhisper inference engine
2
-
3
  import logging
 
 
 
4
  from transformers import pipeline
5
- from app.config.settings import MODEL_NAME
6
- from app.core.chunking import split_audio_to_chunks, ffmpeg_extract_segment
7
- from app.core.audio_utils import make_temp_path
8
- import os
9
- from concurrent.futures import ThreadPoolExecutor, as_completed
10
- from app.core.audio_utils import get_audio_info, make_temp_path
11
-
12
- _model = None
13
-
14
- def load_model(chunk_length_s: int = None):
15
- global _model
16
- if _model is None:
17
- logging.info(f"Loading ASR model {MODEL_NAME} ...")
18
- kwargs = {}
19
- if chunk_length_s is not None:
20
- kwargs["chunk_length_s"] = chunk_length_s
21
- _model = pipeline("automatic-speech-recognition", MODEL_NAME, **kwargs)
22
- logging.info("Model loaded")
23
- return _model
24
-
25
- def merge_chunks(chunks, max_overlap_words=12):
26
- merged = []
27
-
28
- for ch in chunks:
29
- if not merged:
30
- merged.append(ch)
31
- continue
32
 
33
- prev = merged[-1]
34
- merged_text = merge_transcripts(
35
- prev["text"],
36
- ch["text"],
37
- max_overlap_words=max_overlap_words
38
- )
39
 
40
- if merged_text != prev["text"]:
41
- prev["text"] = merged_text
42
- prev["end"] = ch["end"]
43
- else:
44
- merged.append(ch)
45
 
46
- return merged
47
 
48
- def normalize_chunks(chunks):
49
- normalized = []
50
- last_end = 0.0
 
 
 
 
 
51
 
52
- for ch in chunks:
53
- start = max(ch["start"], last_end)
54
- end = max(start, ch["end"])
55
 
56
- text = ch["text"].strip()
57
- if not text:
58
- continue
59
 
60
- normalized.append({
61
- "start": round(start, 3),
62
- "end": round(end, 3),
63
- "text": text
64
- })
65
-
66
- last_end = end
67
-
68
- return normalized
69
-
70
- # Heuristic merge for chunked transcripts
71
- def merge_transcripts(prev_text: str, new_text: str, max_overlap_words: int = 8) -> str:
72
- if not prev_text:
73
- return new_text
74
- p_words = prev_text.strip().split()
75
- n_words = new_text.strip().split()
76
- max_ol = min(max_overlap_words, len(p_words), len(n_words))
77
- best_k = 0
78
- for k in range(max_ol, 0, -1):
79
- if p_words[-k:] == n_words[:k]:
80
- best_k = k
81
- break
82
- if best_k > 0:
83
- merged = " ".join(p_words + n_words[best_k:])
84
- return merged
85
- for k in range(max_ol, 1, -1):
86
- seq = " ".join(p_words[-k:])
87
- if seq in new_text:
88
- idx = new_text.find(seq)
89
- merged = " ".join(p_words + new_text[idx + len(seq):].strip().split())
90
- return merged
91
- return prev_text.rstrip() + " " + new_text.lstrip()
92
-
93
- def transcribe_long_audio(model, wav_path: str, chunk_length_s: float = 30.0, overlap_s: float = 5.0, parallel: bool = False) -> str:
94
- chunks = split_audio_to_chunks(wav_path, chunk_length_s=chunk_length_s, overlap_s=overlap_s)
95
- logging.info(f"Split into {len(chunks)} chunks")
96
- texts = []
97
- if parallel:
98
- def process_chunk(path):
99
- try:
100
- out = model(path)
101
- if isinstance(out, dict):
102
- return out.get("text", "")
103
- return str(out)
104
- except Exception as e:
105
- logging.exception("Chunk inference failed")
106
- return ""
107
- with ThreadPoolExecutor(max_workers=2) as ex:
108
- futures = {ex.submit(process_chunk, c): c for c in chunks}
109
- for fut in as_completed(futures):
110
- texts.append(fut.result() or "")
111
- else:
112
- for c in chunks:
113
- out = model(c)
114
- if isinstance(out, dict):
115
- texts.append(out.get("text", "") or "")
116
- else:
117
- texts.append(str(out) or "")
118
- merged = ""
119
- for t in texts:
120
- merged = merge_transcripts(merged, t, max_overlap_words=12)
121
- for c in chunks:
122
- try:
123
- os.remove(c)
124
- except Exception:
125
- pass
126
- return merged
127
-
128
- def transcribe_file(model, wav_path: str, max_chunk_length: float = 30.0, overlap_s: float = 5.0):
129
- info = get_audio_info(wav_path) or {}
130
- duration = info.get("duration", 0.0)
131
- if duration and duration > max_chunk_length * 1.1:
132
- logging.info(f"Long audio detected ({duration}s) -> chunking")
133
- return transcribe_long_audio(model, wav_path, chunk_length_s=max_chunk_length, overlap_s=overlap_s)
134
- out = model(wav_path)
135
- if isinstance(out, dict):
136
- return out.get("text") or ""
137
- return str(out)
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def transcribe_file_chunks(
140
  model,
141
  wav_path: str,
142
- max_chunk_length: float = 30.0,
143
- overlap_s: float = 5.0,
144
- ):
145
- info = get_audio_info(wav_path) or {}
146
- duration = info.get("duration", 0.0)
147
-
148
- step = max_chunk_length - overlap_s
149
- if step <= 0:
150
- raise ValueError("max_chunk_length must be > overlap_s")
151
-
152
- starts = []
153
- t = 0.0
154
- while t < duration:
155
- starts.append(t)
156
- t += step
157
-
158
- raw_chunks = []
159
-
160
- for i, s in enumerate(starts):
161
- chunk_end = min(s + max_chunk_length, duration)
162
- dst = make_temp_path(suffix=f".chunk{i}.wav")
163
-
164
- ffmpeg_extract_segment(wav_path, s, chunk_end - s, dst)
165
-
166
- out = model(dst)
167
- text = out.get("text", "") if isinstance(out, dict) else str(out)
168
-
169
- raw_chunks.append({
170
- "start": s,
171
- "end": chunk_end,
172
- "text": text
173
- })
174
-
175
- try:
176
- os.remove(dst)
177
- except Exception:
178
- pass
179
-
180
- # 🔽 CHUỖI XỬ LÝ CHUẨN
181
- merged = merge_chunks(raw_chunks)
182
- normalized = normalize_chunks(merged)
183
- logging.info(
184
- "ASR result: raw=%d merged=%d normalized=%d",
185
- len(raw_chunks),
186
- len(merged),
187
- len(normalized),
188
  )
189
- return normalized
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from typing import List, Dict
3
+
4
+ import torch
5
  from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ logger = logging.getLogger(__name__)
 
 
 
 
 
8
 
9
+ # ===============================
10
+ # Global model cache
11
+ # ===============================
12
+ _ASR_MODEL = None
 
13
 
 
14
 
15
+ def load_model(chunk_length_s: float = 30.0):
16
+ """
17
+ Load ASR model once and reuse.
18
+ Safe to call multiple times.
19
+ """
20
+ global _ASR_MODEL
21
+ if _ASR_MODEL is not None:
22
+ return _ASR_MODEL
23
 
24
+ logger.info("Loading ASR model PhoWhisper-base")
 
 
25
 
26
+ device = 0 if torch.cuda.is_available() else -1
27
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
28
 
29
+ _ASR_MODEL = pipeline(
30
+ task="automatic-speech-recognition",
31
+ model="vinai/PhoWhisper-base",
32
+ device=device,
33
+ torch_dtype=torch_dtype,
34
+ chunk_length_s=chunk_length_s,
35
+ return_timestamps=True,
36
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ logger.info(
39
+ "ASR model loaded (device=%s)", "cuda" if device >= 0 else "cpu"
40
+ )
41
+ return _ASR_MODEL
42
+
43
+
44
+ # ===============================
45
+ # Transcribe full text
46
+ # ===============================
47
+ def transcribe_file(
48
+ model,
49
+ wav_path: str,
50
+ chunk_length_s: float = 30.0,
51
+ stride_s: float = 5.0,
52
+ ) -> str:
53
+ """
54
+ Return full transcript text.
55
+ """
56
+ if not wav_path:
57
+ return ""
58
+
59
+ out = model(
60
+ wav_path,
61
+ chunk_length_s=chunk_length_s,
62
+ stride_length_s=stride_s,
63
+ )
64
+
65
+ text = out.get("text", "")
66
+ return text.strip() if text else ""
67
+
68
+
69
+ # ===============================
70
+ # Transcribe chunks with timestamps
71
+ # ===============================
72
  def transcribe_file_chunks(
73
  model,
74
  wav_path: str,
75
+ chunk_length_s: float = 30.0,
76
+ stride_s: float = 5.0,
77
+ ) -> List[Dict]:
78
+ """
79
+ Return list of chunks:
80
+ [{ start, end, text }]
81
+ """
82
+ if not wav_path:
83
+ return []
84
+
85
+ out = model(
86
+ wav_path,
87
+ chunk_length_s=chunk_length_s,
88
+ stride_length_s=stride_s,
89
+ return_timestamps=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
91
+
92
+ chunks = []
93
+ for c in out.get("chunks", []) or []:
94
+ ts = c.get("timestamp") or [None, None]
95
+ start, end = ts if len(ts) == 2 else (None, None)
96
+
97
+ text = (c.get("text") or "").strip()
98
+ if not text:
99
+ continue
100
+
101
+ if start is None or end is None:
102
+ continue
103
+
104
+ chunks.append(
105
+ {
106
+ "start": float(start),
107
+ "end": float(end),
108
+ "text": text,
109
+ }
110
+ )
111
+
112
+ return chunks
app/core/audio_utils.py CHANGED
@@ -1,11 +1,15 @@
1
  # Audio utilities: ffmpeg, normalization, etc.
 
2
  import subprocess
3
  import shlex
4
  import uuid
5
  import requests
6
  from pathlib import Path
7
  import soundfile as sf
8
- from app.config.settings import TMP_DIR, MAX_UPLOAD_BYTES
 
 
 
9
 
10
  def save_upload_file(upload_file, dest_path: str):
11
  """Save FastAPI UploadFile to dest_path (streaming)."""
@@ -64,3 +68,45 @@ def ensure_wav_16k_mono(src_path: str, dest_path: str):
64
  def make_temp_path(suffix=".wav"):
65
  """Generate unique temp file path under TMP_DIR."""
66
  return str(Path(TMP_DIR) / f"{uuid.uuid4().hex}{suffix}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Audio utilities: ffmpeg, normalization, etc.
2
+ from asyncio.log import logger
3
  import subprocess
4
  import shlex
5
  import uuid
6
  import requests
7
  from pathlib import Path
8
  import soundfile as sf
9
+ from app.config.settings import TMP_DIR, MAX_UPLOAD_BYTES, settings
10
+ import cloudinary
11
+ import cloudinary.uploader
12
+ import os
13
 
14
  def save_upload_file(upload_file, dest_path: str):
15
  """Save FastAPI UploadFile to dest_path (streaming)."""
 
68
  def make_temp_path(suffix=".wav"):
69
  """Generate unique temp file path under TMP_DIR."""
70
  return str(Path(TMP_DIR) / f"{uuid.uuid4().hex}{suffix}")
71
+
72
+ # init once
73
+ cloudinary.config(
74
+ cloud_name=settings.CLOUDINARY_CLOUD_NAME,
75
+ api_key=settings.CLOUDINARY_API_KEY,
76
+ api_secret=settings.CLOUDINARY_API_SECRET,
77
+ secure=True,
78
+ )
79
+
80
+
81
+ def upload_temp_audio(
82
+ local_path: str,
83
+ *,
84
+ folder: str = "asr_uploads",
85
+ public_id: str | None = None,
86
+ ttl: int = 3600,
87
+ ) -> str:
88
+ """
89
+ Upload audio file to Cloudinary and return public URL.
90
+ File can be safely deleted locally after upload.
91
+ """
92
+
93
+ if not os.path.exists(local_path):
94
+ raise FileNotFoundError(local_path)
95
+
96
+ logger.info("Uploading audio to Cloudinary: %s", local_path)
97
+
98
+ result = cloudinary.uploader.upload(
99
+ local_path,
100
+ resource_type="video", # ⚠️ audio MUST use video
101
+ folder=folder,
102
+ public_id=public_id,
103
+ overwrite=True,
104
+ invalidate=True,
105
+ )
106
+
107
+ url = result.get("secure_url")
108
+ if not url:
109
+ raise RuntimeError("Cloudinary upload failed")
110
+
111
+ logger.info("Uploaded audio -> %s", url)
112
+ return url
app/jobs/transcribe_job.py CHANGED
@@ -1,41 +1,65 @@
1
  import asyncio
 
 
 
 
2
  from app.core.asr_engine import load_model, transcribe_file, transcribe_file_chunks
3
  from app.services.note_client import NoteServiceClient
4
  from app.core.audio_utils import get_audio_info
5
 
6
- def transcribe_job(wav_path: str, note_id: str, user_id: str | None = None):
 
 
 
 
 
 
 
 
 
 
 
 
7
  model = load_model()
8
 
9
- # 🔥 ASR giống hệt API sync
10
- text = transcribe_file(model, wav_path, 30.0, 5.0)
11
- chunks = transcribe_file_chunks(model, wav_path, 30.0, 5.0)
12
-
13
- # drop invalid chunks (defensive)
14
- chunks = [
15
- c for c in chunks
16
- if c.get("text", "").strip() and c.get("end", 0) > c.get("start", 0)
17
- ]
18
-
19
- note_status = "transcribed" if chunks else "error"
20
-
21
- info = get_audio_info(wav_path) or {}
22
-
23
- payload = {
24
- "note_id": note_id,
25
- "type": "audio",
26
- "status": note_status,
27
- "raw_text": text,
28
- "metadata": {
29
- "audio": {
30
- "duration": info.get("duration"),
31
- "sample_rate": info.get("samplerate"),
32
- "chunks": chunks,
33
- "asr_model": "PhoWhisper-base",
 
 
 
 
 
34
  },
35
- "client": {"user_id": user_id},
36
- },
37
- "generate": ["normalize", "keywords", "summary", "mindmap"],
38
- }
 
39
 
40
- client = NoteServiceClient()
41
- asyncio.run(client.create_audio_note(payload))
 
 
 
1
  import asyncio
2
+ import tempfile
3
+ import os
4
+ import requests
5
+
6
  from app.core.asr_engine import load_model, transcribe_file, transcribe_file_chunks
7
  from app.services.note_client import NoteServiceClient
8
  from app.core.audio_utils import get_audio_info
9
 
10
+
11
+ def download_audio(audio_url: str) -> str:
12
+ r = requests.get(audio_url, timeout=30)
13
+ r.raise_for_status()
14
+
15
+ fd, path = tempfile.mkstemp(suffix=".wav")
16
+ with os.fdopen(fd, "wb") as f:
17
+ f.write(r.content)
18
+
19
+ return path
20
+
21
+
22
+ def transcribe_job(audio_url: str, note_id: str, user_id: str | None = None):
23
  model = load_model()
24
 
25
+ wav_path = None
26
+ try:
27
+ # 1️⃣ Worker tự fetch audio
28
+ wav_path = download_audio(audio_url)
29
+
30
+ # 2️⃣ ASR
31
+ text = transcribe_file(model, wav_path, 30.0, 5.0)
32
+ chunks = transcribe_file_chunks(model, wav_path, 30.0, 5.0)
33
+
34
+ chunks = [
35
+ c for c in chunks
36
+ if c.get("text", "").strip() and c.get("end", 0) > c.get("start", 0)
37
+ ]
38
+
39
+ note_status = "transcribed" if chunks else "error"
40
+ info = get_audio_info(wav_path) or {}
41
+
42
+ payload = {
43
+ "note_id": note_id,
44
+ "type": "audio",
45
+ "status": note_status,
46
+ "raw_text": text,
47
+ "metadata": {
48
+ "audio": {
49
+ "duration": info.get("duration"),
50
+ "sample_rate": info.get("samplerate"),
51
+ "chunks": chunks,
52
+ "asr_model": "PhoWhisper-base",
53
+ },
54
+ "client": {"user_id": user_id},
55
  },
56
+ "generate": ["normalize", "keywords", "summary", "mindmap"],
57
+ }
58
+
59
+ client = NoteServiceClient()
60
+ asyncio.run(client.create_audio_note(payload))
61
 
62
+ finally:
63
+ # 3️⃣ Cleanup
64
+ if wav_path and os.path.exists(wav_path):
65
+ os.remove(wav_path)
requirements.txt CHANGED
@@ -13,4 +13,5 @@ prometheus-client
13
  google-generativeai
14
  google-genai
15
  numpy
16
- pytest
 
 
13
  google-generativeai
14
  google-genai
15
  numpy
16
+ pytest
17
+ cloudinary