Spaces:
Sleeping
Sleeping
colab-user commited on
Commit ·
fe21ffa
1
Parent(s): 8ce75a0
fix model transcription
Browse files- app/api/routes.py +1 -1
- app/core/config.py +1 -6
- app/main.py +1 -1
- app/services/transcription.py +71 -137
app/api/routes.py
CHANGED
|
@@ -37,7 +37,7 @@ async def get_models():
|
|
| 37 |
"""Get available Whisper models."""
|
| 38 |
return {
|
| 39 |
"models": list(AVAILABLE_MODELS.keys()),
|
| 40 |
-
"default": settings.
|
| 41 |
}
|
| 42 |
|
| 43 |
|
|
|
|
| 37 |
"""Get available Whisper models."""
|
| 38 |
return {
|
| 39 |
"models": list(AVAILABLE_MODELS.keys()),
|
| 40 |
+
"default": settings.whisper_lora_model_dir
|
| 41 |
}
|
| 42 |
|
| 43 |
|
app/core/config.py
CHANGED
|
@@ -30,12 +30,7 @@ class Settings(BaseSettings):
|
|
| 30 |
enable_vocal_separation: bool = True
|
| 31 |
mdx_model: str = "Kim_Vocal_2.onnx" # High quality vocal isolation
|
| 32 |
|
| 33 |
-
|
| 34 |
-
available_whisper_models: Dict[str, str] = {
|
| 35 |
-
"EraX-WoW-Turbo": "erax-ai/EraX-WoW-Turbo-V1.1-CT2",
|
| 36 |
-
"PhoWhisper Large": "kiendt/PhoWhisper-large-ct2"
|
| 37 |
-
}
|
| 38 |
-
default_whisper_model: str = "PhoWhisper Large"
|
| 39 |
|
| 40 |
# Diarization model
|
| 41 |
diarization_model: str = "pyannote/speaker-diarization-community-1"
|
|
|
|
| 30 |
enable_vocal_separation: bool = True
|
| 31 |
mdx_model: str = "Kim_Vocal_2.onnx" # High quality vocal isolation
|
| 32 |
|
| 33 |
+
whisper_lora_model_dir: str = "vyluong/pho-whisper-vi-lora-v5"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# Diarization model
|
| 36 |
diarization_model: str = "pyannote/speaker-diarization-community-1"
|
app/main.py
CHANGED
|
@@ -35,7 +35,7 @@ async def lifespan(app: FastAPI):
|
|
| 35 |
"""
|
| 36 |
logger.info("Starting PrecisionVoice application...")
|
| 37 |
logger.info(f"Device: {settings.resolved_device}")
|
| 38 |
-
logger.info(f"Default Whisper model: {settings.
|
| 39 |
logger.info(f"Diarization model: {settings.diarization_model}")
|
| 40 |
|
| 41 |
# Preload default Whisper model
|
|
|
|
| 35 |
"""
|
| 36 |
logger.info("Starting PrecisionVoice application...")
|
| 37 |
logger.info(f"Device: {settings.resolved_device}")
|
| 38 |
+
logger.info(f"Default Whisper model: {settings.whisper_lora_model_dir}")
|
| 39 |
logger.info(f"Diarization model: {settings.diarization_model}")
|
| 40 |
|
| 41 |
# Preload default Whisper model
|
app/services/transcription.py
CHANGED
|
@@ -3,11 +3,14 @@ Transcription service using faster-whisper.
|
|
| 3 |
Supports multiple Vietnamese Whisper models with caching.
|
| 4 |
"""
|
| 5 |
import logging
|
|
|
|
| 6 |
from typing import Dict, Optional, List
|
| 7 |
from dataclasses import dataclass
|
| 8 |
|
| 9 |
import numpy as np
|
| 10 |
-
from
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from app.core.config import get_settings
|
| 13 |
|
|
@@ -17,8 +20,9 @@ settings = get_settings()
|
|
| 17 |
|
| 18 |
# Available Whisper models for Vietnamese
|
| 19 |
AVAILABLE_MODELS = {
|
| 20 |
-
|
| 21 |
-
"
|
|
|
|
| 22 |
}
|
| 23 |
|
| 24 |
|
|
@@ -36,138 +40,88 @@ class TranscriptionService:
|
|
| 36 |
Supports multiple models with caching.
|
| 37 |
"""
|
| 38 |
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
|
| 41 |
@classmethod
|
| 42 |
-
def get_model(cls
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
else:
|
| 65 |
-
# Fallback to first available model
|
| 66 |
-
model_name = list(AVAILABLE_MODELS.keys())[0]
|
| 67 |
-
model_path = AVAILABLE_MODELS[model_name]
|
| 68 |
-
|
| 69 |
-
logger.info(f"Loading Whisper model: {model_name} ({model_path})")
|
| 70 |
-
logger.debug(f"Device: {settings.resolved_device}, Compute type: {settings.resolved_compute_type}")
|
| 71 |
-
|
| 72 |
-
model = WhisperModel(
|
| 73 |
-
model_path,
|
| 74 |
-
device=settings.resolved_device,
|
| 75 |
-
compute_type=settings.resolved_compute_type,
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
cls._models[cache_key] = model
|
| 79 |
-
logger.info(f"Whisper model loaded: {model_name}")
|
| 80 |
-
|
| 81 |
-
return model
|
| 82 |
|
| 83 |
@classmethod
|
| 84 |
-
def is_loaded(cls
|
| 85 |
-
|
| 86 |
-
model_name = settings.default_whisper_model
|
| 87 |
-
"""Check if a model is loaded."""
|
| 88 |
-
cache_key = f"{model_name}_{settings.resolved_compute_type}"
|
| 89 |
-
return cache_key in cls._models
|
| 90 |
|
| 91 |
@classmethod
|
| 92 |
-
def preload_model(cls
|
| 93 |
-
|
| 94 |
-
if model_name is None:
|
| 95 |
-
model_name = settings.default_whisper_model
|
| 96 |
-
try:
|
| 97 |
-
cls.get_model(model_name)
|
| 98 |
-
except Exception as e:
|
| 99 |
-
logger.error(f"Failed to preload Whisper model: {e}")
|
| 100 |
-
raise
|
| 101 |
|
| 102 |
@classmethod
|
| 103 |
def transcribe_with_words(
|
| 104 |
cls,
|
| 105 |
audio_array: np.ndarray,
|
| 106 |
-
model_name: str = None,
|
| 107 |
language: str = "vi",
|
| 108 |
-
vad_options: Optional[dict] = None,
|
| 109 |
beam_size: int = 5,
|
| 110 |
-
temperature: float = 0.
|
| 111 |
-
best_of: int = 5,
|
| 112 |
-
initial_prompt: Optional[str] = None,
|
| 113 |
) -> Dict:
|
| 114 |
-
|
| 115 |
-
Transcribe audio and return word-level timestamps.
|
| 116 |
-
"""
|
| 117 |
-
model = cls.get_model(model_name)
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
|
| 122 |
-
|
| 123 |
audio_array,
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
best_of=best_of,
|
| 128 |
-
|
| 129 |
-
# QA / Stability
|
| 130 |
-
condition_on_previous_text=False,
|
| 131 |
-
no_speech_threshold=0.6,
|
| 132 |
-
|
| 133 |
-
# hallucination
|
| 134 |
-
compression_ratio_threshold=2.4,
|
| 135 |
-
log_prob_threshold=-1.0,
|
| 136 |
-
|
| 137 |
-
word_timestamps=True,
|
| 138 |
-
|
| 139 |
-
# VAD
|
| 140 |
-
vad_filter=vad_filter,
|
| 141 |
-
vad_parameters=dict(
|
| 142 |
-
threshold=settings.vad_threshold,
|
| 143 |
-
min_speech_duration_ms=settings.vad_min_speech_duration_ms,
|
| 144 |
-
min_silence_duration_ms=settings.vad_min_silence_duration_ms,
|
| 145 |
-
),
|
| 146 |
-
|
| 147 |
-
initial_prompt=prompt,
|
| 148 |
-
)
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
words.append({
|
| 162 |
-
"word": w.word.strip(),
|
| 163 |
-
"start": float(w.start),
|
| 164 |
-
"end": float(w.end),
|
| 165 |
-
})
|
| 166 |
|
| 167 |
return {
|
| 168 |
-
"text":
|
| 169 |
-
"words":
|
| 170 |
-
"info":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
}
|
| 172 |
|
| 173 |
|
|
@@ -175,35 +129,15 @@ class TranscriptionService:
|
|
| 175 |
async def transcribe_with_words_async(
|
| 176 |
cls,
|
| 177 |
audio_array: np.ndarray,
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
vad_options: Optional[dict] = None,
|
| 181 |
-
beam_size: int = 5,
|
| 182 |
-
temperature: float = 0.0,
|
| 183 |
-
best_of: int = 5,
|
| 184 |
-
initial_prompt: Optional[str] = None,
|
| 185 |
-
) -> str:
|
| 186 |
-
"""
|
| 187 |
-
Async wrapper for transcription (runs in thread pool).
|
| 188 |
-
"""
|
| 189 |
import asyncio
|
| 190 |
-
|
| 191 |
loop = asyncio.get_event_loop()
|
| 192 |
return await loop.run_in_executor(
|
| 193 |
None,
|
| 194 |
-
lambda: cls.transcribe_with_words(
|
| 195 |
-
audio_array,
|
| 196 |
-
model_name=model_name,
|
| 197 |
-
language=language,
|
| 198 |
-
vad_options=vad_options,
|
| 199 |
-
beam_size=beam_size,
|
| 200 |
-
temperature=temperature,
|
| 201 |
-
best_of=best_of,
|
| 202 |
-
initial_prompt=initial_prompt
|
| 203 |
-
)
|
| 204 |
)
|
| 205 |
-
|
| 206 |
@classmethod
|
| 207 |
def get_available_models(cls) -> Dict[str, str]:
|
| 208 |
-
"""Return list of available models."""
|
| 209 |
return AVAILABLE_MODELS.copy()
|
|
|
|
| 3 |
Supports multiple Vietnamese Whisper models with caching.
|
| 4 |
"""
|
| 5 |
import logging
|
| 6 |
+
import torch
|
| 7 |
from typing import Dict, Optional, List
|
| 8 |
from dataclasses import dataclass
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
| 12 |
+
from peft import PeftModel
|
| 13 |
+
|
| 14 |
|
| 15 |
from app.core.config import get_settings
|
| 16 |
|
|
|
|
| 20 |
|
| 21 |
# Available Whisper models for Vietnamese
|
| 22 |
AVAILABLE_MODELS = {
|
| 23 |
+
|
| 24 |
+
"Whisper-LoRA": settings.whisper_lora_model_dir
|
| 25 |
+
|
| 26 |
}
|
| 27 |
|
| 28 |
|
|
|
|
| 40 |
Supports multiple models with caching.
|
| 41 |
"""
|
| 42 |
|
| 43 |
+
_model = None
|
| 44 |
+
_processor = None
|
| 45 |
+
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 46 |
|
| 47 |
@classmethod
|
| 48 |
+
def get_model(cls):
|
| 49 |
+
if cls._model is not None:
|
| 50 |
+
return cls._model, cls._processor
|
| 51 |
+
|
| 52 |
+
model_dir = AVAILABLE_MODELS["Whisper-LoRA"]
|
| 53 |
+
|
| 54 |
+
logger.info(f"Loading Whisper + LoRA from {model_dir}")
|
| 55 |
+
logger.info(f"Device: {cls._device}")
|
| 56 |
+
|
| 57 |
+
base_model = WhisperForConditionalGeneration.from_pretrained(model_dir)
|
| 58 |
+
model = PeftModel.from_pretrained(base_model, model_dir)
|
| 59 |
+
|
| 60 |
+
model.to(cls._device)
|
| 61 |
+
model.eval()
|
| 62 |
+
|
| 63 |
+
processor = WhisperProcessor.from_pretrained(model_dir)
|
| 64 |
+
|
| 65 |
+
cls._model = model
|
| 66 |
+
cls._processor = processor
|
| 67 |
+
|
| 68 |
+
logger.info("Whisper + LoRA loaded successfully")
|
| 69 |
+
return model, processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
@classmethod
|
| 72 |
+
def is_loaded(cls) -> bool:
|
| 73 |
+
return cls._model is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
@classmethod
|
| 76 |
+
def preload_model(cls) -> None:
|
| 77 |
+
cls.get_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
@classmethod
|
| 80 |
def transcribe_with_words(
|
| 81 |
cls,
|
| 82 |
audio_array: np.ndarray,
|
|
|
|
| 83 |
language: str = "vi",
|
|
|
|
| 84 |
beam_size: int = 5,
|
| 85 |
+
temperature: float = 0.0,
|
|
|
|
|
|
|
| 86 |
) -> Dict:
|
| 87 |
+
model, processor = cls.get_model()
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
if audio_array.ndim > 1:
|
| 90 |
+
audio_array = np.mean(audio_array, axis=0)
|
| 91 |
|
| 92 |
+
inputs = processor(
|
| 93 |
audio_array,
|
| 94 |
+
sampling_rate=16000,
|
| 95 |
+
return_tensors="pt"
|
| 96 |
+
).input_features.to(cls._device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
forced_decoder_ids = processor.get_decoder_prompt_ids(
|
| 99 |
+
language=language,
|
| 100 |
+
task="transcribe"
|
| 101 |
+
)
|
| 102 |
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
generated_ids = model.generate(
|
| 105 |
+
inputs,
|
| 106 |
+
forced_decoder_ids=forced_decoder_ids,
|
| 107 |
+
num_beams=beam_size,
|
| 108 |
+
temperature=temperature,
|
| 109 |
+
max_new_tokens=settings.whisper_max_new_tokens,
|
| 110 |
+
)
|
| 111 |
|
| 112 |
+
text = processor.batch_decode(
|
| 113 |
+
generated_ids,
|
| 114 |
+
skip_special_tokens=True
|
| 115 |
+
)[0].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
return {
|
| 118 |
+
"text": text,
|
| 119 |
+
"words": [],
|
| 120 |
+
"info": {
|
| 121 |
+
"engine": "transformers-whisper-lora",
|
| 122 |
+
"language": language,
|
| 123 |
+
"beam_size": beam_size,
|
| 124 |
+
},
|
| 125 |
}
|
| 126 |
|
| 127 |
|
|
|
|
| 129 |
async def transcribe_with_words_async(
|
| 130 |
cls,
|
| 131 |
audio_array: np.ndarray,
|
| 132 |
+
**kwargs
|
| 133 |
+
) -> Dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
import asyncio
|
|
|
|
| 135 |
loop = asyncio.get_event_loop()
|
| 136 |
return await loop.run_in_executor(
|
| 137 |
None,
|
| 138 |
+
lambda: cls.transcribe_with_words(audio_array, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
)
|
| 140 |
+
|
| 141 |
@classmethod
|
| 142 |
def get_available_models(cls) -> Dict[str, str]:
|
|
|
|
| 143 |
return AVAILABLE_MODELS.copy()
|