Spaces:
Runtime error
Runtime error
liuyang
commited on
Commit
·
7bde45c
1
Parent(s):
36812ab
Refactor speaker assignment logic in transcription: Enhanced the `assign_speakers_to_transcription` method to detect unmatched diarization segments and introduced a second pass for splitting segments with speaker changes. Improved handling of speaker transitions and added functionality to re-process unmatched segments.
Browse files
app.py
CHANGED
|
@@ -565,9 +565,13 @@ class WhisperTranscriber:
|
|
| 565 |
return _embedder
|
| 566 |
|
| 567 |
def assign_speakers_to_transcription(self, transcription_results, diarization_segments):
|
| 568 |
-
"""Assign speakers to words and segments based on overlap with diarization segments.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
if not diarization_segments:
|
| 570 |
-
return transcription_results
|
| 571 |
# Helper: find the diarization speaker active at time t, or closest
|
| 572 |
def speaker_at(t: float):
|
| 573 |
for dseg in diarization_segments:
|
|
@@ -607,8 +611,8 @@ class WhisperTranscriber:
|
|
| 607 |
mid = (float(start_t) + float(end_t)) / 2.0
|
| 608 |
return speaker_at(mid)
|
| 609 |
|
|
|
|
| 610 |
for seg in transcription_results:
|
| 611 |
-
# Assign per-word speakers using overlap, then smooth and stabilize boundaries
|
| 612 |
if seg.get("words"):
|
| 613 |
words = seg["words"]
|
| 614 |
# 1) Initial assignment by overlap
|
|
@@ -628,55 +632,165 @@ class WhisperTranscriber:
|
|
| 628 |
smoothed[i] = prev_spk
|
| 629 |
for i in range(len(words)):
|
| 630 |
words[i]["speaker"] = smoothed[i]
|
| 631 |
-
|
| 632 |
-
# 3) Determine dominant speaker by summed word durations
|
| 633 |
-
speaker_dur = {}
|
| 634 |
-
total_word_dur = 0.0
|
| 635 |
-
for w in words:
|
| 636 |
-
dur = max(0.0, float(w["end"]) - float(w["start"]))
|
| 637 |
-
total_word_dur += dur
|
| 638 |
-
spk = w.get("speaker", "SPEAKER_00")
|
| 639 |
-
speaker_dur[spk] = speaker_dur.get(spk, 0.0) + dur
|
| 640 |
-
if speaker_dur:
|
| 641 |
-
dominant_speaker = max(speaker_dur.items(), key=lambda kv: kv[1])[0]
|
| 642 |
-
else:
|
| 643 |
-
dominant_speaker = speaker_at((float(seg["start"]) + float(seg["end"])) / 2.0)
|
| 644 |
-
|
| 645 |
-
# 4) Boundary stabilization: relabel tiny prefix/suffix runs to dominant
|
| 646 |
-
seg_duration = max(1e-6, float(seg["end"]) - float(seg["start"]))
|
| 647 |
-
max_boundary_sec = 0.5 # hard cap for how much to relabel at edges
|
| 648 |
-
max_boundary_frac = 0.2 # or up to 20% of the segment duration
|
| 649 |
-
|
| 650 |
-
# prefix
|
| 651 |
-
prefix_dur = 0.0
|
| 652 |
-
prefix_count = 0
|
| 653 |
-
for w in words:
|
| 654 |
-
if w.get("speaker") == dominant_speaker:
|
| 655 |
-
break
|
| 656 |
-
prefix_dur += max(0.0, float(w["end"]) - float(w["start"]))
|
| 657 |
-
prefix_count += 1
|
| 658 |
-
if prefix_count > 0 and prefix_dur <= min(max_boundary_sec, max_boundary_frac * seg_duration):
|
| 659 |
-
for i in range(prefix_count):
|
| 660 |
-
words[i]["speaker"] = dominant_speaker
|
| 661 |
-
|
| 662 |
-
# suffix
|
| 663 |
-
suffix_dur = 0.0
|
| 664 |
-
suffix_count = 0
|
| 665 |
-
for w in reversed(words):
|
| 666 |
-
if w.get("speaker") == dominant_speaker:
|
| 667 |
-
break
|
| 668 |
-
suffix_dur += max(0.0, float(w["end"]) - float(w["start"]))
|
| 669 |
-
suffix_count += 1
|
| 670 |
-
if suffix_count > 0 and suffix_dur <= min(max_boundary_sec, max_boundary_frac * seg_duration):
|
| 671 |
-
for i in range(len(words) - suffix_count, len(words)):
|
| 672 |
-
words[i]["speaker"] = dominant_speaker
|
| 673 |
-
|
| 674 |
-
# 5) Final segment speaker
|
| 675 |
-
seg["speaker"] = dominant_speaker
|
| 676 |
else:
|
| 677 |
# No word timings: choose by overlap with diarization over the whole segment
|
| 678 |
seg["speaker"] = best_speaker_for_interval(float(seg["start"]), float(seg["end"]))
|
| 679 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 680 |
|
| 681 |
def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0):
|
| 682 |
"""Group consecutive segments from the same speaker"""
|
|
@@ -801,7 +915,53 @@ class WhisperTranscriber:
|
|
| 801 |
)
|
| 802 |
|
| 803 |
# Step 4: Merge diarization into transcription (assign speakers)
|
| 804 |
-
transcription_results = self.assign_speakers_to_transcription(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 805 |
|
| 806 |
# Step 5: Group segments if requested
|
| 807 |
if group_segments:
|
|
|
|
| 565 |
return _embedder
|
| 566 |
|
| 567 |
def assign_speakers_to_transcription(self, transcription_results, diarization_segments):
|
| 568 |
+
"""Assign speakers to words and segments based on overlap with diarization segments.
|
| 569 |
+
|
| 570 |
+
Also detects diarization segments that do not overlap any transcription segment and
|
| 571 |
+
returns them so they can be re-processed (e.g., re-transcribed) later.
|
| 572 |
+
"""
|
| 573 |
if not diarization_segments:
|
| 574 |
+
return transcription_results, []
|
| 575 |
# Helper: find the diarization speaker active at time t, or closest
|
| 576 |
def speaker_at(t: float):
|
| 577 |
for dseg in diarization_segments:
|
|
|
|
| 611 |
mid = (float(start_t) + float(end_t)) / 2.0
|
| 612 |
return speaker_at(mid)
|
| 613 |
|
| 614 |
+
# First pass: assign speakers to words and apply smoothing
|
| 615 |
for seg in transcription_results:
|
|
|
|
| 616 |
if seg.get("words"):
|
| 617 |
words = seg["words"]
|
| 618 |
# 1) Initial assignment by overlap
|
|
|
|
| 632 |
smoothed[i] = prev_spk
|
| 633 |
for i in range(len(words)):
|
| 634 |
words[i]["speaker"] = smoothed[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
else:
|
| 636 |
# No word timings: choose by overlap with diarization over the whole segment
|
| 637 |
seg["speaker"] = best_speaker_for_interval(float(seg["start"]), float(seg["end"]))
|
| 638 |
+
|
| 639 |
+
# Second pass: split segments that have speaker changes within them
|
| 640 |
+
split_segments = []
|
| 641 |
+
for seg in transcription_results:
|
| 642 |
+
words = seg.get("words", [])
|
| 643 |
+
if not words or len(words) <= 1:
|
| 644 |
+
# No words or single word - can't split, assign speaker directly
|
| 645 |
+
if not words:
|
| 646 |
+
seg["speaker"] = best_speaker_for_interval(float(seg["start"]), float(seg["end"]))
|
| 647 |
+
else:
|
| 648 |
+
seg["speaker"] = words[0].get("speaker", "SPEAKER_00")
|
| 649 |
+
split_segments.append(seg)
|
| 650 |
+
continue
|
| 651 |
+
|
| 652 |
+
# Find speaker transition points with minimum duration filter
|
| 653 |
+
current_speaker = words[0].get("speaker", "SPEAKER_00")
|
| 654 |
+
split_points = [0] # Always start with first word
|
| 655 |
+
min_segment_duration = 0.5 # Minimum 0.5 seconds per segment
|
| 656 |
+
|
| 657 |
+
for i in range(1, len(words)):
|
| 658 |
+
word_speaker = words[i].get("speaker", "SPEAKER_00")
|
| 659 |
+
if word_speaker != current_speaker:
|
| 660 |
+
# Check if this would create a segment that's too short
|
| 661 |
+
if split_points:
|
| 662 |
+
last_split = split_points[-1]
|
| 663 |
+
segment_start_time = float(words[last_split]["start"])
|
| 664 |
+
current_word_time = float(words[i-1]["end"])
|
| 665 |
+
segment_duration = current_word_time - segment_start_time
|
| 666 |
+
|
| 667 |
+
# Only split if the previous segment would be long enough
|
| 668 |
+
if segment_duration >= min_segment_duration:
|
| 669 |
+
split_points.append(i)
|
| 670 |
+
current_speaker = word_speaker
|
| 671 |
+
# If too short, continue without splitting (speaker will be resolved by dominant speaker logic)
|
| 672 |
+
else:
|
| 673 |
+
split_points.append(i)
|
| 674 |
+
current_speaker = word_speaker
|
| 675 |
+
|
| 676 |
+
split_points.append(len(words)) # End point
|
| 677 |
+
|
| 678 |
+
# Create sub-segments if we found speaker changes
|
| 679 |
+
if len(split_points) <= 2:
|
| 680 |
+
# No splits needed - process as single segment
|
| 681 |
+
self._assign_dominant_speaker_to_segment(seg, speaker_at, best_speaker_for_interval)
|
| 682 |
+
split_segments.append(seg)
|
| 683 |
+
else:
|
| 684 |
+
# Split into multiple segments
|
| 685 |
+
for i in range(len(split_points) - 1):
|
| 686 |
+
start_idx = split_points[i]
|
| 687 |
+
end_idx = split_points[i + 1]
|
| 688 |
+
|
| 689 |
+
if end_idx <= start_idx:
|
| 690 |
+
continue
|
| 691 |
+
|
| 692 |
+
subseg_words = words[start_idx:end_idx]
|
| 693 |
+
if not subseg_words:
|
| 694 |
+
continue
|
| 695 |
+
|
| 696 |
+
# Calculate segment timing and text from words
|
| 697 |
+
subseg_start = float(subseg_words[0]["start"])
|
| 698 |
+
subseg_end = float(subseg_words[-1]["end"])
|
| 699 |
+
subseg_text = " ".join(w.get("word", "").strip() for w in subseg_words if w.get("word", "").strip())
|
| 700 |
+
|
| 701 |
+
# Create new sub-segment
|
| 702 |
+
new_seg = {
|
| 703 |
+
"start": subseg_start,
|
| 704 |
+
"end": subseg_end,
|
| 705 |
+
"text": subseg_text,
|
| 706 |
+
"words": subseg_words,
|
| 707 |
+
"duration": subseg_end - subseg_start,
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
# Copy over other fields from original segment if they exist
|
| 711 |
+
for key in ["avg_logprob"]:
|
| 712 |
+
if key in seg:
|
| 713 |
+
new_seg[key] = seg[key]
|
| 714 |
+
|
| 715 |
+
# Assign dominant speaker to this sub-segment
|
| 716 |
+
self._assign_dominant_speaker_to_segment(new_seg, speaker_at, best_speaker_for_interval)
|
| 717 |
+
split_segments.append(new_seg)
|
| 718 |
+
|
| 719 |
+
# Update transcription_results with split segments
|
| 720 |
+
transcription_results = split_segments
|
| 721 |
+
|
| 722 |
+
# Identify diarization segments that have no overlapping transcription segments
|
| 723 |
+
unmatched_diarization_segments = []
|
| 724 |
+
for dseg in diarization_segments:
|
| 725 |
+
d_start = float(dseg["start"])
|
| 726 |
+
d_end = float(dseg["end"])
|
| 727 |
+
has_overlap = False
|
| 728 |
+
for seg in transcription_results:
|
| 729 |
+
if interval_overlap(d_start, d_end, float(seg["start"]), float(seg["end"])) > 1e-6:
|
| 730 |
+
has_overlap = True
|
| 731 |
+
break
|
| 732 |
+
if not has_overlap:
|
| 733 |
+
unmatched_diarization_segments.append({
|
| 734 |
+
"start": d_start,
|
| 735 |
+
"end": d_end,
|
| 736 |
+
"speaker": dseg["speaker"],
|
| 737 |
+
})
|
| 738 |
+
|
| 739 |
+
return transcription_results, unmatched_diarization_segments
|
| 740 |
+
|
| 741 |
+
def _assign_dominant_speaker_to_segment(self, seg, speaker_at_func, best_speaker_for_interval_func):
|
| 742 |
+
"""Assign dominant speaker to a segment based on word durations and boundary stabilization."""
|
| 743 |
+
words = seg.get("words", [])
|
| 744 |
+
if not words:
|
| 745 |
+
# No words: use segment-level overlap
|
| 746 |
+
seg["speaker"] = best_speaker_for_interval_func(float(seg["start"]), float(seg["end"]))
|
| 747 |
+
return
|
| 748 |
+
|
| 749 |
+
# 1) Determine dominant speaker by summed word durations
|
| 750 |
+
speaker_dur = {}
|
| 751 |
+
total_word_dur = 0.0
|
| 752 |
+
for w in words:
|
| 753 |
+
dur = max(0.0, float(w["end"]) - float(w["start"]))
|
| 754 |
+
total_word_dur += dur
|
| 755 |
+
spk = w.get("speaker", "SPEAKER_00")
|
| 756 |
+
speaker_dur[spk] = speaker_dur.get(spk, 0.0) + dur
|
| 757 |
+
|
| 758 |
+
if speaker_dur:
|
| 759 |
+
dominant_speaker = max(speaker_dur.items(), key=lambda kv: kv[1])[0]
|
| 760 |
+
else:
|
| 761 |
+
dominant_speaker = speaker_at_func((float(seg["start"]) + float(seg["end"])) / 2.0)
|
| 762 |
+
|
| 763 |
+
# 2) Boundary stabilization: relabel tiny prefix/suffix runs to dominant
|
| 764 |
+
seg_duration = max(1e-6, float(seg["end"]) - float(seg["start"]))
|
| 765 |
+
max_boundary_sec = 0.5 # hard cap for how much to relabel at edges
|
| 766 |
+
max_boundary_frac = 0.2 # or up to 20% of the segment duration
|
| 767 |
+
|
| 768 |
+
# prefix
|
| 769 |
+
prefix_dur = 0.0
|
| 770 |
+
prefix_count = 0
|
| 771 |
+
for w in words:
|
| 772 |
+
if w.get("speaker") == dominant_speaker:
|
| 773 |
+
break
|
| 774 |
+
prefix_dur += max(0.0, float(w["end"]) - float(w["start"]))
|
| 775 |
+
prefix_count += 1
|
| 776 |
+
if prefix_count > 0 and prefix_dur <= min(max_boundary_sec, max_boundary_frac * seg_duration):
|
| 777 |
+
for i in range(prefix_count):
|
| 778 |
+
words[i]["speaker"] = dominant_speaker
|
| 779 |
+
|
| 780 |
+
# suffix
|
| 781 |
+
suffix_dur = 0.0
|
| 782 |
+
suffix_count = 0
|
| 783 |
+
for w in reversed(words):
|
| 784 |
+
if w.get("speaker") == dominant_speaker:
|
| 785 |
+
break
|
| 786 |
+
suffix_dur += max(0.0, float(w["end"]) - float(w["start"]))
|
| 787 |
+
suffix_count += 1
|
| 788 |
+
if suffix_count > 0 and suffix_dur <= min(max_boundary_sec, max_boundary_frac * seg_duration):
|
| 789 |
+
for i in range(len(words) - suffix_count, len(words)):
|
| 790 |
+
words[i]["speaker"] = dominant_speaker
|
| 791 |
+
|
| 792 |
+
# 3) Final segment speaker
|
| 793 |
+
seg["speaker"] = dominant_speaker
|
| 794 |
|
| 795 |
def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0):
|
| 796 |
"""Group consecutive segments from the same speaker"""
|
|
|
|
| 915 |
)
|
| 916 |
|
| 917 |
# Step 4: Merge diarization into transcription (assign speakers)
|
| 918 |
+
transcription_results, unmatched_diarization_segments = self.assign_speakers_to_transcription(
|
| 919 |
+
transcription_results, diarization_segments
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
# Step 4.1: Transcribe diarization-only regions and merge
|
| 923 |
+
if unmatched_diarization_segments:
|
| 924 |
+
waveform, sample_rate = torchaudio.load(wav_path)
|
| 925 |
+
extra_segments = []
|
| 926 |
+
for dseg in unmatched_diarization_segments:
|
| 927 |
+
d_start = float(dseg["start"]) # global seconds
|
| 928 |
+
d_end = float(dseg["end"]) # global seconds
|
| 929 |
+
if d_end <= d_start:
|
| 930 |
+
continue
|
| 931 |
+
# Map global time to local file time
|
| 932 |
+
local_start = max(0.0, d_start - float(base_offset_s))
|
| 933 |
+
local_end = max(local_start, d_end - float(base_offset_s))
|
| 934 |
+
start_sample = max(0, int(local_start * sample_rate))
|
| 935 |
+
end_sample = min(waveform.shape[1], int(local_end * sample_rate))
|
| 936 |
+
if end_sample <= start_sample:
|
| 937 |
+
continue
|
| 938 |
+
seg_wav = waveform[:, start_sample:end_sample].contiguous()
|
| 939 |
+
tmp_f = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
| 940 |
+
tmp_path = tmp_f.name
|
| 941 |
+
tmp_f.close()
|
| 942 |
+
try:
|
| 943 |
+
torchaudio.save(tmp_path, seg_wav.cpu(), sample_rate)
|
| 944 |
+
seg_transcription, _ = self.transcribe_full_audio(
|
| 945 |
+
tmp_path,
|
| 946 |
+
language=language if language is not None else None,
|
| 947 |
+
translate=translate,
|
| 948 |
+
prompt=prompt,
|
| 949 |
+
batch_size=batch_size,
|
| 950 |
+
base_offset_s=d_start,
|
| 951 |
+
)
|
| 952 |
+
extra_segments.extend(seg_transcription)
|
| 953 |
+
finally:
|
| 954 |
+
try:
|
| 955 |
+
os.unlink(tmp_path)
|
| 956 |
+
except Exception:
|
| 957 |
+
pass
|
| 958 |
+
if extra_segments:
|
| 959 |
+
transcription_results.extend(extra_segments)
|
| 960 |
+
transcription_results.sort(key=lambda s: float(s.get("start", 0.0)))
|
| 961 |
+
# Re-assign speakers on the combined set
|
| 962 |
+
transcription_results, _ = self.assign_speakers_to_transcription(
|
| 963 |
+
transcription_results, diarization_segments
|
| 964 |
+
)
|
| 965 |
|
| 966 |
# Step 5: Group segments if requested
|
| 967 |
if group_segments:
|