colab-user commited on
Commit
c90efd5
·
1 Parent(s): e846326

optimize finetuned model

Browse files
Files changed (2) hide show
  1. app/services/processor.py +311 -125
  2. app/services/transcription.py +131 -70
app/services/processor.py CHANGED
@@ -12,16 +12,22 @@ import librosa
12
  import torch
13
 
14
  from app.core.config import get_settings
15
- from app.services.transcription import TranscriptionService, WordTimestamp
16
  from app.services.diarization import DiarizationService, SpeakerSegment, DiarizationResult
17
- from app.services.alignment import AlignmentService
18
- from app.schemas.models import TranscriptSegment
19
-
20
 
21
  logger = logging.getLogger(__name__)
22
  settings = get_settings()
23
 
24
 
 
 
 
 
 
 
 
 
 
25
 
26
  @dataclass
27
  class ProcessingResult:
@@ -36,45 +42,76 @@ class ProcessingResult:
36
  txt_content: str = ""
37
  csv_content: str = ""
38
 
39
-
40
-
41
- def pad_and_refine_tensor(
42
- waveform: torch.Tensor,
43
- sr: int,
44
- start_s: float,
45
- end_s: float,
46
- pad_ms: int = 200,
47
- ) -> Tuple[float, float]:
48
-
49
- total_len = waveform.shape[1]
50
- s = max(int((start_s - pad_ms / 1000) * sr), 0)
51
- e = min(int((end_s + pad_ms / 1000) * sr), total_len)
52
-
53
- if e <= s:
54
- return start_s, end_s
55
-
56
- return s / sr, e / sr
57
-
58
-
59
- def infer_roles_from_diarization(
60
  diarization_segments: List[SpeakerSegment],
61
- ) -> Dict[str, str]:
62
-
63
- dur = defaultdict(float)
64
-
65
- for s in diarization_segments:
66
- dur[s.speaker] += s.end - s.start
67
-
68
- if not dur:
69
- return {}
70
-
71
- agent = max(dur, key=dur.get)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- return {
74
- spk: ("NV" if spk == agent else "KH")
75
- for spk in dur
76
- }
77
 
 
78
 
79
  def convert_audio_to_wav(audio_path: Path) -> Path:
80
  """Convert any audio to WAV 16kHz Mono using ffmpeg."""
@@ -89,14 +126,59 @@ def convert_audio_to_wav(audio_path: Path) -> Path:
89
  except subprocess.CalledProcessError as e:
90
  logger.error(f"FFmpeg conversion failed: {e}")
91
  return audio_path
92
-
93
-
94
  def format_timestamp(seconds: float) -> str:
95
  m = int(seconds // 60)
96
  s = seconds % 60
97
  return f"{m:02d}:{s:06.3f}"
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # =========================
101
  # Processor
102
  # =========================
@@ -105,9 +187,20 @@ class Processor:
105
  async def process_audio(
106
  cls,
107
  audio_path: Path,
 
108
  language: str = "vi",
 
 
 
 
 
 
 
 
109
  beam_size: int = 5,
110
  temperature: float = 0.0,
 
 
111
  ) -> ProcessingResult:
112
 
113
  import asyncio
@@ -120,97 +213,172 @@ class Processor:
120
 
121
  # 2: Load audio
122
  y, sr = librosa.load(wav_path, sr=16000, mono=True)
123
- waveform = torch.from_numpy(y).unsqueeze(0)
 
 
124
  duration = len(y) / sr
125
-
126
  # 3: Diarization
 
 
127
  diarization: DiarizationResult = await DiarizationService.diarize_async(wav_path)
128
- diarization_segments = diarization.segments or [
129
- SpeakerSegment(0.0, duration, "SPEAKER_0")
130
- ]
131
-
132
- diarization_segments.sort(key=lambda s: s.start)
133
 
134
- diarization_segments = [
135
- SpeakerSegment(
136
- *pad_and_refine_tensor(waveform, sr, s.start, s.end),
137
- speaker=s.speaker,
138
- )
139
- for s in diarization_segments
140
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- # 4. SPEAKER MAP
143
- raw_speakers = sorted({s.speaker for s in diarization_segments})
144
- speaker_map = {s: f"Speaker {i+1}" for i, s in enumerate(raw_speakers)}
145
-
146
- diarization_segments = [
147
- SpeakerSegment(
148
- start=s.start,
149
- end=s.end,
150
- speaker=speaker_map[s.speaker]
151
- )
152
- for s in diarization_segments
153
- ]
154
- # 5. Roles infer
155
- roles = infer_roles_from_diarization(diarization_segments)
156
-
157
- result = await TranscriptionService.transcribe_with_words_async(
158
- audio_array=y,
159
- language=language,
160
- beam_size=beam_size,
161
- temperature=temperature
162
- )
163
-
164
- words: List[WordTimestamp] = [
165
- WordTimestamp(
166
- word=w["word"],
167
- start=float(w["start"]),
168
- end=float(w["end"]),
169
- )
170
- for w in result.get("words", [])
171
- if w.get("word")
172
- ]
173
 
 
 
 
174
 
175
- aligned_segments = AlignmentService.align_precision(
176
- words,
177
- diarization_segments
178
- )
179
 
180
- segments = [
181
- TranscriptSegment(
182
- start=s.start,
183
- end=s.end,
184
- speaker=s.speaker,
185
- role=roles.get(s.speaker, "KH"),
186
- text=s.text,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  )
188
- for s in aligned_segments
189
- ]
190
-
191
- segments = cls._filter_segments_with_context(segments)
192
 
193
- segments = cls._merge_adjacent_segments(
194
- segments,
 
 
 
 
 
 
 
 
 
 
 
195
  max_gap_s=0.6
196
  )
197
 
198
- processing_time = time.time() - t0
199
- speakers = sorted({s.speaker for s in segments})
200
 
201
  txt_content = cls._generate_txt(
202
- segments,
203
- len(speaker_map),
204
  processing_time,
205
  duration,
206
  roles
207
  )
208
 
209
- csv_content = cls._generate_csv(segments)
210
 
211
  return ProcessingResult(
212
- segments=segments,
213
- speaker_count=len(speaker_map),
214
  duration=duration,
215
  processing_time=processing_time,
216
  speakers=speakers,
@@ -224,36 +392,44 @@ class Processor:
224
  def _is_meaningful_segment(
225
  seg: TranscriptSegment,
226
  min_duration_s: float = 0.6,
227
- min_words: int = 3,
 
228
  ) -> bool:
229
- if seg.end - seg.start >= min_duration_s:
 
 
 
230
  return True
231
- if len(seg.text.split()) >= min_words:
232
  return True
233
  if seg.role == "KH":
234
  return True
 
235
  return False
236
 
237
  @classmethod
238
  def _filter_segments_with_context(
239
  cls,
240
- segments: List[TranscriptSegment],
241
  ) -> List[TranscriptSegment]:
242
-
243
  if not segments:
244
  return segments
245
-
246
  segments = sorted(segments, key=lambda s: s.start)
247
  result = []
 
248
 
249
  for i, seg in enumerate(segments):
250
- prev = segments[i - 1] if i > 0 else None
251
- next_ = segments[i + 1] if i < len(segments) - 1 else None
252
 
253
  if cls._is_meaningful_segment(seg):
254
  result.append(seg)
255
- elif prev and next_ and prev.speaker == seg.speaker == next_.speaker:
256
- result.append(seg)
 
 
 
257
 
258
  return result
259
 
@@ -261,20 +437,30 @@ class Processor:
261
  @staticmethod
262
  def _merge_adjacent_segments(
263
  segments: List[TranscriptSegment],
264
- max_gap_s: float = 0.5,
265
  ) -> List[TranscriptSegment]:
266
-
 
 
 
 
267
  if not segments:
268
  return segments
269
 
 
270
  merged = [segments[0]]
271
 
272
  for seg in segments[1:]:
273
  prev = merged[-1]
 
274
  gap = seg.start - prev.end
275
 
276
- if seg.speaker == prev.speaker and gap <= max_gap_s:
277
- prev.text += " " + seg.text
 
 
 
 
278
  prev.end = max(prev.end, seg.end)
279
  else:
280
  merged.append(seg)
 
12
  import torch
13
 
14
  from app.core.config import get_settings
15
+ from app.services.transcription import TranscriptionService
16
  from app.services.diarization import DiarizationService, SpeakerSegment, DiarizationResult
 
 
 
17
 
18
  logger = logging.getLogger(__name__)
19
  settings = get_settings()
20
 
21
 
22
+ @dataclass
23
+ class TranscriptSegment:
24
+ """A transcribed segment with speaker info."""
25
+ start: float
26
+ end: float
27
+ speaker: str
28
+ role: Optional[str]
29
+ text: str
30
+
31
 
32
  @dataclass
33
  class ProcessingResult:
 
42
  txt_content: str = ""
43
  csv_content: str = ""
44
 
45
+ def assign_speaker_to_word(
46
+ word_start: float,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  diarization_segments: List[SpeakerSegment],
48
+ ) -> str:
49
+ """
50
+ Assign speaker to word using diarization.
51
+ """
52
+ for seg in diarization_segments:
53
+ if seg.start <= word_start <= seg.end:
54
+ return seg.speaker
55
+
56
+ # fallback: nearest diar segment
57
+ return min(
58
+ diarization_segments,
59
+ key=lambda s: abs((s.start + s.end) / 2 - word_start)
60
+ ).speaker
61
+
62
+
63
+ def group_words_into_segments(
64
+ words: List[dict],
65
+ diarization_segments: List[SpeakerSegment],
66
+ speaker_map: Dict[str, str],
67
+ roles: Dict[str, str],
68
+ max_word_gap_s: float = 0.6,
69
+ ) -> List[TranscriptSegment]:
70
+
71
+ segments: List[TranscriptSegment] = []
72
+ current: Optional[TranscriptSegment] = None
73
+
74
+ for w in words:
75
+ text = w.get("word", "").strip()
76
+ if not text:
77
+ continue
78
+
79
+ w_start = float(w["start"])
80
+ w_end = float(w["end"])
81
+
82
+ speaker_raw = assign_speaker_to_word(w_start, diarization_segments)
83
+ speaker = speaker_map.get(speaker_raw, speaker_raw)
84
+ role = roles.get(speaker, "KH")
85
+
86
+ if current is None:
87
+ current = TranscriptSegment(
88
+ start=w_start,
89
+ end=w_end,
90
+ speaker=speaker,
91
+ role=role,
92
+ text=text,
93
+ )
94
+ continue
95
+
96
+ gap = w_start - current.end
97
+
98
+ if speaker == current.speaker and gap <= max_word_gap_s:
99
+ current.text += " " + text
100
+ current.end = max(current.end, w_end)
101
+ else:
102
+ segments.append(current)
103
+ current = TranscriptSegment(
104
+ start=w_start,
105
+ end=w_end,
106
+ speaker=speaker,
107
+ role=role,
108
+ text=text,
109
+ )
110
 
111
+ if current:
112
+ segments.append(current)
 
 
113
 
114
+ return segments
115
 
116
  def convert_audio_to_wav(audio_path: Path) -> Path:
117
  """Convert any audio to WAV 16kHz Mono using ffmpeg."""
 
126
  except subprocess.CalledProcessError as e:
127
  logger.error(f"FFmpeg conversion failed: {e}")
128
  return audio_path
129
+
130
+
131
  def format_timestamp(seconds: float) -> str:
132
  m = int(seconds // 60)
133
  s = seconds % 60
134
  return f"{m:02d}:{s:06.3f}"
135
 
136
 
137
+ def pad_and_refine_tensor(
138
+ waveform: torch.Tensor,
139
+ sr: int,
140
+ start_s: float,
141
+ end_s: float,
142
+ pad_ms: int = 200,
143
+ silence_db_delta: float = 16,
144
+ min_duration_ms: int = 150,
145
+ ) -> Optional[Tuple[int, int]]:
146
+ """
147
+ Refine segment using energy on TORCH tensor.
148
+ Returns sample indices or None.
149
+ """
150
+ total_len = waveform.shape[1]
151
+
152
+ start_idx = max(int((start_s - pad_ms / 1000) * sr), 0)
153
+ end_idx = min(int((end_s + pad_ms / 1000) * sr), total_len)
154
+
155
+ if end_idx <= start_idx:
156
+ return None
157
+
158
+ segment = waveform[:, start_idx:end_idx]
159
+ if segment .numel() == 0:
160
+ return None
161
+
162
+ # RMS energy
163
+ rms = torch.sqrt(torch.mean(segment ** 2) + 1e-9)
164
+ threshold = rms / silence_db_delta
165
+
166
+ energy = torch.abs(segment)
167
+ valid = torch.where(energy > threshold)[0]
168
+
169
+ if valid.numel() == 0:
170
+ return None
171
+
172
+
173
+ new_start = start_idx + valid[0].item()
174
+ new_end = start_idx + valid[-1].item()
175
+
176
+ if new_end - new_start < int(min_duration_ms / 1000 * sr):
177
+ return None
178
+
179
+ return new_start, new_end
180
+
181
+
182
  # =========================
183
  # Processor
184
  # =========================
 
187
  async def process_audio(
188
  cls,
189
  audio_path: Path,
190
+ model_name: str = "PhoWhisper VI Finetuned",
191
  language: str = "vi",
192
+
193
+ # VAD options
194
+ vad_filter: bool = True,
195
+ vad_min_silence_ms: int = 1000,
196
+ vad_speech_pad_ms: int = 400,
197
+ vad_min_speech_ms: int = 250,
198
+ vad_threshold: float = 0.5,
199
+ # Generation options
200
  beam_size: int = 5,
201
  temperature: float = 0.0,
202
+ best_of: int = 5,
203
+ initial_prompt: Optional[str] = None,
204
  ) -> ProcessingResult:
205
 
206
  import asyncio
 
213
 
214
  # 2: Load audio
215
  y, sr = librosa.load(wav_path, sr=16000, mono=True)
216
+ if y.size == 0:
217
+ raise ValueError("Empty audio")
218
+ waveform = torch.from_numpy(y).unsqueeze(0).float()
219
  duration = len(y) / sr
220
+
221
  # 3: Diarization
222
+ logger.info("Step 3: Running diarization...")
223
+
224
  diarization: DiarizationResult = await DiarizationService.diarize_async(wav_path)
 
 
 
 
 
225
 
226
+ diarization_segments = diarization.segments or []
227
+ speakers = diarization.speakers or []
228
+ roles = diarization.roles or {}
229
+
230
+ if not diarization_segments:
231
+ diarization_segments = [SpeakerSegment(0.0, duration, "SPEAKER_0")]
232
+ speakers = ["SPEAKER_0"]
233
+ roles = {"SPEAKER_0": "KH"}
234
+
235
+ diarization_segments.sort(key=lambda x: x.start)
236
+
237
+ # 4: Refine segment boundaries
238
+ refined_segments: List[SpeakerSegment] = []
239
+
240
+ for seg in diarization_segments:
241
+ refined = pad_and_refine_tensor(waveform, sr, seg.start, seg.end)
242
+
243
+ if refined:
244
+ s, e = refined
245
+ if e > s:
246
+ refined_segments.append(
247
+ SpeakerSegment(
248
+ start=s / sr,
249
+ end=e / sr,
250
+ speaker=seg.speaker,
251
+ )
252
+ )
253
+ else:
254
+ refined_segments.append(seg)
255
+ else:
256
+ refined_segments.append(seg)
257
 
258
+ if not refined_segments:
259
+ refined_segments = diarization_segments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
+
262
+ # 5. Normalize speakers
263
+ raw_speakers = sorted({seg.speaker for seg in refined_segments})
264
 
265
+ speaker_map = {
266
+ spk: f"Speaker {i+1}"
267
+ for i, spk in enumerate(raw_speakers)
268
+ }
269
 
270
+ speakers = list(speaker_map.values())
271
+
272
+ # 6. NORMALIZE ROLES
273
+ speaker_duration = defaultdict(float)
274
+ for seg in refined_segments:
275
+ speaker_duration[seg.speaker] += seg.end - seg.start
276
+
277
+ logger.info(f"speaker_duration(raw) = {speaker_duration}")
278
+
279
+ if speaker_duration:
280
+ agent_raw = max(speaker_duration, key=speaker_duration.get)
281
+
282
+ roles = {
283
+ speaker_map[spk]: ("NV" if spk == agent_raw else "KH")
284
+ for spk in speaker_duration
285
+ }
286
+ else:
287
+ roles = {}
288
+
289
+ # Default fallback
290
+ for label in speakers:
291
+ roles.setdefault(label, "KH")
292
+
293
+ logger.info(f"roles(mapped) = {roles}")
294
+
295
+
296
+
297
+ # 7: Transcribe
298
+ vad_options = None
299
+ if vad_filter:
300
+ vad_options = {
301
+ "min_silence_duration_ms": vad_min_silence_ms,
302
+ "speech_pad_ms": vad_speech_pad_ms,
303
+ "min_speech_duration_ms": vad_min_speech_ms,
304
+ "threshold": vad_threshold
305
+ }
306
+
307
+ processed_segments: List[TranscriptSegment] = []
308
+
309
+ for seg in refined_segments:
310
+ start = int(seg.start * sr)
311
+ end = int(seg.end * sr)
312
+
313
+ if end <= start:
314
+ continue
315
+
316
+ audio_slice = y[start:end]
317
+ if audio_slice.size < sr * 0.25:
318
+ continue
319
+
320
+ try:
321
+ text = await TranscriptionService.transcribe_with_words_async(
322
+ audio_array=audio_slice,
323
+ model_name=model_name,
324
+ language=language,
325
+ vad_options=vad_options,
326
+ beam_size=beam_size,
327
+ temperature=temperature,
328
+ best_of=best_of,
329
+ initial_prompt=initial_prompt,
330
+ )
331
+ except Exception as e:
332
+ logger.error(f"Transcribe error: {e}")
333
+ continue
334
+
335
+ if not text or not text.strip():
336
+ continue
337
+
338
+ label = speaker_map.get(seg.speaker, seg.speaker)
339
+
340
+ processed_segments.append(
341
+ TranscriptSegment(
342
+ start=seg.start,
343
+ end=seg.end,
344
+ speaker=label,
345
+ role=roles[label],
346
+ text=text.strip(),
347
+ )
348
  )
 
 
 
 
349
 
350
+ if not processed_segments:
351
+ processed_segments = [
352
+ TranscriptSegment(
353
+ start=0.0,
354
+ end=duration,
355
+ speaker=speakers[0],
356
+ role=roles[speakers[0]],
357
+ text="(No speech detected)"
358
+ )
359
+ ]
360
+
361
+ processed_segments = cls._merge_adjacent_segments(
362
+ processed_segments,
363
  max_gap_s=0.6
364
  )
365
 
366
+ processed_segments = cls._filter_segments_with_context(processed_segments)
367
+ processing_time = time.time() - t0
368
 
369
  txt_content = cls._generate_txt(
370
+ processed_segments,
371
+ len(speakers),
372
  processing_time,
373
  duration,
374
  roles
375
  )
376
 
377
+ csv_content = cls._generate_csv(processed_segments)
378
 
379
  return ProcessingResult(
380
+ segments=processed_segments,
381
+ speaker_count=len(speakers),
382
  duration=duration,
383
  processing_time=processing_time,
384
  speakers=speakers,
 
392
  def _is_meaningful_segment(
393
  seg: TranscriptSegment,
394
  min_duration_s: float = 0.6,
395
+ eps: float = 0.05,
396
+ min_words: int = 3
397
  ) -> bool:
398
+ duration = seg.end - seg.start
399
+ word_count = len(seg.text.split())
400
+
401
+ if duration + eps >= min_duration_s:
402
  return True
403
+ if word_count >= min_words:
404
  return True
405
  if seg.role == "KH":
406
  return True
407
+
408
  return False
409
 
410
  @classmethod
411
  def _filter_segments_with_context(
412
  cls,
413
+ segments: List[TranscriptSegment]
414
  ) -> List[TranscriptSegment]:
 
415
  if not segments:
416
  return segments
417
+
418
  segments = sorted(segments, key=lambda s: s.start)
419
  result = []
420
+ n = len(segments)
421
 
422
  for i, seg in enumerate(segments):
423
+ prev_seg = segments[i - 1] if i > 0 else None
424
+ next_seg = segments[i + 1] if i < n - 1 else None
425
 
426
  if cls._is_meaningful_segment(seg):
427
  result.append(seg)
428
+ continue
429
+
430
+ if prev_seg and next_seg:
431
+ if prev_seg.speaker == seg.speaker == next_seg.speaker:
432
+ result.append(seg)
433
 
434
  return result
435
 
 
437
  @staticmethod
438
  def _merge_adjacent_segments(
439
  segments: List[TranscriptSegment],
440
+ max_gap_s: float = 0.5
441
  ) -> List[TranscriptSegment]:
442
+ """
443
+ Merge adjacent segments if:
444
+ - same speaker
445
+ - gap <= max_gap_s
446
+ """
447
  if not segments:
448
  return segments
449
 
450
+ segments = sorted(segments, key=lambda s: s.start)
451
  merged = [segments[0]]
452
 
453
  for seg in segments[1:]:
454
  prev = merged[-1]
455
+
456
  gap = seg.start - prev.end
457
 
458
+ if (
459
+ seg.speaker == prev.speaker
460
+ and gap <= max_gap_s
461
+ ):
462
+ # MERGE
463
+ prev.text = f"{prev.text} {seg.text}".strip()
464
  prev.end = max(prev.end, seg.end)
465
  else:
466
  merged.append(seg)
app/services/transcription.py CHANGED
@@ -3,14 +3,11 @@ Transcription service using faster-whisper.
3
  Supports multiple Vietnamese Whisper models with caching.
4
  """
5
  import logging
6
- import torch
7
  from typing import Dict, Optional, List
8
  from dataclasses import dataclass
9
 
10
  import numpy as np
11
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
12
- from peft import PeftModel
13
-
14
 
15
  from app.core.config import get_settings
16
 
@@ -20,9 +17,7 @@ settings = get_settings()
20
 
21
  # Available Whisper models for Vietnamese
22
  AVAILABLE_MODELS = {
23
-
24
- "Whisper-LoRA": settings.whisper_lora_model_dir
25
-
26
  }
27
 
28
 
@@ -40,88 +35,134 @@ class TranscriptionService:
40
  Supports multiple models with caching.
41
  """
42
 
43
- _model = None
44
- _processor = None
45
- _device = "cuda" if torch.cuda.is_available() else "cpu"
46
 
47
  @classmethod
48
- def get_model(cls):
49
- if cls._model is not None:
50
- return cls._model, cls._processor
51
-
52
- model_dir = AVAILABLE_MODELS["Whisper-LoRA"]
53
-
54
- logger.info(f"Loading Whisper + LoRA from {model_dir}")
55
- logger.info(f"Device: {cls._device}")
56
-
57
- base_model = WhisperForConditionalGeneration.from_pretrained(model_dir)
58
- model = PeftModel.from_pretrained(base_model, model_dir)
59
-
60
- model.to(cls._device)
61
- model.eval()
62
-
63
- processor = WhisperProcessor.from_pretrained(model_dir)
64
-
65
- cls._model = model
66
- cls._processor = processor
67
-
68
- logger.info("Whisper + LoRA loaded successfully")
69
- return model, processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  @classmethod
72
- def is_loaded(cls) -> bool:
73
- return cls._model is not None
 
 
 
 
74
 
75
  @classmethod
76
- def preload_model(cls) -> None:
77
- cls.get_model()
 
 
 
 
 
 
 
78
 
79
  @classmethod
80
  def transcribe_with_words(
81
  cls,
82
  audio_array: np.ndarray,
 
83
  language: str = "vi",
 
84
  beam_size: int = 5,
85
  temperature: float = 0.0,
 
 
86
  ) -> Dict:
87
- model, processor = cls.get_model()
 
 
 
88
 
89
- if audio_array.ndim > 1:
90
- audio_array = np.mean(audio_array, axis=0)
91
 
92
- inputs = processor(
93
  audio_array,
94
- sampling_rate=16000,
95
- return_tensors="pt"
96
- ).input_features.to(cls._device)
97
-
98
- forced_decoder_ids = processor.get_decoder_prompt_ids(
99
- language=language,
100
- task="transcribe"
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  )
102
 
103
- with torch.no_grad():
104
- generated_ids = model.generate(
105
- inputs,
106
- forced_decoder_ids=forced_decoder_ids,
107
- num_beams=beam_size,
108
- temperature=temperature,
109
- max_new_tokens=settings.whisper_max_new_tokens,
110
- )
111
 
112
- text = processor.batch_decode(
113
- generated_ids,
114
- skip_special_tokens=True
115
- )[0].strip()
 
 
 
 
 
116
 
117
  return {
118
- "text": text,
119
- "words": [],
120
- "info": {
121
- "engine": "transformers-whisper-lora",
122
- "language": language,
123
- "beam_size": beam_size,
124
- },
125
  }
126
 
127
 
@@ -129,15 +170,35 @@ class TranscriptionService:
129
  async def transcribe_with_words_async(
130
  cls,
131
  audio_array: np.ndarray,
132
- **kwargs
133
- ) -> Dict:
 
 
 
 
 
 
 
 
 
134
  import asyncio
 
135
  loop = asyncio.get_event_loop()
136
  return await loop.run_in_executor(
137
  None,
138
- lambda: cls.transcribe_with_words(audio_array, **kwargs)
 
 
 
 
 
 
 
 
 
139
  )
140
-
141
  @classmethod
142
  def get_available_models(cls) -> Dict[str, str]:
 
143
  return AVAILABLE_MODELS.copy()
 
3
  Supports multiple Vietnamese Whisper models with caching.
4
  """
5
  import logging
 
6
  from typing import Dict, Optional, List
7
  from dataclasses import dataclass
8
 
9
  import numpy as np
10
+ from faster_whisper import WhisperModel
 
 
11
 
12
  from app.core.config import get_settings
13
 
 
17
 
18
  # Available Whisper models for Vietnamese
19
  AVAILABLE_MODELS = {
20
+ "PhoWhisper VI Finetuned": settings.default_whisper_model
 
 
21
  }
22
 
23
 
 
35
  Supports multiple models with caching.
36
  """
37
 
38
+ _models: Dict[str, WhisperModel] = {}
 
 
39
 
40
  @classmethod
41
+ def get_model(cls, model_name: str = None) -> WhisperModel:
42
+ """
43
+ Get or load a Whisper model (lazy loading with caching).
44
+
45
+ Args:
46
+ model_name: Name of the model from AVAILABLE_MODELS
47
+
48
+ Returns:
49
+ Loaded WhisperModel instance
50
+ """
51
+
52
+ if model_name is None:
53
+ model_name = settings.default_whisper_model
54
+
55
+ cache_key = f"{model_name}_{settings.resolved_compute_type}"
56
+
57
+ if cache_key in cls._models:
58
+ return cls._models[cache_key]
59
+
60
+ # Get model path
61
+ if model_name in AVAILABLE_MODELS:
62
+ model_path = AVAILABLE_MODELS[model_name]
63
+ else:
64
+ # Fallback to first available model
65
+ model_name = list(AVAILABLE_MODELS.keys())[0]
66
+ model_path = AVAILABLE_MODELS[model_name]
67
+
68
+ logger.info(f"Loading Whisper model: {model_name} ({model_path})")
69
+ logger.debug(f"Device: {settings.resolved_device}, Compute type: {settings.resolved_compute_type}")
70
+
71
+ model = WhisperModel(
72
+ model_path,
73
+ device=settings.resolved_device,
74
+ compute_type=settings.resolved_compute_type,
75
+ )
76
+
77
+ cls._models[cache_key] = model
78
+ logger.info(f"Whisper model loaded: {model_name}")
79
+
80
+ return model
81
 
82
  @classmethod
83
+ def is_loaded(cls, model_name: str = None) -> bool:
84
+ if model_name is None:
85
+ model_name = settings.default_whisper_model
86
+ """Check if a model is loaded."""
87
+ cache_key = f"{model_name}_{settings.resolved_compute_type}"
88
+ return cache_key in cls._models
89
 
90
  @classmethod
91
+ def preload_model(cls, model_name: str = None) -> None:
92
+ """Preload a model during startup."""
93
+ if model_name is None:
94
+ model_name = settings.default_whisper_model
95
+ try:
96
+ cls.get_model(model_name)
97
+ except Exception as e:
98
+ logger.error(f"Failed to preload Whisper model: {e}")
99
+ raise
100
 
101
  @classmethod
102
  def transcribe_with_words(
103
  cls,
104
  audio_array: np.ndarray,
105
+ model_name: str = None,
106
  language: str = "vi",
107
+ vad_options: Optional[dict] = None,
108
  beam_size: int = 5,
109
  temperature: float = 0.0,
110
+ best_of: int = 5,
111
+ initial_prompt: Optional[str] = None,
112
  ) -> Dict:
113
+ """
114
+ Transcribe audio and return word-level timestamps.
115
+ """
116
+ model = cls.get_model(model_name)
117
 
118
+ vad_filter = vad_options if vad_options else False
119
+ prompt = initial_prompt.strip() if initial_prompt and initial_prompt.strip() else None
120
 
121
+ segments_gen, info = model.transcribe(
122
  audio_array,
123
+ language=language if language != "auto" else None,
124
+ beam_size=beam_size,
125
+ temperature=temperature,
126
+ best_of=best_of,
127
+
128
+ # QA / Stability
129
+ condition_on_previous_text=False,
130
+ no_speech_threshold=0.6,
131
+
132
+ word_timestamps=True,
133
+
134
+ # VAD
135
+ vad_filter=vad_filter,
136
+ vad_parameters=dict(
137
+ threshold=settings.vad_threshold,
138
+ min_speech_duration_ms=settings.vad_min_speech_duration_ms,
139
+ min_silence_duration_ms=settings.vad_min_silence_duration_ms,
140
+ ),
141
+
142
+ initial_prompt=prompt,
143
  )
144
 
145
+ words = []
146
+ full_text = []
147
+
148
+ for seg in segments_gen:
149
+ if seg.text:
150
+ full_text.append(seg.text.strip())
 
 
151
 
152
+ if hasattr(seg, "words") and seg.words:
153
+ for w in seg.words:
154
+ if not w.word.strip():
155
+ continue
156
+ words.append({
157
+ "word": w.word.strip(),
158
+ "start": float(w.start),
159
+ "end": float(w.end),
160
+ })
161
 
162
  return {
163
+ "text": " ".join(full_text).strip(),
164
+ "words": words,
165
+ "info": info,
 
 
 
 
166
  }
167
 
168
 
 
170
  async def transcribe_with_words_async(
171
  cls,
172
  audio_array: np.ndarray,
173
+ model_name: str = None,
174
+ language: str = "vi",
175
+ vad_options: Optional[dict] = None,
176
+ beam_size: int = 5,
177
+ temperature: float = 0.0,
178
+ best_of: int = 5,
179
+ initial_prompt: Optional[str] = None,
180
+ ) -> str:
181
+ """
182
+ Async wrapper for transcription (runs in thread pool).
183
+ """
184
  import asyncio
185
+
186
  loop = asyncio.get_event_loop()
187
  return await loop.run_in_executor(
188
  None,
189
+ lambda: cls.transcribe_with_words(
190
+ audio_array,
191
+ model_name=model_name,
192
+ language=language,
193
+ vad_options=vad_options,
194
+ beam_size=beam_size,
195
+ temperature=temperature,
196
+ best_of=best_of,
197
+ initial_prompt=initial_prompt
198
+ )
199
  )
200
+
201
  @classmethod
202
  def get_available_models(cls) -> Dict[str, str]:
203
+ """Return list of available models."""
204
  return AVAILABLE_MODELS.copy()