Spaces:
Running
Running
colab-user
commited on
Commit
·
c90efd5
1
Parent(s):
e846326
optimize finetuned model
Browse files- app/services/processor.py +311 -125
- app/services/transcription.py +131 -70
app/services/processor.py
CHANGED
|
@@ -12,16 +12,22 @@ import librosa
|
|
| 12 |
import torch
|
| 13 |
|
| 14 |
from app.core.config import get_settings
|
| 15 |
-
from app.services.transcription import TranscriptionService
|
| 16 |
from app.services.diarization import DiarizationService, SpeakerSegment, DiarizationResult
|
| 17 |
-
from app.services.alignment import AlignmentService
|
| 18 |
-
from app.schemas.models import TranscriptSegment
|
| 19 |
-
|
| 20 |
|
| 21 |
logger = logging.getLogger(__name__)
|
| 22 |
settings = get_settings()
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
@dataclass
|
| 27 |
class ProcessingResult:
|
|
@@ -36,45 +42,76 @@ class ProcessingResult:
|
|
| 36 |
txt_content: str = ""
|
| 37 |
csv_content: str = ""
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def pad_and_refine_tensor(
|
| 42 |
-
waveform: torch.Tensor,
|
| 43 |
-
sr: int,
|
| 44 |
-
start_s: float,
|
| 45 |
-
end_s: float,
|
| 46 |
-
pad_ms: int = 200,
|
| 47 |
-
) -> Tuple[float, float]:
|
| 48 |
-
|
| 49 |
-
total_len = waveform.shape[1]
|
| 50 |
-
s = max(int((start_s - pad_ms / 1000) * sr), 0)
|
| 51 |
-
e = min(int((end_s + pad_ms / 1000) * sr), total_len)
|
| 52 |
-
|
| 53 |
-
if e <= s:
|
| 54 |
-
return start_s, end_s
|
| 55 |
-
|
| 56 |
-
return s / sr, e / sr
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def infer_roles_from_diarization(
|
| 60 |
diarization_segments: List[SpeakerSegment],
|
| 61 |
-
) ->
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
for
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
for spk in dur
|
| 76 |
-
}
|
| 77 |
|
|
|
|
| 78 |
|
| 79 |
def convert_audio_to_wav(audio_path: Path) -> Path:
|
| 80 |
"""Convert any audio to WAV 16kHz Mono using ffmpeg."""
|
|
@@ -89,14 +126,59 @@ def convert_audio_to_wav(audio_path: Path) -> Path:
|
|
| 89 |
except subprocess.CalledProcessError as e:
|
| 90 |
logger.error(f"FFmpeg conversion failed: {e}")
|
| 91 |
return audio_path
|
| 92 |
-
|
| 93 |
-
|
| 94 |
def format_timestamp(seconds: float) -> str:
|
| 95 |
m = int(seconds // 60)
|
| 96 |
s = seconds % 60
|
| 97 |
return f"{m:02d}:{s:06.3f}"
|
| 98 |
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
# =========================
|
| 101 |
# Processor
|
| 102 |
# =========================
|
|
@@ -105,9 +187,20 @@ class Processor:
|
|
| 105 |
async def process_audio(
|
| 106 |
cls,
|
| 107 |
audio_path: Path,
|
|
|
|
| 108 |
language: str = "vi",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
beam_size: int = 5,
|
| 110 |
temperature: float = 0.0,
|
|
|
|
|
|
|
| 111 |
) -> ProcessingResult:
|
| 112 |
|
| 113 |
import asyncio
|
|
@@ -120,97 +213,172 @@ class Processor:
|
|
| 120 |
|
| 121 |
# 2: Load audio
|
| 122 |
y, sr = librosa.load(wav_path, sr=16000, mono=True)
|
| 123 |
-
|
|
|
|
|
|
|
| 124 |
duration = len(y) / sr
|
| 125 |
-
|
| 126 |
# 3: Diarization
|
|
|
|
|
|
|
| 127 |
diarization: DiarizationResult = await DiarizationService.diarize_async(wav_path)
|
| 128 |
-
diarization_segments = diarization.segments or [
|
| 129 |
-
SpeakerSegment(0.0, duration, "SPEAKER_0")
|
| 130 |
-
]
|
| 131 |
-
|
| 132 |
-
diarization_segments.sort(key=lambda s: s.start)
|
| 133 |
|
| 134 |
-
diarization_segments = [
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
speaker_map = {s: f"Speaker {i+1}" for i, s in enumerate(raw_speakers)}
|
| 145 |
-
|
| 146 |
-
diarization_segments = [
|
| 147 |
-
SpeakerSegment(
|
| 148 |
-
start=s.start,
|
| 149 |
-
end=s.end,
|
| 150 |
-
speaker=speaker_map[s.speaker]
|
| 151 |
-
)
|
| 152 |
-
for s in diarization_segments
|
| 153 |
-
]
|
| 154 |
-
# 5. Roles infer
|
| 155 |
-
roles = infer_roles_from_diarization(diarization_segments)
|
| 156 |
-
|
| 157 |
-
result = await TranscriptionService.transcribe_with_words_async(
|
| 158 |
-
audio_array=y,
|
| 159 |
-
language=language,
|
| 160 |
-
beam_size=beam_size,
|
| 161 |
-
temperature=temperature
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
words: List[WordTimestamp] = [
|
| 165 |
-
WordTimestamp(
|
| 166 |
-
word=w["word"],
|
| 167 |
-
start=float(w["start"]),
|
| 168 |
-
end=float(w["end"]),
|
| 169 |
-
)
|
| 170 |
-
for w in result.get("words", [])
|
| 171 |
-
if w.get("word")
|
| 172 |
-
]
|
| 173 |
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
)
|
| 188 |
-
for s in aligned_segments
|
| 189 |
-
]
|
| 190 |
-
|
| 191 |
-
segments = cls._filter_segments_with_context(segments)
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
max_gap_s=0.6
|
| 196 |
)
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
|
| 201 |
txt_content = cls._generate_txt(
|
| 202 |
-
|
| 203 |
-
len(
|
| 204 |
processing_time,
|
| 205 |
duration,
|
| 206 |
roles
|
| 207 |
)
|
| 208 |
|
| 209 |
-
csv_content = cls._generate_csv(
|
| 210 |
|
| 211 |
return ProcessingResult(
|
| 212 |
-
segments=
|
| 213 |
-
speaker_count=len(
|
| 214 |
duration=duration,
|
| 215 |
processing_time=processing_time,
|
| 216 |
speakers=speakers,
|
|
@@ -224,36 +392,44 @@ class Processor:
|
|
| 224 |
def _is_meaningful_segment(
|
| 225 |
seg: TranscriptSegment,
|
| 226 |
min_duration_s: float = 0.6,
|
| 227 |
-
|
|
|
|
| 228 |
) -> bool:
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
| 230 |
return True
|
| 231 |
-
if
|
| 232 |
return True
|
| 233 |
if seg.role == "KH":
|
| 234 |
return True
|
|
|
|
| 235 |
return False
|
| 236 |
|
| 237 |
@classmethod
|
| 238 |
def _filter_segments_with_context(
|
| 239 |
cls,
|
| 240 |
-
segments: List[TranscriptSegment]
|
| 241 |
) -> List[TranscriptSegment]:
|
| 242 |
-
|
| 243 |
if not segments:
|
| 244 |
return segments
|
| 245 |
-
|
| 246 |
segments = sorted(segments, key=lambda s: s.start)
|
| 247 |
result = []
|
|
|
|
| 248 |
|
| 249 |
for i, seg in enumerate(segments):
|
| 250 |
-
|
| 251 |
-
|
| 252 |
|
| 253 |
if cls._is_meaningful_segment(seg):
|
| 254 |
result.append(seg)
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
return result
|
| 259 |
|
|
@@ -261,20 +437,30 @@ class Processor:
|
|
| 261 |
@staticmethod
|
| 262 |
def _merge_adjacent_segments(
|
| 263 |
segments: List[TranscriptSegment],
|
| 264 |
-
max_gap_s: float = 0.5
|
| 265 |
) -> List[TranscriptSegment]:
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
if not segments:
|
| 268 |
return segments
|
| 269 |
|
|
|
|
| 270 |
merged = [segments[0]]
|
| 271 |
|
| 272 |
for seg in segments[1:]:
|
| 273 |
prev = merged[-1]
|
|
|
|
| 274 |
gap = seg.start - prev.end
|
| 275 |
|
| 276 |
-
if
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
prev.end = max(prev.end, seg.end)
|
| 279 |
else:
|
| 280 |
merged.append(seg)
|
|
|
|
| 12 |
import torch
|
| 13 |
|
| 14 |
from app.core.config import get_settings
|
| 15 |
+
from app.services.transcription import TranscriptionService
|
| 16 |
from app.services.diarization import DiarizationService, SpeakerSegment, DiarizationResult
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
settings = get_settings()
|
| 20 |
|
| 21 |
|
| 22 |
+
@dataclass
|
| 23 |
+
class TranscriptSegment:
|
| 24 |
+
"""A transcribed segment with speaker info."""
|
| 25 |
+
start: float
|
| 26 |
+
end: float
|
| 27 |
+
speaker: str
|
| 28 |
+
role: Optional[str]
|
| 29 |
+
text: str
|
| 30 |
+
|
| 31 |
|
| 32 |
@dataclass
|
| 33 |
class ProcessingResult:
|
|
|
|
| 42 |
txt_content: str = ""
|
| 43 |
csv_content: str = ""
|
| 44 |
|
| 45 |
+
def assign_speaker_to_word(
|
| 46 |
+
word_start: float,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
diarization_segments: List[SpeakerSegment],
|
| 48 |
+
) -> str:
|
| 49 |
+
"""
|
| 50 |
+
Assign speaker to word using diarization.
|
| 51 |
+
"""
|
| 52 |
+
for seg in diarization_segments:
|
| 53 |
+
if seg.start <= word_start <= seg.end:
|
| 54 |
+
return seg.speaker
|
| 55 |
+
|
| 56 |
+
# fallback: nearest diar segment
|
| 57 |
+
return min(
|
| 58 |
+
diarization_segments,
|
| 59 |
+
key=lambda s: abs((s.start + s.end) / 2 - word_start)
|
| 60 |
+
).speaker
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def group_words_into_segments(
|
| 64 |
+
words: List[dict],
|
| 65 |
+
diarization_segments: List[SpeakerSegment],
|
| 66 |
+
speaker_map: Dict[str, str],
|
| 67 |
+
roles: Dict[str, str],
|
| 68 |
+
max_word_gap_s: float = 0.6,
|
| 69 |
+
) -> List[TranscriptSegment]:
|
| 70 |
+
|
| 71 |
+
segments: List[TranscriptSegment] = []
|
| 72 |
+
current: Optional[TranscriptSegment] = None
|
| 73 |
+
|
| 74 |
+
for w in words:
|
| 75 |
+
text = w.get("word", "").strip()
|
| 76 |
+
if not text:
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
w_start = float(w["start"])
|
| 80 |
+
w_end = float(w["end"])
|
| 81 |
+
|
| 82 |
+
speaker_raw = assign_speaker_to_word(w_start, diarization_segments)
|
| 83 |
+
speaker = speaker_map.get(speaker_raw, speaker_raw)
|
| 84 |
+
role = roles.get(speaker, "KH")
|
| 85 |
+
|
| 86 |
+
if current is None:
|
| 87 |
+
current = TranscriptSegment(
|
| 88 |
+
start=w_start,
|
| 89 |
+
end=w_end,
|
| 90 |
+
speaker=speaker,
|
| 91 |
+
role=role,
|
| 92 |
+
text=text,
|
| 93 |
+
)
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
gap = w_start - current.end
|
| 97 |
+
|
| 98 |
+
if speaker == current.speaker and gap <= max_word_gap_s:
|
| 99 |
+
current.text += " " + text
|
| 100 |
+
current.end = max(current.end, w_end)
|
| 101 |
+
else:
|
| 102 |
+
segments.append(current)
|
| 103 |
+
current = TranscriptSegment(
|
| 104 |
+
start=w_start,
|
| 105 |
+
end=w_end,
|
| 106 |
+
speaker=speaker,
|
| 107 |
+
role=role,
|
| 108 |
+
text=text,
|
| 109 |
+
)
|
| 110 |
|
| 111 |
+
if current:
|
| 112 |
+
segments.append(current)
|
|
|
|
|
|
|
| 113 |
|
| 114 |
+
return segments
|
| 115 |
|
| 116 |
def convert_audio_to_wav(audio_path: Path) -> Path:
|
| 117 |
"""Convert any audio to WAV 16kHz Mono using ffmpeg."""
|
|
|
|
| 126 |
except subprocess.CalledProcessError as e:
|
| 127 |
logger.error(f"FFmpeg conversion failed: {e}")
|
| 128 |
return audio_path
|
| 129 |
+
|
| 130 |
+
|
| 131 |
def format_timestamp(seconds: float) -> str:
|
| 132 |
m = int(seconds // 60)
|
| 133 |
s = seconds % 60
|
| 134 |
return f"{m:02d}:{s:06.3f}"
|
| 135 |
|
| 136 |
|
| 137 |
+
def pad_and_refine_tensor(
|
| 138 |
+
waveform: torch.Tensor,
|
| 139 |
+
sr: int,
|
| 140 |
+
start_s: float,
|
| 141 |
+
end_s: float,
|
| 142 |
+
pad_ms: int = 200,
|
| 143 |
+
silence_db_delta: float = 16,
|
| 144 |
+
min_duration_ms: int = 150,
|
| 145 |
+
) -> Optional[Tuple[int, int]]:
|
| 146 |
+
"""
|
| 147 |
+
Refine segment using energy on TORCH tensor.
|
| 148 |
+
Returns sample indices or None.
|
| 149 |
+
"""
|
| 150 |
+
total_len = waveform.shape[1]
|
| 151 |
+
|
| 152 |
+
start_idx = max(int((start_s - pad_ms / 1000) * sr), 0)
|
| 153 |
+
end_idx = min(int((end_s + pad_ms / 1000) * sr), total_len)
|
| 154 |
+
|
| 155 |
+
if end_idx <= start_idx:
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
segment = waveform[:, start_idx:end_idx]
|
| 159 |
+
if segment .numel() == 0:
|
| 160 |
+
return None
|
| 161 |
+
|
| 162 |
+
# RMS energy
|
| 163 |
+
rms = torch.sqrt(torch.mean(segment ** 2) + 1e-9)
|
| 164 |
+
threshold = rms / silence_db_delta
|
| 165 |
+
|
| 166 |
+
energy = torch.abs(segment)
|
| 167 |
+
valid = torch.where(energy > threshold)[0]
|
| 168 |
+
|
| 169 |
+
if valid.numel() == 0:
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
new_start = start_idx + valid[0].item()
|
| 174 |
+
new_end = start_idx + valid[-1].item()
|
| 175 |
+
|
| 176 |
+
if new_end - new_start < int(min_duration_ms / 1000 * sr):
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
return new_start, new_end
|
| 180 |
+
|
| 181 |
+
|
| 182 |
# =========================
|
| 183 |
# Processor
|
| 184 |
# =========================
|
|
|
|
| 187 |
async def process_audio(
|
| 188 |
cls,
|
| 189 |
audio_path: Path,
|
| 190 |
+
model_name: str = "PhoWhisper VI Finetuned",
|
| 191 |
language: str = "vi",
|
| 192 |
+
|
| 193 |
+
# VAD options
|
| 194 |
+
vad_filter: bool = True,
|
| 195 |
+
vad_min_silence_ms: int = 1000,
|
| 196 |
+
vad_speech_pad_ms: int = 400,
|
| 197 |
+
vad_min_speech_ms: int = 250,
|
| 198 |
+
vad_threshold: float = 0.5,
|
| 199 |
+
# Generation options
|
| 200 |
beam_size: int = 5,
|
| 201 |
temperature: float = 0.0,
|
| 202 |
+
best_of: int = 5,
|
| 203 |
+
initial_prompt: Optional[str] = None,
|
| 204 |
) -> ProcessingResult:
|
| 205 |
|
| 206 |
import asyncio
|
|
|
|
| 213 |
|
| 214 |
# 2: Load audio
|
| 215 |
y, sr = librosa.load(wav_path, sr=16000, mono=True)
|
| 216 |
+
if y.size == 0:
|
| 217 |
+
raise ValueError("Empty audio")
|
| 218 |
+
waveform = torch.from_numpy(y).unsqueeze(0).float()
|
| 219 |
duration = len(y) / sr
|
| 220 |
+
|
| 221 |
# 3: Diarization
|
| 222 |
+
logger.info("Step 3: Running diarization...")
|
| 223 |
+
|
| 224 |
diarization: DiarizationResult = await DiarizationService.diarize_async(wav_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
diarization_segments = diarization.segments or []
|
| 227 |
+
speakers = diarization.speakers or []
|
| 228 |
+
roles = diarization.roles or {}
|
| 229 |
+
|
| 230 |
+
if not diarization_segments:
|
| 231 |
+
diarization_segments = [SpeakerSegment(0.0, duration, "SPEAKER_0")]
|
| 232 |
+
speakers = ["SPEAKER_0"]
|
| 233 |
+
roles = {"SPEAKER_0": "KH"}
|
| 234 |
+
|
| 235 |
+
diarization_segments.sort(key=lambda x: x.start)
|
| 236 |
+
|
| 237 |
+
# 4: Refine segment boundaries
|
| 238 |
+
refined_segments: List[SpeakerSegment] = []
|
| 239 |
+
|
| 240 |
+
for seg in diarization_segments:
|
| 241 |
+
refined = pad_and_refine_tensor(waveform, sr, seg.start, seg.end)
|
| 242 |
+
|
| 243 |
+
if refined:
|
| 244 |
+
s, e = refined
|
| 245 |
+
if e > s:
|
| 246 |
+
refined_segments.append(
|
| 247 |
+
SpeakerSegment(
|
| 248 |
+
start=s / sr,
|
| 249 |
+
end=e / sr,
|
| 250 |
+
speaker=seg.speaker,
|
| 251 |
+
)
|
| 252 |
+
)
|
| 253 |
+
else:
|
| 254 |
+
refined_segments.append(seg)
|
| 255 |
+
else:
|
| 256 |
+
refined_segments.append(seg)
|
| 257 |
|
| 258 |
+
if not refined_segments:
|
| 259 |
+
refined_segments = diarization_segments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
+
|
| 262 |
+
# 5. Normalize speakers
|
| 263 |
+
raw_speakers = sorted({seg.speaker for seg in refined_segments})
|
| 264 |
|
| 265 |
+
speaker_map = {
|
| 266 |
+
spk: f"Speaker {i+1}"
|
| 267 |
+
for i, spk in enumerate(raw_speakers)
|
| 268 |
+
}
|
| 269 |
|
| 270 |
+
speakers = list(speaker_map.values())
|
| 271 |
+
|
| 272 |
+
# 6. NORMALIZE ROLES
|
| 273 |
+
speaker_duration = defaultdict(float)
|
| 274 |
+
for seg in refined_segments:
|
| 275 |
+
speaker_duration[seg.speaker] += seg.end - seg.start
|
| 276 |
+
|
| 277 |
+
logger.info(f"speaker_duration(raw) = {speaker_duration}")
|
| 278 |
+
|
| 279 |
+
if speaker_duration:
|
| 280 |
+
agent_raw = max(speaker_duration, key=speaker_duration.get)
|
| 281 |
+
|
| 282 |
+
roles = {
|
| 283 |
+
speaker_map[spk]: ("NV" if spk == agent_raw else "KH")
|
| 284 |
+
for spk in speaker_duration
|
| 285 |
+
}
|
| 286 |
+
else:
|
| 287 |
+
roles = {}
|
| 288 |
+
|
| 289 |
+
# Default fallback
|
| 290 |
+
for label in speakers:
|
| 291 |
+
roles.setdefault(label, "KH")
|
| 292 |
+
|
| 293 |
+
logger.info(f"roles(mapped) = {roles}")
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# 7: Transcribe
|
| 298 |
+
vad_options = None
|
| 299 |
+
if vad_filter:
|
| 300 |
+
vad_options = {
|
| 301 |
+
"min_silence_duration_ms": vad_min_silence_ms,
|
| 302 |
+
"speech_pad_ms": vad_speech_pad_ms,
|
| 303 |
+
"min_speech_duration_ms": vad_min_speech_ms,
|
| 304 |
+
"threshold": vad_threshold
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
processed_segments: List[TranscriptSegment] = []
|
| 308 |
+
|
| 309 |
+
for seg in refined_segments:
|
| 310 |
+
start = int(seg.start * sr)
|
| 311 |
+
end = int(seg.end * sr)
|
| 312 |
+
|
| 313 |
+
if end <= start:
|
| 314 |
+
continue
|
| 315 |
+
|
| 316 |
+
audio_slice = y[start:end]
|
| 317 |
+
if audio_slice.size < sr * 0.25:
|
| 318 |
+
continue
|
| 319 |
+
|
| 320 |
+
try:
|
| 321 |
+
text = await TranscriptionService.transcribe_with_words_async(
|
| 322 |
+
audio_array=audio_slice,
|
| 323 |
+
model_name=model_name,
|
| 324 |
+
language=language,
|
| 325 |
+
vad_options=vad_options,
|
| 326 |
+
beam_size=beam_size,
|
| 327 |
+
temperature=temperature,
|
| 328 |
+
best_of=best_of,
|
| 329 |
+
initial_prompt=initial_prompt,
|
| 330 |
+
)
|
| 331 |
+
except Exception as e:
|
| 332 |
+
logger.error(f"Transcribe error: {e}")
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
if not text or not text.strip():
|
| 336 |
+
continue
|
| 337 |
+
|
| 338 |
+
label = speaker_map.get(seg.speaker, seg.speaker)
|
| 339 |
+
|
| 340 |
+
processed_segments.append(
|
| 341 |
+
TranscriptSegment(
|
| 342 |
+
start=seg.start,
|
| 343 |
+
end=seg.end,
|
| 344 |
+
speaker=label,
|
| 345 |
+
role=roles[label],
|
| 346 |
+
text=text.strip(),
|
| 347 |
+
)
|
| 348 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
+
if not processed_segments:
|
| 351 |
+
processed_segments = [
|
| 352 |
+
TranscriptSegment(
|
| 353 |
+
start=0.0,
|
| 354 |
+
end=duration,
|
| 355 |
+
speaker=speakers[0],
|
| 356 |
+
role=roles[speakers[0]],
|
| 357 |
+
text="(No speech detected)"
|
| 358 |
+
)
|
| 359 |
+
]
|
| 360 |
+
|
| 361 |
+
processed_segments = cls._merge_adjacent_segments(
|
| 362 |
+
processed_segments,
|
| 363 |
max_gap_s=0.6
|
| 364 |
)
|
| 365 |
|
| 366 |
+
processed_segments = cls._filter_segments_with_context(processed_segments)
|
| 367 |
+
processing_time = time.time() - t0
|
| 368 |
|
| 369 |
txt_content = cls._generate_txt(
|
| 370 |
+
processed_segments,
|
| 371 |
+
len(speakers),
|
| 372 |
processing_time,
|
| 373 |
duration,
|
| 374 |
roles
|
| 375 |
)
|
| 376 |
|
| 377 |
+
csv_content = cls._generate_csv(processed_segments)
|
| 378 |
|
| 379 |
return ProcessingResult(
|
| 380 |
+
segments=processed_segments,
|
| 381 |
+
speaker_count=len(speakers),
|
| 382 |
duration=duration,
|
| 383 |
processing_time=processing_time,
|
| 384 |
speakers=speakers,
|
|
|
|
| 392 |
def _is_meaningful_segment(
|
| 393 |
seg: TranscriptSegment,
|
| 394 |
min_duration_s: float = 0.6,
|
| 395 |
+
eps: float = 0.05,
|
| 396 |
+
min_words: int = 3
|
| 397 |
) -> bool:
|
| 398 |
+
duration = seg.end - seg.start
|
| 399 |
+
word_count = len(seg.text.split())
|
| 400 |
+
|
| 401 |
+
if duration + eps >= min_duration_s:
|
| 402 |
return True
|
| 403 |
+
if word_count >= min_words:
|
| 404 |
return True
|
| 405 |
if seg.role == "KH":
|
| 406 |
return True
|
| 407 |
+
|
| 408 |
return False
|
| 409 |
|
| 410 |
@classmethod
|
| 411 |
def _filter_segments_with_context(
|
| 412 |
cls,
|
| 413 |
+
segments: List[TranscriptSegment]
|
| 414 |
) -> List[TranscriptSegment]:
|
|
|
|
| 415 |
if not segments:
|
| 416 |
return segments
|
| 417 |
+
|
| 418 |
segments = sorted(segments, key=lambda s: s.start)
|
| 419 |
result = []
|
| 420 |
+
n = len(segments)
|
| 421 |
|
| 422 |
for i, seg in enumerate(segments):
|
| 423 |
+
prev_seg = segments[i - 1] if i > 0 else None
|
| 424 |
+
next_seg = segments[i + 1] if i < n - 1 else None
|
| 425 |
|
| 426 |
if cls._is_meaningful_segment(seg):
|
| 427 |
result.append(seg)
|
| 428 |
+
continue
|
| 429 |
+
|
| 430 |
+
if prev_seg and next_seg:
|
| 431 |
+
if prev_seg.speaker == seg.speaker == next_seg.speaker:
|
| 432 |
+
result.append(seg)
|
| 433 |
|
| 434 |
return result
|
| 435 |
|
|
|
|
| 437 |
@staticmethod
|
| 438 |
def _merge_adjacent_segments(
|
| 439 |
segments: List[TranscriptSegment],
|
| 440 |
+
max_gap_s: float = 0.5
|
| 441 |
) -> List[TranscriptSegment]:
|
| 442 |
+
"""
|
| 443 |
+
Merge adjacent segments if:
|
| 444 |
+
- same speaker
|
| 445 |
+
- gap <= max_gap_s
|
| 446 |
+
"""
|
| 447 |
if not segments:
|
| 448 |
return segments
|
| 449 |
|
| 450 |
+
segments = sorted(segments, key=lambda s: s.start)
|
| 451 |
merged = [segments[0]]
|
| 452 |
|
| 453 |
for seg in segments[1:]:
|
| 454 |
prev = merged[-1]
|
| 455 |
+
|
| 456 |
gap = seg.start - prev.end
|
| 457 |
|
| 458 |
+
if (
|
| 459 |
+
seg.speaker == prev.speaker
|
| 460 |
+
and gap <= max_gap_s
|
| 461 |
+
):
|
| 462 |
+
# MERGE
|
| 463 |
+
prev.text = f"{prev.text} {seg.text}".strip()
|
| 464 |
prev.end = max(prev.end, seg.end)
|
| 465 |
else:
|
| 466 |
merged.append(seg)
|
app/services/transcription.py
CHANGED
|
@@ -3,14 +3,11 @@ Transcription service using faster-whisper.
|
|
| 3 |
Supports multiple Vietnamese Whisper models with caching.
|
| 4 |
"""
|
| 5 |
import logging
|
| 6 |
-
import torch
|
| 7 |
from typing import Dict, Optional, List
|
| 8 |
from dataclasses import dataclass
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
-
from
|
| 12 |
-
from peft import PeftModel
|
| 13 |
-
|
| 14 |
|
| 15 |
from app.core.config import get_settings
|
| 16 |
|
|
@@ -20,9 +17,7 @@ settings = get_settings()
|
|
| 20 |
|
| 21 |
# Available Whisper models for Vietnamese
|
| 22 |
AVAILABLE_MODELS = {
|
| 23 |
-
|
| 24 |
-
"Whisper-LoRA": settings.whisper_lora_model_dir
|
| 25 |
-
|
| 26 |
}
|
| 27 |
|
| 28 |
|
|
@@ -40,88 +35,134 @@ class TranscriptionService:
|
|
| 40 |
Supports multiple models with caching.
|
| 41 |
"""
|
| 42 |
|
| 43 |
-
|
| 44 |
-
_processor = None
|
| 45 |
-
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 46 |
|
| 47 |
@classmethod
|
| 48 |
-
def get_model(cls):
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
@classmethod
|
| 72 |
-
def is_loaded(cls) -> bool:
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
@classmethod
|
| 76 |
-
def preload_model(cls) -> None:
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
@classmethod
|
| 80 |
def transcribe_with_words(
|
| 81 |
cls,
|
| 82 |
audio_array: np.ndarray,
|
|
|
|
| 83 |
language: str = "vi",
|
|
|
|
| 84 |
beam_size: int = 5,
|
| 85 |
temperature: float = 0.0,
|
|
|
|
|
|
|
| 86 |
) -> Dict:
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
if
|
| 90 |
-
|
| 91 |
|
| 92 |
-
|
| 93 |
audio_array,
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
)
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
max_new_tokens=settings.whisper_max_new_tokens,
|
| 110 |
-
)
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
return {
|
| 118 |
-
"text":
|
| 119 |
-
"words":
|
| 120 |
-
"info":
|
| 121 |
-
"engine": "transformers-whisper-lora",
|
| 122 |
-
"language": language,
|
| 123 |
-
"beam_size": beam_size,
|
| 124 |
-
},
|
| 125 |
}
|
| 126 |
|
| 127 |
|
|
@@ -129,15 +170,35 @@ class TranscriptionService:
|
|
| 129 |
async def transcribe_with_words_async(
|
| 130 |
cls,
|
| 131 |
audio_array: np.ndarray,
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
import asyncio
|
|
|
|
| 135 |
loop = asyncio.get_event_loop()
|
| 136 |
return await loop.run_in_executor(
|
| 137 |
None,
|
| 138 |
-
lambda: cls.transcribe_with_words(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
)
|
| 140 |
-
|
| 141 |
@classmethod
|
| 142 |
def get_available_models(cls) -> Dict[str, str]:
|
|
|
|
| 143 |
return AVAILABLE_MODELS.copy()
|
|
|
|
| 3 |
Supports multiple Vietnamese Whisper models with caching.
|
| 4 |
"""
|
| 5 |
import logging
|
|
|
|
| 6 |
from typing import Dict, Optional, List
|
| 7 |
from dataclasses import dataclass
|
| 8 |
|
| 9 |
import numpy as np
|
| 10 |
+
from faster_whisper import WhisperModel
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from app.core.config import get_settings
|
| 13 |
|
|
|
|
| 17 |
|
| 18 |
# Available Whisper models for Vietnamese
|
| 19 |
AVAILABLE_MODELS = {
|
| 20 |
+
"PhoWhisper VI Finetuned": settings.default_whisper_model
|
|
|
|
|
|
|
| 21 |
}
|
| 22 |
|
| 23 |
|
|
|
|
| 35 |
Supports multiple models with caching.
|
| 36 |
"""
|
| 37 |
|
| 38 |
+
_models: Dict[str, WhisperModel] = {}
|
|
|
|
|
|
|
| 39 |
|
| 40 |
@classmethod
|
| 41 |
+
def get_model(cls, model_name: str = None) -> WhisperModel:
|
| 42 |
+
"""
|
| 43 |
+
Get or load a Whisper model (lazy loading with caching).
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
model_name: Name of the model from AVAILABLE_MODELS
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Loaded WhisperModel instance
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
if model_name is None:
|
| 53 |
+
model_name = settings.default_whisper_model
|
| 54 |
+
|
| 55 |
+
cache_key = f"{model_name}_{settings.resolved_compute_type}"
|
| 56 |
+
|
| 57 |
+
if cache_key in cls._models:
|
| 58 |
+
return cls._models[cache_key]
|
| 59 |
+
|
| 60 |
+
# Get model path
|
| 61 |
+
if model_name in AVAILABLE_MODELS:
|
| 62 |
+
model_path = AVAILABLE_MODELS[model_name]
|
| 63 |
+
else:
|
| 64 |
+
# Fallback to first available model
|
| 65 |
+
model_name = list(AVAILABLE_MODELS.keys())[0]
|
| 66 |
+
model_path = AVAILABLE_MODELS[model_name]
|
| 67 |
+
|
| 68 |
+
logger.info(f"Loading Whisper model: {model_name} ({model_path})")
|
| 69 |
+
logger.debug(f"Device: {settings.resolved_device}, Compute type: {settings.resolved_compute_type}")
|
| 70 |
+
|
| 71 |
+
model = WhisperModel(
|
| 72 |
+
model_path,
|
| 73 |
+
device=settings.resolved_device,
|
| 74 |
+
compute_type=settings.resolved_compute_type,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
cls._models[cache_key] = model
|
| 78 |
+
logger.info(f"Whisper model loaded: {model_name}")
|
| 79 |
+
|
| 80 |
+
return model
|
| 81 |
|
| 82 |
@classmethod
|
| 83 |
+
def is_loaded(cls, model_name: str = None) -> bool:
|
| 84 |
+
if model_name is None:
|
| 85 |
+
model_name = settings.default_whisper_model
|
| 86 |
+
"""Check if a model is loaded."""
|
| 87 |
+
cache_key = f"{model_name}_{settings.resolved_compute_type}"
|
| 88 |
+
return cache_key in cls._models
|
| 89 |
|
| 90 |
@classmethod
|
| 91 |
+
def preload_model(cls, model_name: str = None) -> None:
|
| 92 |
+
"""Preload a model during startup."""
|
| 93 |
+
if model_name is None:
|
| 94 |
+
model_name = settings.default_whisper_model
|
| 95 |
+
try:
|
| 96 |
+
cls.get_model(model_name)
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.error(f"Failed to preload Whisper model: {e}")
|
| 99 |
+
raise
|
| 100 |
|
| 101 |
@classmethod
|
| 102 |
def transcribe_with_words(
|
| 103 |
cls,
|
| 104 |
audio_array: np.ndarray,
|
| 105 |
+
model_name: str = None,
|
| 106 |
language: str = "vi",
|
| 107 |
+
vad_options: Optional[dict] = None,
|
| 108 |
beam_size: int = 5,
|
| 109 |
temperature: float = 0.0,
|
| 110 |
+
best_of: int = 5,
|
| 111 |
+
initial_prompt: Optional[str] = None,
|
| 112 |
) -> Dict:
|
| 113 |
+
"""
|
| 114 |
+
Transcribe audio and return word-level timestamps.
|
| 115 |
+
"""
|
| 116 |
+
model = cls.get_model(model_name)
|
| 117 |
|
| 118 |
+
vad_filter = vad_options if vad_options else False
|
| 119 |
+
prompt = initial_prompt.strip() if initial_prompt and initial_prompt.strip() else None
|
| 120 |
|
| 121 |
+
segments_gen, info = model.transcribe(
|
| 122 |
audio_array,
|
| 123 |
+
language=language if language != "auto" else None,
|
| 124 |
+
beam_size=beam_size,
|
| 125 |
+
temperature=temperature,
|
| 126 |
+
best_of=best_of,
|
| 127 |
+
|
| 128 |
+
# QA / Stability
|
| 129 |
+
condition_on_previous_text=False,
|
| 130 |
+
no_speech_threshold=0.6,
|
| 131 |
+
|
| 132 |
+
word_timestamps=True,
|
| 133 |
+
|
| 134 |
+
# VAD
|
| 135 |
+
vad_filter=vad_filter,
|
| 136 |
+
vad_parameters=dict(
|
| 137 |
+
threshold=settings.vad_threshold,
|
| 138 |
+
min_speech_duration_ms=settings.vad_min_speech_duration_ms,
|
| 139 |
+
min_silence_duration_ms=settings.vad_min_silence_duration_ms,
|
| 140 |
+
),
|
| 141 |
+
|
| 142 |
+
initial_prompt=prompt,
|
| 143 |
)
|
| 144 |
|
| 145 |
+
words = []
|
| 146 |
+
full_text = []
|
| 147 |
+
|
| 148 |
+
for seg in segments_gen:
|
| 149 |
+
if seg.text:
|
| 150 |
+
full_text.append(seg.text.strip())
|
|
|
|
|
|
|
| 151 |
|
| 152 |
+
if hasattr(seg, "words") and seg.words:
|
| 153 |
+
for w in seg.words:
|
| 154 |
+
if not w.word.strip():
|
| 155 |
+
continue
|
| 156 |
+
words.append({
|
| 157 |
+
"word": w.word.strip(),
|
| 158 |
+
"start": float(w.start),
|
| 159 |
+
"end": float(w.end),
|
| 160 |
+
})
|
| 161 |
|
| 162 |
return {
|
| 163 |
+
"text": " ".join(full_text).strip(),
|
| 164 |
+
"words": words,
|
| 165 |
+
"info": info,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
}
|
| 167 |
|
| 168 |
|
|
|
|
| 170 |
async def transcribe_with_words_async(
|
| 171 |
cls,
|
| 172 |
audio_array: np.ndarray,
|
| 173 |
+
model_name: str = None,
|
| 174 |
+
language: str = "vi",
|
| 175 |
+
vad_options: Optional[dict] = None,
|
| 176 |
+
beam_size: int = 5,
|
| 177 |
+
temperature: float = 0.0,
|
| 178 |
+
best_of: int = 5,
|
| 179 |
+
initial_prompt: Optional[str] = None,
|
| 180 |
+
) -> str:
|
| 181 |
+
"""
|
| 182 |
+
Async wrapper for transcription (runs in thread pool).
|
| 183 |
+
"""
|
| 184 |
import asyncio
|
| 185 |
+
|
| 186 |
loop = asyncio.get_event_loop()
|
| 187 |
return await loop.run_in_executor(
|
| 188 |
None,
|
| 189 |
+
lambda: cls.transcribe_with_words(
|
| 190 |
+
audio_array,
|
| 191 |
+
model_name=model_name,
|
| 192 |
+
language=language,
|
| 193 |
+
vad_options=vad_options,
|
| 194 |
+
beam_size=beam_size,
|
| 195 |
+
temperature=temperature,
|
| 196 |
+
best_of=best_of,
|
| 197 |
+
initial_prompt=initial_prompt
|
| 198 |
+
)
|
| 199 |
)
|
| 200 |
+
|
| 201 |
@classmethod
|
| 202 |
def get_available_models(cls) -> Dict[str, str]:
|
| 203 |
+
"""Return list of available models."""
|
| 204 |
return AVAILABLE_MODELS.copy()
|