vyluong commited on
Commit
f14c6b9
·
verified ·
1 Parent(s): b34d849

Update app/services/processor.py

Browse files
Files changed (1) hide show
  1. app/services/processor.py +57 -188
app/services/processor.py CHANGED
@@ -72,32 +72,17 @@ class ProcessingResult:
72
  emotion_timeline: List[EmotionPoint] = None
73
  emotion_changes: List[EmotionChange] = None
74
 
75
-
76
- def pad_and_refine_tensor(
77
- waveform: torch.Tensor,
78
- sr: int,
79
- start_s: float,
80
- end_s: float,
81
- pad_ms: int = 250,
82
- ) -> Tuple[float, float]:
83
-
84
- total_len = waveform.shape[1]
85
- s = max(int((start_s - pad_ms / 1000) * sr), 0)
86
- e = min(int((end_s + pad_ms / 1000) * sr), total_len)
87
-
88
- if e <= s:
89
- return start_s, end_s
90
-
91
- return s / sr, e / sr
92
-
93
-
94
  def normalize_asr_result(result: dict):
95
 
96
  words = []
97
 
98
  for w in result.get("words", []):
99
 
100
- word = w.get("word", "").strip()
 
 
 
 
101
  if not word:
102
  continue
103
 
@@ -107,12 +92,15 @@ def normalize_asr_result(result: dict):
107
  "start": float(w["start"]),
108
  "end": float(w["end"]),
109
  "speaker": w.get("speaker"),
 
 
 
110
  }
111
  )
112
 
113
  text = result.get("text", "").strip()
114
- return text, words
115
 
 
116
 
117
  def guess_speaker_by_overlap(start, end, diar_segments):
118
 
@@ -163,65 +151,44 @@ def format_timestamp(seconds: float) -> str:
163
  return f"{m:02d}:{s:06.3f}"
164
 
165
 
166
- def extract_mfcc_segment(
167
- audio: np.ndarray,
168
- sr: int,
169
- start: float,
170
- end: float,
171
- duration=5,
172
- ):
173
-
174
- start_sample = int(start * sr)
175
- end_sample = int(end * sr)
176
-
177
- segment = audio[start_sample:end_sample]
178
-
179
- if len(segment) == 0:
180
- return None
181
-
182
- target_len = int(sr * duration)
183
-
184
- if len(segment) < target_len:
185
- segment = np.pad(segment, (0, target_len - len(segment)), mode="symmetric")
186
- else:
187
- segment = segment[:target_len]
188
-
189
- mfcc = librosa.feature.mfcc(
190
- y=segment, sr=sr, n_mfcc=128, n_fft=2048, hop_length=512
191
- )
192
-
193
- return mfcc
194
-
195
-
196
  def merge_consecutive_segments(
197
  segments: List[SpeakerSegment],
198
- max_gap: float = 0.8,
199
- min_duration: float = 0.15,
200
  ) -> List[SpeakerSegment]:
201
- """Merge consecutive segments from same speaker."""
202
  if not segments:
203
  return []
204
 
205
- merged = []
206
- current = SpeakerSegment(
207
- start=segments[0].start, end=segments[0].end, speaker=segments[0].speaker
208
  )
209
 
 
 
210
  for seg in segments[1:]:
211
- seg_dur = seg.end - seg.start
 
 
 
 
 
 
212
  if (
213
- seg.speaker == current.speaker
214
- and (seg.start - current.end) <= max_gap
215
- or seg_dur < min_duration
216
  ):
217
- # Merge: extend current segment
218
- current.end = seg.end
 
 
 
 
219
  else:
220
- # New speaker or gap too large
221
- merged.append(current)
222
- current = SpeakerSegment(start=seg.start, end=seg.end, speaker=seg.speaker)
223
 
224
- merged.append(current)
225
  return merged
226
 
227
 
@@ -242,8 +209,7 @@ class Processor:
242
  audio_path: Path,
243
  model_name: str = "PhoWhisper Lora Finetuned",
244
  language="vi",
245
- merge_segments: bool = True,
246
- backend: str = "whisper",
247
  ) -> ProcessingResult:
248
 
249
  import asyncio
@@ -282,8 +248,10 @@ class Processor:
282
 
283
  diarization_segments = [
284
  SpeakerSegment(
285
- *pad_and_refine_tensor(waveform, sr, s.start, s.end),
 
286
  speaker=s.speaker,
 
287
  )
288
  for s in diarization_segments
289
  ]
@@ -301,74 +269,29 @@ class Processor:
301
 
302
  speakers = list(speaker_map.values())
303
 
304
- # 5. NORMALIZE ROLES
305
- speakers = diarization.speakers or []
306
- roles = diarization.roles or {}
307
 
308
- # Default fallback
309
- for label in speakers:
310
- roles.setdefault(label, "KH")
311
 
312
  logger.info(f"roles(mapped) = {roles}")
313
 
314
  # 7: Transcribe segments after diarization
315
  logger.info("Step 7: Running ASR with external VAD batch...")
316
 
317
- # asr_result = await TranscriptionService.transcribe_with_words_async(
318
- # audio_array=y,
319
- # model_name=model_name,
320
- # language=language,
321
- # vad_options=True
322
- # )
323
-
324
- # text, raw_words = normalize_asr_result(asr_result)
325
-
326
- logger.info(f"ASR backend = {backend}")
327
-
328
- if backend == "whisper":
329
- asr_result = await TranscriptionService.transcribe_with_words_async(
330
- audio_array=y,
331
- model_name=model_name,
332
- language=language,
333
- vad_options=True,
334
- )
335
- text, raw_words = normalize_asr_result(asr_result)
336
-
337
- elif backend == "w2v":
338
-
339
- logger.info("Running W2V per diarization segment...")
340
-
341
- raw_words = []
342
-
343
- for seg in diarization_segments:
344
-
345
- start_sample = int(seg.start * sr)
346
- end_sample = int(seg.end * sr)
347
-
348
- chunk = y[start_sample:end_sample]
349
-
350
- if len(chunk) == 0:
351
- continue
352
-
353
- text = TranscriptionService.infer_w2v(chunk)
354
-
355
- if not text:
356
- continue
357
-
358
- raw_words.append(
359
- {
360
- "word": text,
361
- "start": seg.start,
362
- "end": seg.end,
363
- "speaker": seg.speaker,
364
- }
365
- )
366
-
367
- else:
368
- raise ValueError(f"Unsupported backend: {backend}")
369
 
370
- processed_segments: List[TranscriptSegment] = []
371
 
 
372
  if not raw_words:
373
  processed_segments = [
374
  TranscriptSegment(
@@ -400,33 +323,17 @@ class Processor:
400
  start=w["start"],
401
  end=w["end"],
402
  speaker=spk,
 
403
  )
404
  )
405
 
406
  word_objs.sort(key=lambda x: x.start)
407
 
408
  # ===== ALIGNMENT =====
409
- # aligned_segments = AlignmentService.align_precision(
410
- # word_objs,
411
- # diarization_segments
412
- # )
413
- if backend == "w2v":
414
- aligned_segments = [
415
- TranscriptSegment(
416
- start=w["start"],
417
- end=w["end"],
418
- speaker=speaker_map.get(w["speaker"], "Speaker 1"),
419
- role=roles.get(
420
- speaker_map.get(w["speaker"], "Speaker 1"), "KH"
421
- ),
422
- text=w["word"],
423
- )
424
- for w in raw_words
425
- ]
426
- else:
427
- aligned_segments = AlignmentService.align_precision(
428
- word_objs, diarization_segments
429
- )
430
 
431
  processed_segments = []
432
 
@@ -463,7 +370,6 @@ class Processor:
463
  )
464
  )
465
 
466
- processed_segments = cls._merge_adjacent_segments(processed_segments)
467
  processed_segments.sort(key=lambda x: x.start)
468
 
469
  # 8 : Predict emotion segments
@@ -497,43 +403,6 @@ class Processor:
497
  emotion_changes=emotion_changes,
498
  )
499
 
500
- @staticmethod
501
- def _merge_adjacent_segments(
502
- segments: List[TranscriptSegment],
503
- max_gap_s: float = 0.8,
504
- max_segment_duration: float = 9.0,
505
- ) -> List[TranscriptSegment]:
506
- """
507
- Merge adjacent segments if:
508
- - same speaker
509
- - gap <= max_gap_s
510
- """
511
- if not segments:
512
- return segments
513
-
514
- segments = sorted(segments, key=lambda s: s.start)
515
- merged = [segments[0]]
516
-
517
- for seg in segments[1:]:
518
- prev = merged[-1]
519
-
520
- gap = seg.start - prev.end
521
- combined_duration = seg.end - prev.start
522
-
523
- if (
524
- seg.speaker == prev.speaker
525
- and seg.role == prev.role
526
- and gap <= max_gap_s
527
- and combined_duration <= max_segment_duration
528
- and not overlap_prefix(seg.text, prev.text)
529
- ):
530
- # MERGE
531
- prev.text = f"{prev.text} {seg.text}".strip()
532
- prev.end = max(prev.end, seg.end)
533
- else:
534
- merged.append(seg)
535
-
536
- return merged
537
 
538
  @staticmethod
539
  def _predict_emotion_segments(
 
72
  emotion_timeline: List[EmotionPoint] = None
73
  emotion_changes: List[EmotionChange] = None
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def normalize_asr_result(result: dict):
76
 
77
  words = []
78
 
79
  for w in result.get("words", []):
80
 
81
+ word = (
82
+ w.get("word", "")
83
+ .strip()
84
+ )
85
+
86
  if not word:
87
  continue
88
 
 
92
  "start": float(w["start"]),
93
  "end": float(w["end"]),
94
  "speaker": w.get("speaker"),
95
+ "confidence": float(
96
+ w.get("confidence", 1.0)
97
+ ),
98
  }
99
  )
100
 
101
  text = result.get("text", "").strip()
 
102
 
103
+ return text, words
104
 
105
  def guess_speaker_by_overlap(start, end, diar_segments):
106
 
 
151
  return f"{m:02d}:{s:06.3f}"
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def merge_consecutive_segments(
155
  segments: List[SpeakerSegment],
156
+ max_gap: float = 0.80,
157
+ max_overlap: float = 0.15,
158
  ) -> List[SpeakerSegment]:
159
+
160
  if not segments:
161
  return []
162
 
163
+ segments = sorted(
164
+ segments,
165
+ key=lambda x: x.start
166
  )
167
 
168
+ merged = [segments[0]]
169
+
170
  for seg in segments[1:]:
171
+
172
+ prev = merged[-1]
173
+
174
+ gap = seg.start - prev.end
175
+
176
+ overlap = prev.end - seg.start
177
+
178
  if (
179
+ seg.speaker == prev.speaker
180
+ and gap <= max_gap
181
+ and overlap <= max_overlap
182
  ):
183
+
184
+ prev.end = max(
185
+ prev.end,
186
+ seg.end
187
+ )
188
+
189
  else:
190
+ merged.append(seg)
 
 
191
 
 
192
  return merged
193
 
194
 
 
209
  audio_path: Path,
210
  model_name: str = "PhoWhisper Lora Finetuned",
211
  language="vi",
212
+ merge_segments: bool = True
 
213
  ) -> ProcessingResult:
214
 
215
  import asyncio
 
248
 
249
  diarization_segments = [
250
  SpeakerSegment(
251
+ start=new_start,
252
+ end=new_end,
253
  speaker=s.speaker,
254
+ confidence=getattr(s, "confidence", 1.0),
255
  )
256
  for s in diarization_segments
257
  ]
 
269
 
270
  speakers = list(speaker_map.values())
271
 
272
+ raw_roles = diarization.roles or {}
273
+
274
+ roles = {}
275
 
276
+ for raw_spk, label in speaker_map.items():
277
+ roles[label] = raw_roles.get(raw_spk, "KH")
278
+
279
 
280
  logger.info(f"roles(mapped) = {roles}")
281
 
282
  # 7: Transcribe segments after diarization
283
  logger.info("Step 7: Running ASR with external VAD batch...")
284
 
285
+ asr_result = await TranscriptionService.transcribe_with_words_async(
286
+ audio_array=y,
287
+ model_name=model_name,
288
+ language=language,
289
+ vad_options=False
290
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
+ text, raw_words = normalize_asr_result(asr_result)
293
 
294
+
295
  if not raw_words:
296
  processed_segments = [
297
  TranscriptSegment(
 
323
  start=w["start"],
324
  end=w["end"],
325
  speaker=spk,
326
+ confidence=w.get("confidence", 1.0)
327
  )
328
  )
329
 
330
  word_objs.sort(key=lambda x: x.start)
331
 
332
  # ===== ALIGNMENT =====
333
+ aligned_segments = AlignmentService.align_precision(
334
+ word_objs,
335
+ diarization_segments
336
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
  processed_segments = []
339
 
 
370
  )
371
  )
372
 
 
373
  processed_segments.sort(key=lambda x: x.start)
374
 
375
  # 8 : Predict emotion segments
 
403
  emotion_changes=emotion_changes,
404
  )
405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
  @staticmethod
408
  def _predict_emotion_segments(