liuyang commited on
Commit
7bde45c
·
1 Parent(s): 36812ab

Refactor speaker assignment logic in transcription: Enhanced the `assign_speakers_to_transcription` method to detect unmatched diarization segments and introduced a second pass for splitting segments with speaker changes. Improved handling of speaker transitions and added functionality to re-process unmatched segments.

Browse files
Files changed (1) hide show
  1. app.py +210 -50
app.py CHANGED
@@ -565,9 +565,13 @@ class WhisperTranscriber:
565
  return _embedder
566
 
567
  def assign_speakers_to_transcription(self, transcription_results, diarization_segments):
568
- """Assign speakers to words and segments based on overlap with diarization segments."""
 
 
 
 
569
  if not diarization_segments:
570
- return transcription_results
571
  # Helper: find the diarization speaker active at time t, or closest
572
  def speaker_at(t: float):
573
  for dseg in diarization_segments:
@@ -607,8 +611,8 @@ class WhisperTranscriber:
607
  mid = (float(start_t) + float(end_t)) / 2.0
608
  return speaker_at(mid)
609
 
 
610
  for seg in transcription_results:
611
- # Assign per-word speakers using overlap, then smooth and stabilize boundaries
612
  if seg.get("words"):
613
  words = seg["words"]
614
  # 1) Initial assignment by overlap
@@ -628,55 +632,165 @@ class WhisperTranscriber:
628
  smoothed[i] = prev_spk
629
  for i in range(len(words)):
630
  words[i]["speaker"] = smoothed[i]
631
-
632
- # 3) Determine dominant speaker by summed word durations
633
- speaker_dur = {}
634
- total_word_dur = 0.0
635
- for w in words:
636
- dur = max(0.0, float(w["end"]) - float(w["start"]))
637
- total_word_dur += dur
638
- spk = w.get("speaker", "SPEAKER_00")
639
- speaker_dur[spk] = speaker_dur.get(spk, 0.0) + dur
640
- if speaker_dur:
641
- dominant_speaker = max(speaker_dur.items(), key=lambda kv: kv[1])[0]
642
- else:
643
- dominant_speaker = speaker_at((float(seg["start"]) + float(seg["end"])) / 2.0)
644
-
645
- # 4) Boundary stabilization: relabel tiny prefix/suffix runs to dominant
646
- seg_duration = max(1e-6, float(seg["end"]) - float(seg["start"]))
647
- max_boundary_sec = 0.5 # hard cap for how much to relabel at edges
648
- max_boundary_frac = 0.2 # or up to 20% of the segment duration
649
-
650
- # prefix
651
- prefix_dur = 0.0
652
- prefix_count = 0
653
- for w in words:
654
- if w.get("speaker") == dominant_speaker:
655
- break
656
- prefix_dur += max(0.0, float(w["end"]) - float(w["start"]))
657
- prefix_count += 1
658
- if prefix_count > 0 and prefix_dur <= min(max_boundary_sec, max_boundary_frac * seg_duration):
659
- for i in range(prefix_count):
660
- words[i]["speaker"] = dominant_speaker
661
-
662
- # suffix
663
- suffix_dur = 0.0
664
- suffix_count = 0
665
- for w in reversed(words):
666
- if w.get("speaker") == dominant_speaker:
667
- break
668
- suffix_dur += max(0.0, float(w["end"]) - float(w["start"]))
669
- suffix_count += 1
670
- if suffix_count > 0 and suffix_dur <= min(max_boundary_sec, max_boundary_frac * seg_duration):
671
- for i in range(len(words) - suffix_count, len(words)):
672
- words[i]["speaker"] = dominant_speaker
673
-
674
- # 5) Final segment speaker
675
- seg["speaker"] = dominant_speaker
676
  else:
677
  # No word timings: choose by overlap with diarization over the whole segment
678
  seg["speaker"] = best_speaker_for_interval(float(seg["start"]), float(seg["end"]))
679
- return transcription_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
680
 
681
  def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0):
682
  """Group consecutive segments from the same speaker"""
@@ -801,7 +915,53 @@ class WhisperTranscriber:
801
  )
802
 
803
  # Step 4: Merge diarization into transcription (assign speakers)
804
- transcription_results = self.assign_speakers_to_transcription(transcription_results, diarization_segments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
805
 
806
  # Step 5: Group segments if requested
807
  if group_segments:
 
565
  return _embedder
566
 
567
  def assign_speakers_to_transcription(self, transcription_results, diarization_segments):
568
+ """Assign speakers to words and segments based on overlap with diarization segments.
569
+
570
+ Also detects diarization segments that do not overlap any transcription segment and
571
+ returns them so they can be re-processed (e.g., re-transcribed) later.
572
+ """
573
  if not diarization_segments:
574
+ return transcription_results, []
575
  # Helper: find the diarization speaker active at time t, or closest
576
  def speaker_at(t: float):
577
  for dseg in diarization_segments:
 
611
  mid = (float(start_t) + float(end_t)) / 2.0
612
  return speaker_at(mid)
613
 
614
+ # First pass: assign speakers to words and apply smoothing
615
  for seg in transcription_results:
 
616
  if seg.get("words"):
617
  words = seg["words"]
618
  # 1) Initial assignment by overlap
 
632
  smoothed[i] = prev_spk
633
  for i in range(len(words)):
634
  words[i]["speaker"] = smoothed[i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  else:
636
  # No word timings: choose by overlap with diarization over the whole segment
637
  seg["speaker"] = best_speaker_for_interval(float(seg["start"]), float(seg["end"]))
638
+
639
+ # Second pass: split segments that have speaker changes within them
640
+ split_segments = []
641
+ for seg in transcription_results:
642
+ words = seg.get("words", [])
643
+ if not words or len(words) <= 1:
644
+ # No words or single word - can't split, assign speaker directly
645
+ if not words:
646
+ seg["speaker"] = best_speaker_for_interval(float(seg["start"]), float(seg["end"]))
647
+ else:
648
+ seg["speaker"] = words[0].get("speaker", "SPEAKER_00")
649
+ split_segments.append(seg)
650
+ continue
651
+
652
+ # Find speaker transition points with minimum duration filter
653
+ current_speaker = words[0].get("speaker", "SPEAKER_00")
654
+ split_points = [0] # Always start with first word
655
+ min_segment_duration = 0.5 # Minimum 0.5 seconds per segment
656
+
657
+ for i in range(1, len(words)):
658
+ word_speaker = words[i].get("speaker", "SPEAKER_00")
659
+ if word_speaker != current_speaker:
660
+ # Check if this would create a segment that's too short
661
+ if split_points:
662
+ last_split = split_points[-1]
663
+ segment_start_time = float(words[last_split]["start"])
664
+ current_word_time = float(words[i-1]["end"])
665
+ segment_duration = current_word_time - segment_start_time
666
+
667
+ # Only split if the previous segment would be long enough
668
+ if segment_duration >= min_segment_duration:
669
+ split_points.append(i)
670
+ current_speaker = word_speaker
671
+ # If too short, continue without splitting (speaker will be resolved by dominant speaker logic)
672
+ else:
673
+ split_points.append(i)
674
+ current_speaker = word_speaker
675
+
676
+ split_points.append(len(words)) # End point
677
+
678
+ # Create sub-segments if we found speaker changes
679
+ if len(split_points) <= 2:
680
+ # No splits needed - process as single segment
681
+ self._assign_dominant_speaker_to_segment(seg, speaker_at, best_speaker_for_interval)
682
+ split_segments.append(seg)
683
+ else:
684
+ # Split into multiple segments
685
+ for i in range(len(split_points) - 1):
686
+ start_idx = split_points[i]
687
+ end_idx = split_points[i + 1]
688
+
689
+ if end_idx <= start_idx:
690
+ continue
691
+
692
+ subseg_words = words[start_idx:end_idx]
693
+ if not subseg_words:
694
+ continue
695
+
696
+ # Calculate segment timing and text from words
697
+ subseg_start = float(subseg_words[0]["start"])
698
+ subseg_end = float(subseg_words[-1]["end"])
699
+ subseg_text = " ".join(w.get("word", "").strip() for w in subseg_words if w.get("word", "").strip())
700
+
701
+ # Create new sub-segment
702
+ new_seg = {
703
+ "start": subseg_start,
704
+ "end": subseg_end,
705
+ "text": subseg_text,
706
+ "words": subseg_words,
707
+ "duration": subseg_end - subseg_start,
708
+ }
709
+
710
+ # Copy over other fields from original segment if they exist
711
+ for key in ["avg_logprob"]:
712
+ if key in seg:
713
+ new_seg[key] = seg[key]
714
+
715
+ # Assign dominant speaker to this sub-segment
716
+ self._assign_dominant_speaker_to_segment(new_seg, speaker_at, best_speaker_for_interval)
717
+ split_segments.append(new_seg)
718
+
719
+ # Update transcription_results with split segments
720
+ transcription_results = split_segments
721
+
722
+ # Identify diarization segments that have no overlapping transcription segments
723
+ unmatched_diarization_segments = []
724
+ for dseg in diarization_segments:
725
+ d_start = float(dseg["start"])
726
+ d_end = float(dseg["end"])
727
+ has_overlap = False
728
+ for seg in transcription_results:
729
+ if interval_overlap(d_start, d_end, float(seg["start"]), float(seg["end"])) > 1e-6:
730
+ has_overlap = True
731
+ break
732
+ if not has_overlap:
733
+ unmatched_diarization_segments.append({
734
+ "start": d_start,
735
+ "end": d_end,
736
+ "speaker": dseg["speaker"],
737
+ })
738
+
739
+ return transcription_results, unmatched_diarization_segments
740
+
741
+ def _assign_dominant_speaker_to_segment(self, seg, speaker_at_func, best_speaker_for_interval_func):
742
+ """Assign dominant speaker to a segment based on word durations and boundary stabilization."""
743
+ words = seg.get("words", [])
744
+ if not words:
745
+ # No words: use segment-level overlap
746
+ seg["speaker"] = best_speaker_for_interval_func(float(seg["start"]), float(seg["end"]))
747
+ return
748
+
749
+ # 1) Determine dominant speaker by summed word durations
750
+ speaker_dur = {}
751
+ total_word_dur = 0.0
752
+ for w in words:
753
+ dur = max(0.0, float(w["end"]) - float(w["start"]))
754
+ total_word_dur += dur
755
+ spk = w.get("speaker", "SPEAKER_00")
756
+ speaker_dur[spk] = speaker_dur.get(spk, 0.0) + dur
757
+
758
+ if speaker_dur:
759
+ dominant_speaker = max(speaker_dur.items(), key=lambda kv: kv[1])[0]
760
+ else:
761
+ dominant_speaker = speaker_at_func((float(seg["start"]) + float(seg["end"])) / 2.0)
762
+
763
+ # 2) Boundary stabilization: relabel tiny prefix/suffix runs to dominant
764
+ seg_duration = max(1e-6, float(seg["end"]) - float(seg["start"]))
765
+ max_boundary_sec = 0.5 # hard cap for how much to relabel at edges
766
+ max_boundary_frac = 0.2 # or up to 20% of the segment duration
767
+
768
+ # prefix
769
+ prefix_dur = 0.0
770
+ prefix_count = 0
771
+ for w in words:
772
+ if w.get("speaker") == dominant_speaker:
773
+ break
774
+ prefix_dur += max(0.0, float(w["end"]) - float(w["start"]))
775
+ prefix_count += 1
776
+ if prefix_count > 0 and prefix_dur <= min(max_boundary_sec, max_boundary_frac * seg_duration):
777
+ for i in range(prefix_count):
778
+ words[i]["speaker"] = dominant_speaker
779
+
780
+ # suffix
781
+ suffix_dur = 0.0
782
+ suffix_count = 0
783
+ for w in reversed(words):
784
+ if w.get("speaker") == dominant_speaker:
785
+ break
786
+ suffix_dur += max(0.0, float(w["end"]) - float(w["start"]))
787
+ suffix_count += 1
788
+ if suffix_count > 0 and suffix_dur <= min(max_boundary_sec, max_boundary_frac * seg_duration):
789
+ for i in range(len(words) - suffix_count, len(words)):
790
+ words[i]["speaker"] = dominant_speaker
791
+
792
+ # 3) Final segment speaker
793
+ seg["speaker"] = dominant_speaker
794
 
795
  def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0):
796
  """Group consecutive segments from the same speaker"""
 
915
  )
916
 
917
  # Step 4: Merge diarization into transcription (assign speakers)
918
+ transcription_results, unmatched_diarization_segments = self.assign_speakers_to_transcription(
919
+ transcription_results, diarization_segments
920
+ )
921
+
922
+ # Step 4.1: Transcribe diarization-only regions and merge
923
+ if unmatched_diarization_segments:
924
+ waveform, sample_rate = torchaudio.load(wav_path)
925
+ extra_segments = []
926
+ for dseg in unmatched_diarization_segments:
927
+ d_start = float(dseg["start"]) # global seconds
928
+ d_end = float(dseg["end"]) # global seconds
929
+ if d_end <= d_start:
930
+ continue
931
+ # Map global time to local file time
932
+ local_start = max(0.0, d_start - float(base_offset_s))
933
+ local_end = max(local_start, d_end - float(base_offset_s))
934
+ start_sample = max(0, int(local_start * sample_rate))
935
+ end_sample = min(waveform.shape[1], int(local_end * sample_rate))
936
+ if end_sample <= start_sample:
937
+ continue
938
+ seg_wav = waveform[:, start_sample:end_sample].contiguous()
939
+ tmp_f = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
940
+ tmp_path = tmp_f.name
941
+ tmp_f.close()
942
+ try:
943
+ torchaudio.save(tmp_path, seg_wav.cpu(), sample_rate)
944
+ seg_transcription, _ = self.transcribe_full_audio(
945
+ tmp_path,
946
+ language=language if language is not None else None,
947
+ translate=translate,
948
+ prompt=prompt,
949
+ batch_size=batch_size,
950
+ base_offset_s=d_start,
951
+ )
952
+ extra_segments.extend(seg_transcription)
953
+ finally:
954
+ try:
955
+ os.unlink(tmp_path)
956
+ except Exception:
957
+ pass
958
+ if extra_segments:
959
+ transcription_results.extend(extra_segments)
960
+ transcription_results.sort(key=lambda s: float(s.get("start", 0.0)))
961
+ # Re-assign speakers on the combined set
962
+ transcription_results, _ = self.assign_speakers_to_transcription(
963
+ transcription_results, diarization_segments
964
+ )
965
 
966
  # Step 5: Group segments if requested
967
  if group_segments: