colab-user commited on
Commit
fe21ffa
·
1 Parent(s): 8ce75a0

fix model transcription

Browse files
app/api/routes.py CHANGED
@@ -37,7 +37,7 @@ async def get_models():
37
  """Get available Whisper models."""
38
  return {
39
  "models": list(AVAILABLE_MODELS.keys()),
40
- "default": settings.default_whisper_model
41
  }
42
 
43
 
 
37
  """Get available Whisper models."""
38
  return {
39
  "models": list(AVAILABLE_MODELS.keys()),
40
+ "default": settings.whisper_lora_model_dir
41
  }
42
 
43
 
app/core/config.py CHANGED
@@ -30,12 +30,7 @@ class Settings(BaseSettings):
30
  enable_vocal_separation: bool = True
31
  mdx_model: str = "Kim_Vocal_2.onnx" # High quality vocal isolation
32
 
33
- # Available Whisper models
34
- available_whisper_models: Dict[str, str] = {
35
- "EraX-WoW-Turbo": "erax-ai/EraX-WoW-Turbo-V1.1-CT2",
36
- "PhoWhisper Large": "kiendt/PhoWhisper-large-ct2"
37
- }
38
- default_whisper_model: str = "PhoWhisper Large"
39
 
40
  # Diarization model
41
  diarization_model: str = "pyannote/speaker-diarization-community-1"
 
30
  enable_vocal_separation: bool = True
31
  mdx_model: str = "Kim_Vocal_2.onnx" # High quality vocal isolation
32
 
33
+ whisper_lora_model_dir: str = "vyluong/pho-whisper-vi-lora-v5"
 
 
 
 
 
34
 
35
  # Diarization model
36
  diarization_model: str = "pyannote/speaker-diarization-community-1"
app/main.py CHANGED
@@ -35,7 +35,7 @@ async def lifespan(app: FastAPI):
35
  """
36
  logger.info("Starting PrecisionVoice application...")
37
  logger.info(f"Device: {settings.resolved_device}")
38
- logger.info(f"Default Whisper model: {settings.default_whisper_model}")
39
  logger.info(f"Diarization model: {settings.diarization_model}")
40
 
41
  # Preload default Whisper model
 
35
  """
36
  logger.info("Starting PrecisionVoice application...")
37
  logger.info(f"Device: {settings.resolved_device}")
38
+ logger.info(f"Default Whisper model: {settings.whisper_lora_model_dir}")
39
  logger.info(f"Diarization model: {settings.diarization_model}")
40
 
41
  # Preload default Whisper model
app/services/transcription.py CHANGED
@@ -3,11 +3,14 @@ Transcription service using faster-whisper.
3
  Supports multiple Vietnamese Whisper models with caching.
4
  """
5
  import logging
 
6
  from typing import Dict, Optional, List
7
  from dataclasses import dataclass
8
 
9
  import numpy as np
10
- from faster_whisper import WhisperModel
 
 
11
 
12
  from app.core.config import get_settings
13
 
@@ -17,8 +20,9 @@ settings = get_settings()
17
 
18
  # Available Whisper models for Vietnamese
19
  AVAILABLE_MODELS = {
20
- "EraX-WoW-Turbo": "erax-ai/EraX-WoW-Turbo-V1.1-CT2",
21
- "PhoWhisper Large": "kiendt/PhoWhisper-large-ct2"
 
22
  }
23
 
24
 
@@ -36,138 +40,88 @@ class TranscriptionService:
36
  Supports multiple models with caching.
37
  """
38
 
39
- _models: Dict[str, WhisperModel] = {}
 
 
40
 
41
  @classmethod
42
- def get_model(cls, model_name: str = None) -> WhisperModel:
43
- """
44
- Get or load a Whisper model (lazy loading with caching).
45
-
46
- Args:
47
- model_name: Name of the model from AVAILABLE_MODELS
48
-
49
- Returns:
50
- Loaded WhisperModel instance
51
- """
52
-
53
- if model_name is None:
54
- model_name = settings.default_whisper_model
55
-
56
- cache_key = f"{model_name}_{settings.resolved_compute_type}"
57
-
58
- if cache_key in cls._models:
59
- return cls._models[cache_key]
60
-
61
- # Get model path
62
- if model_name in AVAILABLE_MODELS:
63
- model_path = AVAILABLE_MODELS[model_name]
64
- else:
65
- # Fallback to first available model
66
- model_name = list(AVAILABLE_MODELS.keys())[0]
67
- model_path = AVAILABLE_MODELS[model_name]
68
-
69
- logger.info(f"Loading Whisper model: {model_name} ({model_path})")
70
- logger.debug(f"Device: {settings.resolved_device}, Compute type: {settings.resolved_compute_type}")
71
-
72
- model = WhisperModel(
73
- model_path,
74
- device=settings.resolved_device,
75
- compute_type=settings.resolved_compute_type,
76
- )
77
-
78
- cls._models[cache_key] = model
79
- logger.info(f"Whisper model loaded: {model_name}")
80
-
81
- return model
82
 
83
  @classmethod
84
- def is_loaded(cls, model_name: str = None) -> bool:
85
- if model_name is None:
86
- model_name = settings.default_whisper_model
87
- """Check if a model is loaded."""
88
- cache_key = f"{model_name}_{settings.resolved_compute_type}"
89
- return cache_key in cls._models
90
 
91
  @classmethod
92
- def preload_model(cls, model_name: str = None) -> None:
93
- """Preload a model during startup."""
94
- if model_name is None:
95
- model_name = settings.default_whisper_model
96
- try:
97
- cls.get_model(model_name)
98
- except Exception as e:
99
- logger.error(f"Failed to preload Whisper model: {e}")
100
- raise
101
 
102
  @classmethod
103
  def transcribe_with_words(
104
  cls,
105
  audio_array: np.ndarray,
106
- model_name: str = None,
107
  language: str = "vi",
108
- vad_options: Optional[dict] = None,
109
  beam_size: int = 5,
110
- temperature: float = 0.2,
111
- best_of: int = 5,
112
- initial_prompt: Optional[str] = None,
113
  ) -> Dict:
114
- """
115
- Transcribe audio and return word-level timestamps.
116
- """
117
- model = cls.get_model(model_name)
118
 
119
- vad_filter = vad_options if vad_options else False
120
- prompt = initial_prompt.strip() if initial_prompt and initial_prompt.strip() else None
121
 
122
- segments_gen, info = model.transcribe(
123
  audio_array,
124
- language=language if language != "auto" else None,
125
- beam_size=beam_size,
126
- temperature=temperature,
127
- best_of=best_of,
128
-
129
- # QA / Stability
130
- condition_on_previous_text=False,
131
- no_speech_threshold=0.6,
132
-
133
- # hallucination
134
- compression_ratio_threshold=2.4,
135
- log_prob_threshold=-1.0,
136
-
137
- word_timestamps=True,
138
-
139
- # VAD
140
- vad_filter=vad_filter,
141
- vad_parameters=dict(
142
- threshold=settings.vad_threshold,
143
- min_speech_duration_ms=settings.vad_min_speech_duration_ms,
144
- min_silence_duration_ms=settings.vad_min_silence_duration_ms,
145
- ),
146
-
147
- initial_prompt=prompt,
148
- )
149
 
150
- words = []
151
- full_text = []
 
 
152
 
153
- for seg in segments_gen:
154
- if seg.text:
155
- full_text.append(seg.text.strip())
 
 
 
 
 
156
 
157
- if hasattr(seg, "words") and seg.words:
158
- for w in seg.words:
159
- if not w.word.strip():
160
- continue
161
- words.append({
162
- "word": w.word.strip(),
163
- "start": float(w.start),
164
- "end": float(w.end),
165
- })
166
 
167
  return {
168
- "text": " ".join(full_text).strip(),
169
- "words": words,
170
- "info": info,
 
 
 
 
171
  }
172
 
173
 
@@ -175,35 +129,15 @@ class TranscriptionService:
175
  async def transcribe_with_words_async(
176
  cls,
177
  audio_array: np.ndarray,
178
- model_name: str = None,
179
- language: str = "vi",
180
- vad_options: Optional[dict] = None,
181
- beam_size: int = 5,
182
- temperature: float = 0.0,
183
- best_of: int = 5,
184
- initial_prompt: Optional[str] = None,
185
- ) -> str:
186
- """
187
- Async wrapper for transcription (runs in thread pool).
188
- """
189
  import asyncio
190
-
191
  loop = asyncio.get_event_loop()
192
  return await loop.run_in_executor(
193
  None,
194
- lambda: cls.transcribe_with_words(
195
- audio_array,
196
- model_name=model_name,
197
- language=language,
198
- vad_options=vad_options,
199
- beam_size=beam_size,
200
- temperature=temperature,
201
- best_of=best_of,
202
- initial_prompt=initial_prompt
203
- )
204
  )
205
-
206
  @classmethod
207
  def get_available_models(cls) -> Dict[str, str]:
208
- """Return list of available models."""
209
  return AVAILABLE_MODELS.copy()
 
3
  Supports multiple Vietnamese Whisper models with caching.
4
  """
5
  import logging
6
+ import torch
7
  from typing import Dict, Optional, List
8
  from dataclasses import dataclass
9
 
10
  import numpy as np
11
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
12
+ from peft import PeftModel
13
+
14
 
15
  from app.core.config import get_settings
16
 
 
20
 
21
  # Available Whisper models for Vietnamese
22
  AVAILABLE_MODELS = {
23
+
24
+ "Whisper-LoRA": settings.whisper_lora_model_dir
25
+
26
  }
27
 
28
 
 
40
  Supports multiple models with caching.
41
  """
42
 
43
+ _model = None
44
+ _processor = None
45
+ _device = "cuda" if torch.cuda.is_available() else "cpu"
46
 
47
  @classmethod
48
+ def get_model(cls):
49
+ if cls._model is not None:
50
+ return cls._model, cls._processor
51
+
52
+ model_dir = AVAILABLE_MODELS["Whisper-LoRA"]
53
+
54
+ logger.info(f"Loading Whisper + LoRA from {model_dir}")
55
+ logger.info(f"Device: {cls._device}")
56
+
57
+ base_model = WhisperForConditionalGeneration.from_pretrained(model_dir)
58
+ model = PeftModel.from_pretrained(base_model, model_dir)
59
+
60
+ model.to(cls._device)
61
+ model.eval()
62
+
63
+ processor = WhisperProcessor.from_pretrained(model_dir)
64
+
65
+ cls._model = model
66
+ cls._processor = processor
67
+
68
+ logger.info("Whisper + LoRA loaded successfully")
69
+ return model, processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  @classmethod
72
+ def is_loaded(cls) -> bool:
73
+ return cls._model is not None
 
 
 
 
74
 
75
  @classmethod
76
+ def preload_model(cls) -> None:
77
+ cls.get_model()
 
 
 
 
 
 
 
78
 
79
  @classmethod
80
  def transcribe_with_words(
81
  cls,
82
  audio_array: np.ndarray,
 
83
  language: str = "vi",
 
84
  beam_size: int = 5,
85
+ temperature: float = 0.0,
 
 
86
  ) -> Dict:
87
+ model, processor = cls.get_model()
 
 
 
88
 
89
+ if audio_array.ndim > 1:
90
+ audio_array = np.mean(audio_array, axis=0)
91
 
92
+ inputs = processor(
93
  audio_array,
94
+ sampling_rate=16000,
95
+ return_tensors="pt"
96
+ ).input_features.to(cls._device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ forced_decoder_ids = processor.get_decoder_prompt_ids(
99
+ language=language,
100
+ task="transcribe"
101
+ )
102
 
103
+ with torch.no_grad():
104
+ generated_ids = model.generate(
105
+ inputs,
106
+ forced_decoder_ids=forced_decoder_ids,
107
+ num_beams=beam_size,
108
+ temperature=temperature,
109
+ max_new_tokens=settings.whisper_max_new_tokens,
110
+ )
111
 
112
+ text = processor.batch_decode(
113
+ generated_ids,
114
+ skip_special_tokens=True
115
+ )[0].strip()
 
 
 
 
 
116
 
117
  return {
118
+ "text": text,
119
+ "words": [],
120
+ "info": {
121
+ "engine": "transformers-whisper-lora",
122
+ "language": language,
123
+ "beam_size": beam_size,
124
+ },
125
  }
126
 
127
 
 
129
  async def transcribe_with_words_async(
130
  cls,
131
  audio_array: np.ndarray,
132
+ **kwargs
133
+ ) -> Dict:
 
 
 
 
 
 
 
 
 
134
  import asyncio
 
135
  loop = asyncio.get_event_loop()
136
  return await loop.run_in_executor(
137
  None,
138
+ lambda: cls.transcribe_with_words(audio_array, **kwargs)
 
 
 
 
 
 
 
 
 
139
  )
140
+
141
  @classmethod
142
  def get_available_models(cls) -> Dict[str, str]:
 
143
  return AVAILABLE_MODELS.copy()