colab-user commited on
Commit
bad4ead
·
1 Parent(s): 1a0c8e0

fix ASR after diarization

Browse files
app/services/processor.py CHANGED
@@ -56,6 +56,7 @@ def normalize_asr_result(result: dict):
56
  "word": w["word"].strip(),
57
  "start": float(w["start"]),
58
  "end": float(w["end"]),
 
59
  })
60
 
61
  text = result.get("text", "").strip()
@@ -111,21 +112,6 @@ def overlap_prefix(a: str, b: str, n: int = 12) -> bool:
111
  return a[:n] in b or b[:n] in a
112
 
113
 
114
- MAX_SEGMENT_LEN = 15.0 # seconds
115
- MERGE_GAP = 0.4
116
-
117
-
118
- def split_long_segment(start, end, max_len=MAX_SEGMENT_LEN):
119
- segments = []
120
- t = start
121
- while t < end:
122
- segments.append((t, min(end, t + max_len)))
123
- t += max_len
124
- return segments
125
-
126
- # =========================
127
- # Processor
128
- # =========================
129
  class Processor:
130
  @classmethod
131
  async def process_audio(
@@ -206,12 +192,13 @@ class Processor:
206
 
207
  logger.info(f"roles(mapped) = {roles}")
208
 
209
- # 7: Transcribe using batch external vad
210
  logger.info("Step 7: Running ASR with external VAD batch...")
211
 
212
- asr_result = await TranscriptionService.transcribe_batch_with_external_vad(
213
  audio_array=y,
214
  sr=sr,
 
215
  model_name=model_name,
216
  vad_options={
217
  "threshold": vad_threshold,
@@ -247,6 +234,7 @@ class Processor:
247
  word=w["word"],
248
  start=w["start"],
249
  end=w["end"],
 
250
  )
251
  )
252
  except:
@@ -267,10 +255,9 @@ class Processor:
267
  else:
268
  mid = duration / 2
269
 
270
- label = guess_speaker_by_time(mid, mid, diarization_segments)
271
-
272
- label = speaker_map.get(label, label)
273
-
274
  processed_segments.append(
275
  TranscriptSegment(
276
  start=0,
@@ -283,7 +270,13 @@ class Processor:
283
  else:
284
  for seg in aligned_segments:
285
 
286
- label = speaker_map.get(seg.speaker, seg.speaker)
 
 
 
 
 
 
287
  role = roles.get(label, "KH")
288
 
289
  processed_segments.append(
 
56
  "word": w["word"].strip(),
57
  "start": float(w["start"]),
58
  "end": float(w["end"]),
59
+ "speaker": w.get("speaker")
60
  })
61
 
62
  text = result.get("text", "").strip()
 
112
  return a[:n] in b or b[:n] in a
113
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  class Processor:
116
  @classmethod
117
  async def process_audio(
 
192
 
193
  logger.info(f"roles(mapped) = {roles}")
194
 
195
+ # 7: Transcribe segments after diarization
196
  logger.info("Step 7: Running ASR with external VAD batch...")
197
 
198
+ asr_result = await TranscriptionService.transcribe_after_diarization(
199
  audio_array=y,
200
  sr=sr,
201
+ diarization_segments=diarization_segments,
202
  model_name=model_name,
203
  vad_options={
204
  "threshold": vad_threshold,
 
234
  word=w["word"],
235
  start=w["start"],
236
  end=w["end"],
237
+ speaker=w.get("speaker")
238
  )
239
  )
240
  except:
 
255
  else:
256
  mid = duration / 2
257
 
258
+ raw_spk = guess_speaker_by_time(mid, mid, diarization_segments)
259
+ label = speaker_map.get(raw_spk, "Speaker 1")
260
+
 
261
  processed_segments.append(
262
  TranscriptSegment(
263
  start=0,
 
270
  else:
271
  for seg in aligned_segments:
272
 
273
+ raw_spk = seg.speaker
274
+
275
+ if not raw_spk and hasattr(seg, "words"):
276
+ raw_spk = seg.words[0].speaker if seg.words else None
277
+
278
+ label = speaker_map.get(raw_spk, "Speaker 1")
279
+
280
  role = roles.get(label, "KH")
281
 
282
  processed_segments.append(
app/services/transcription.py CHANGED
@@ -30,6 +30,7 @@ class WordTimestamp:
30
  word: str
31
  start: float
32
  end: float
 
33
 
34
  class TranscriptionService:
35
  """
@@ -109,52 +110,52 @@ class TranscriptionService:
109
  if not segments:
110
  return []
111
 
112
- normalized = []
113
- buf_start, buf_end = segments[0]
114
 
115
- for start, end in segments[1:]:
116
 
117
- gap = start - buf_end
118
- new_dur = end - buf_start
119
 
120
  if (
121
- new_dur < settings.MIN_SEGMENT_SEC
122
  and gap < settings.MERGE_GAP_SEC
123
  ):
124
- buf_end = end
125
- continue
126
-
127
- normalized.append((buf_start, buf_end))
128
- buf_start, buf_end = start, end
129
-
130
- normalized.append((buf_start, buf_end))
131
 
132
- # -------- Split long segment --------
133
 
134
- final_segments = []
 
 
135
 
136
- for start, end in normalized:
137
-
138
- dur = end - start
139
 
140
  if dur <= settings.MAX_SEGMENT_SEC:
141
- final_segments.append((start, end))
142
  continue
143
 
144
- cur = start
145
- while cur < end:
146
- split_end = min(cur + settings.TARGET_SEGMENT_SEC, end)
147
- final_segments.append((cur, split_end))
148
- cur = split_end
149
 
150
- return final_segments
151
 
152
 
153
  @staticmethod
154
  def _slice_audio(audio: np.ndarray, sr: int, start: float, end: float):
155
 
156
- s = int(start * sr)
157
- e = int(end * sr)
 
 
 
158
 
159
  return audio[s:e]
160
 
@@ -167,9 +168,9 @@ class TranscriptionService:
167
  vad_options: Optional[dict] = None,
168
  beam_size: int = 6,
169
  temperature: float = 0.0,
170
- best_of: int = 1,
171
  patience: float = 1.2,
172
- length_penalty: float = 1.0,
173
  no_repeat_ngram_size=3,
174
 
175
  # Prompting
@@ -300,7 +301,7 @@ class TranscriptionService:
300
  patience=patience,
301
  length_penalty=length_penalty,
302
  no_repeat_ngram_size=no_repeat_ngram_size,
303
- initial_prompt=None,
304
  prefix_text=prefix_text,
305
  condition_on_previous_text=condition_on_previous_text,
306
  no_speech_threshold=no_speech_threshold,
@@ -390,91 +391,105 @@ class TranscriptionService:
390
  return await loop.run_in_executor(None, _run_vad)
391
 
392
  @classmethod
393
- async def transcribe_batch_with_external_vad(
394
  cls,
395
  audio_array: np.ndarray,
396
- sr: int = 16000,
 
397
  model_name: Optional[str] = None,
398
  vad_options: Optional[dict] = None,
399
- ) -> Dict:
400
-
401
- vad_segments = await cls.get_vad_segments_async(
402
- audio_array,
403
- sr,
404
- vad_options
405
- )
406
-
407
- vad_segments = cls._normalize_segments(vad_segments)
408
 
409
  all_words = []
410
- texts = []
411
-
412
  total_audio_dur = len(audio_array) / sr
413
 
414
- for start, end in vad_segments:
415
 
416
- chunk = cls._slice_audio(audio_array, sr, start, end)
 
 
417
 
418
- if len(chunk) < sr * 0.5:
419
- continue
420
-
421
- energy = float(np.sqrt(np.mean(chunk ** 2)))
422
 
423
- if energy < 0.008:
424
  continue
425
-
426
- is_tail = (total_audio_dur - end) < 1.5
427
-
428
- beam_size = 6
429
- best_of = 5
430
- patience = 1.2
431
- length_penalty = 1.1
432
-
433
- if is_tail:
434
- beam_size = 4
435
- best_of = 3
436
- patience = 1.0
437
-
438
- res = await cls.transcribe_with_words_async(
439
- audio_array=chunk,
440
- model_name=model_name,
441
- language="vi",
442
- beam_size=beam_size,
443
- best_of=best_of,
444
- temperature=0,
445
- patience=patience,
446
- length_penalty=length_penalty,
447
-
448
- # reset prompt at tail
449
- initial_prompt=None if is_tail else "Cuộc gọi thu hồi nợ.",
450
 
451
- condition_on_previous_text=False,
 
 
452
 
453
- no_speech_threshold=0.4 if not is_tail else 0.5,
454
- log_prob_threshold=-2.0,
455
- compression_ratio_threshold=2.2,
456
 
457
- )
458
-
459
- text = res["text"].strip()
460
- if len(text.split()) > 6:
461
- uniq_ratio = len(set(text.split())) / len(text.split())
 
462
 
463
- if uniq_ratio < 0.45:
464
- logger.debug("Drop repetitive hallucination segment")
465
  continue
466
 
467
- if is_tail and len(text) > 80:
468
- logger.debug("Drop suspicious tail text")
469
- continue
470
- for w in res["words"]:
471
- w["start"] += start
472
- w["end"] += start
473
 
474
- texts.append(res["text"])
475
- all_words.extend(res["words"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
 
477
  return {
478
- "text": " ".join(texts),
479
- "words": all_words
480
- }
 
 
30
  word: str
31
  start: float
32
  end: float
33
+ speaker: Optional[str] = None
34
 
35
  class TranscriptionService:
36
  """
 
110
  if not segments:
111
  return []
112
 
113
+ merged = []
114
+ cur_s, cur_e = segments[0]
115
 
116
+ for s, e in segments[1:]:
117
 
118
+ gap = s - cur_e
119
+ new_dur = e - cur_s
120
 
121
  if (
122
+ new_dur < settings.MAX_SEGMENT_SEC
123
  and gap < settings.MERGE_GAP_SEC
124
  ):
125
+ cur_e = e
126
+ else:
127
+ merged.append((cur_s, cur_e))
128
+ cur_s, cur_e = s, e
 
 
 
129
 
130
+ merged.append((cur_s, cur_e))
131
 
132
+ # split long segments
133
+ final = []
134
+ for s, e in merged:
135
 
136
+ dur = e - s
 
 
137
 
138
  if dur <= settings.MAX_SEGMENT_SEC:
139
+ final.append((s, e))
140
  continue
141
 
142
+ cur = s
143
+ while cur < e:
144
+ nxt = min(cur + settings.TARGET_SEGMENT_SEC, e)
145
+ final.append((cur, nxt))
146
+ cur = nxt
147
 
148
+ return final
149
 
150
 
151
  @staticmethod
152
  def _slice_audio(audio: np.ndarray, sr: int, start: float, end: float):
153
 
154
+ s = max(0, int(start * sr))
155
+ e = min(len(audio), int(end * sr))
156
+
157
+ if e <= s:
158
+ return np.zeros(1, dtype=np.float32)
159
 
160
  return audio[s:e]
161
 
 
168
  vad_options: Optional[dict] = None,
169
  beam_size: int = 6,
170
  temperature: float = 0.0,
171
+ best_of: int = 5,
172
  patience: float = 1.2,
173
+ length_penalty: float = 0.95,
174
  no_repeat_ngram_size=3,
175
 
176
  # Prompting
 
301
  patience=patience,
302
  length_penalty=length_penalty,
303
  no_repeat_ngram_size=no_repeat_ngram_size,
304
+ initial_prompt=initial_prompt,
305
  prefix_text=prefix_text,
306
  condition_on_previous_text=condition_on_previous_text,
307
  no_speech_threshold=no_speech_threshold,
 
391
  return await loop.run_in_executor(None, _run_vad)
392
 
393
  @classmethod
394
+ async def transcribe_after_diarization(
395
  cls,
396
  audio_array: np.ndarray,
397
+ sr: int,
398
+ diarization_segments: List,
399
  model_name: Optional[str] = None,
400
  vad_options: Optional[dict] = None,
401
+ ):
 
 
 
 
 
 
 
 
402
 
403
  all_words = []
404
+ segments_text = []
405
+
406
  total_audio_dur = len(audio_array) / sr
407
 
408
+ for diar in diarization_segments:
409
 
410
+ spk = diar.speaker
411
+ seg_start = diar.start
412
+ seg_end = diar.end
413
 
414
+ speaker_audio = cls._slice_audio(
415
+ audio_array, sr, seg_start, seg_end
416
+ )
 
417
 
418
+ if len(speaker_audio) < sr * 0.5:
419
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
+ vad_segments = await cls.get_vad_segments_async(
422
+ speaker_audio, sr, vad_options
423
+ )
424
 
425
+ vad_segments = cls._normalize_segments(vad_segments)
 
 
426
 
427
+ for v_start, v_end in vad_segments:
428
+
429
+ g_start = seg_start + v_start
430
+ g_end = seg_start + v_end
431
+
432
+ chunk = cls._slice_audio(audio_array, sr, g_start, g_end)
433
 
434
+ if len(chunk) < sr * 0.5:
 
435
  continue
436
 
437
+ # energy filter
438
+ energy = float(np.sqrt(np.mean(chunk ** 2)))
439
+ if not np.isfinite(energy) or energy < 0.006:
440
+ continue
 
 
441
 
442
+ is_tail = (total_audio_dur - g_end) < 1.5
443
+
444
+ res = await cls.transcribe_with_words_async(
445
+ chunk,
446
+ model_name=model_name,
447
+ beam_size=6 if not is_tail else 4,
448
+ best_of=5,
449
+ temperature=0,
450
+ patience=1.2,
451
+ condition_on_previous_text=False,
452
+ no_speech_threshold=0.4,
453
+ compression_ratio_threshold=2.2,
454
+ )
455
+
456
+ if not res:
457
+ continue
458
+
459
+ text = res.get("text")
460
+ if not text:
461
+ continue
462
+
463
+ # hallucination filter
464
+ tokens = text.split()
465
+ if len(tokens) > 6:
466
+ uniq_ratio = len(set(tokens)) / len(tokens)
467
+ if uniq_ratio < 0.45:
468
+ continue
469
+
470
+ for w in res.get("words", []):
471
+ w["start"] += g_start
472
+ w["end"] += g_start
473
+ w["speaker"] = spk
474
+ all_words.append(w)
475
+
476
+ segments_text.append(
477
+ {
478
+ "speaker": spk,
479
+ "start": g_start,
480
+ "end": g_end,
481
+ "text": text,
482
+ }
483
+ )
484
+
485
+ # ===== sort results =====
486
+ all_words.sort(key=lambda x: x["start"])
487
+ segments_text.sort(key=lambda x: x["start"])
488
+
489
+ full_text = " ".join(seg["text"] for seg in segments_text)
490
 
491
  return {
492
+ "segments": segments_text,
493
+ "words": all_words,
494
+ "text": full_text
495
+ }