|
|
import torch |
|
|
import torchaudio |
|
|
import numpy as np |
|
|
import os |
|
|
import warnings |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple |
|
|
import argparse |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
import gc |
|
|
import logging |
|
|
|
|
|
verbose_output = True |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="pyannote.audio.models.blocks.pooling") |
|
|
warnings.filterwarnings("ignore", message=".*TensorFloat-32.*", category=UserWarning) |
|
|
warnings.filterwarnings("ignore", message=".*std\\(\\): degrees of freedom.*", category=UserWarning) |
|
|
warnings.filterwarnings("ignore", message=".*speechbrain.pretrained.*was deprecated.*", category=UserWarning) |
|
|
warnings.filterwarnings("ignore", message=".*Module 'speechbrain.pretrained'.*", category=UserWarning) |
|
|
|
|
|
|
|
|
os.environ["SB_LOG_LEVEL"] = "WARNING" |
|
|
import speechbrain |
|
|
|
|
|
def xprint(t = None): |
|
|
if verbose_output: |
|
|
print(t) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
try: |
|
|
from pyannote.audio import Pipeline |
|
|
PYANNOTE_AVAILABLE = True |
|
|
except ImportError: |
|
|
PYANNOTE_AVAILABLE = False |
|
|
print("Install: pip install pyannote.audio") |
|
|
|
|
|
|
|
|
class OptimizedPyannote31SpeakerSeparator: |
|
|
def __init__(self, hf_token: str = None, local_model_path: str = None, |
|
|
vad_onset: float = 0.2, vad_offset: float = 0.8): |
|
|
""" |
|
|
Initialize with Pyannote 3.1 pipeline with tunable VAD sensitivity. |
|
|
""" |
|
|
embedding_path = "ckpts/pyannote/pyannote_model_wespeaker-voxceleb-resnet34-LM.bin" |
|
|
segmentation_path = "ckpts/pyannote/pytorch_model_segmentation-3.0.bin" |
|
|
|
|
|
|
|
|
xprint(f"Loading segmentation model from: {segmentation_path}") |
|
|
xprint(f"Loading embedding model from: {embedding_path}") |
|
|
|
|
|
try: |
|
|
from pyannote.audio import Model |
|
|
from pyannote.audio.pipelines import SpeakerDiarization |
|
|
|
|
|
|
|
|
segmentation_model = Model.from_pretrained(segmentation_path) |
|
|
embedding_model = Model.from_pretrained(embedding_path) |
|
|
xprint("Models loaded successfully!") |
|
|
|
|
|
|
|
|
self.pipeline = SpeakerDiarization( |
|
|
segmentation=segmentation_model, |
|
|
embedding=embedding_model, |
|
|
clustering='AgglomerativeClustering' |
|
|
) |
|
|
|
|
|
|
|
|
self.pipeline.instantiate({ |
|
|
'clustering': { |
|
|
'method': 'centroid', |
|
|
'min_cluster_size': 12, |
|
|
'threshold': 0.7045654963945799 |
|
|
}, |
|
|
'segmentation': { |
|
|
'min_duration_off': 0.0 |
|
|
} |
|
|
}) |
|
|
xprint("Pipeline instantiated successfully!") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
xprint("CUDA available, moving pipeline to GPU...") |
|
|
self.pipeline.to(torch.device("cuda")) |
|
|
else: |
|
|
xprint("CUDA not available, using CPU...") |
|
|
|
|
|
except Exception as e: |
|
|
xprint(f"Error loading pipeline: {e}") |
|
|
xprint(f"Error type: {type(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise |
|
|
|
|
|
|
|
|
self.hf_token = hf_token |
|
|
self._overlap_pipeline = None |
|
|
|
|
|
def separate_audio(self, audio_path: str, output1, output2, audio_original_path: str = None ) -> Dict[str, str]: |
|
|
"""Optimized main separation function with memory management.""" |
|
|
xprint("Starting optimized audio separation...") |
|
|
self._current_audio_path = os.path.abspath(audio_path) |
|
|
|
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
|
|
|
waveform, sample_rate = self.load_audio(audio_path) |
|
|
|
|
|
|
|
|
diarization = self.perform_optimized_diarization(audio_path) |
|
|
|
|
|
|
|
|
masks = self.create_optimized_speaker_masks(diarization, waveform.shape[1], sample_rate) |
|
|
|
|
|
|
|
|
final_masks = self.apply_optimized_background_preservation(masks, waveform.shape[1]) |
|
|
|
|
|
|
|
|
del masks |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
if audio_original_path is None: |
|
|
waveform_original = waveform |
|
|
else: |
|
|
waveform_original, sample_rate = self.load_audio(audio_original_path) |
|
|
output_paths = self._save_outputs_optimized(waveform_original, final_masks, sample_rate, audio_path, output1, output2) |
|
|
|
|
|
return output_paths |
|
|
|
|
|
def _extract_both_speaking_regions( |
|
|
self, |
|
|
diarization, |
|
|
audio_length: int, |
|
|
sample_rate: int |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Detect regions where β₯2 speakers talk simultaneously |
|
|
using pyannote/overlapped-speech-detection. |
|
|
Falls back to manual pair-wise detection if the model |
|
|
is unavailable. |
|
|
""" |
|
|
xprint("Extracting overlap with dedicated pipelineβ¦") |
|
|
both_speaking_mask = np.zeros(audio_length, dtype=bool) |
|
|
|
|
|
|
|
|
|
|
|
overlap_pipeline = None |
|
|
|
|
|
|
|
|
|
|
|
audio_uri = getattr(self, "_current_audio_path", None) \ |
|
|
or getattr(diarization, "uri", None) |
|
|
if overlap_pipeline and audio_uri: |
|
|
try: |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
overlap_annotation = overlap_pipeline(audio_uri) |
|
|
|
|
|
for seg in overlap_annotation.get_timeline().support(): |
|
|
s = max(0, int(seg.start * sample_rate)) |
|
|
e = min(audio_length, int(seg.end * sample_rate)) |
|
|
if s < e: |
|
|
both_speaking_mask[s:e] = True |
|
|
t = np.sum(both_speaking_mask) / sample_rate |
|
|
xprint(f" Found {t:.1f}s of overlapped speech (model) ") |
|
|
return both_speaking_mask |
|
|
except Exception as e: |
|
|
xprint(f" β Overlap model failed: {e}") |
|
|
|
|
|
|
|
|
xprint(" Falling back to manual overlap detectionβ¦") |
|
|
timeline_tracks = list(diarization.itertracks(yield_label=True)) |
|
|
for i, (turn1, _, spk1) in enumerate(timeline_tracks): |
|
|
for j, (turn2, _, spk2) in enumerate(timeline_tracks): |
|
|
if i >= j or spk1 == spk2: |
|
|
continue |
|
|
o_start, o_end = max(turn1.start, turn2.start), min(turn1.end, turn2.end) |
|
|
if o_start < o_end: |
|
|
s = max(0, int(o_start * sample_rate)) |
|
|
e = min(audio_length, int(o_end * sample_rate)) |
|
|
if s < e: |
|
|
both_speaking_mask[s:e] = True |
|
|
t = np.sum(both_speaking_mask) / sample_rate |
|
|
xprint(f" Found {t:.1f}s of overlapped speech (manual) ") |
|
|
return both_speaking_mask |
|
|
|
|
|
def _configure_vad(self, vad_onset: float, vad_offset: float): |
|
|
"""Configure VAD parameters efficiently.""" |
|
|
xprint("Applying more sensitive VAD parameters...") |
|
|
try: |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
if hasattr(self.pipeline, '_vad'): |
|
|
self.pipeline._vad.instantiate({ |
|
|
"onset": vad_onset, |
|
|
"offset": vad_offset, |
|
|
"min_duration_on": 0.1, |
|
|
"min_duration_off": 0.1, |
|
|
"pad_onset": 0.1, |
|
|
"pad_offset": 0.1, |
|
|
}) |
|
|
xprint(f"β VAD parameters updated: onset={vad_onset}, offset={vad_offset}") |
|
|
else: |
|
|
xprint("β Could not access VAD component directly") |
|
|
except Exception as e: |
|
|
xprint(f"β Could not modify VAD parameters: {e}") |
|
|
|
|
|
def _get_overlap_pipeline(self): |
|
|
""" |
|
|
Build a pyannote-3-native OverlappedSpeechDetection pipeline. |
|
|
|
|
|
β’ uses the open-licence `pyannote/segmentation-3.0` checkpoint |
|
|
β’ only `min_duration_on/off` can be tuned (API 3.x) |
|
|
""" |
|
|
if self._overlap_pipeline is not None: |
|
|
return None if self._overlap_pipeline is False else self._overlap_pipeline |
|
|
|
|
|
try: |
|
|
from pyannote.audio.pipelines import OverlappedSpeechDetection |
|
|
|
|
|
xprint("Building OverlappedSpeechDetection with segmentation-3.0β¦") |
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
|
|
|
ods = OverlappedSpeechDetection( |
|
|
segmentation="pyannote/segmentation-3.0" |
|
|
) |
|
|
|
|
|
|
|
|
ods.instantiate({ |
|
|
"min_duration_on": 0.06, |
|
|
"min_duration_off": 0.10, |
|
|
}) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
ods.to(torch.device("cuda")) |
|
|
|
|
|
self._overlap_pipeline = ods |
|
|
xprint("β Overlap pipeline ready (segmentation-3.0)") |
|
|
return ods |
|
|
|
|
|
except Exception as e: |
|
|
xprint(f"β Could not build overlap pipeline ({e}). " |
|
|
"Falling back to manual pair-wise detection.") |
|
|
self._overlap_pipeline = False |
|
|
return None |
|
|
|
|
|
def _xprint_setup_instructions(self): |
|
|
"""xprint setup instructions.""" |
|
|
xprint("\nTo use Pyannote 3.1:") |
|
|
xprint("1. Get token: https://huggingface.co/settings/tokens") |
|
|
xprint("2. Accept terms: https://huggingface.co/pyannote/speaker-diarization-3.1") |
|
|
xprint("3. Run with: --token YOUR_TOKEN") |
|
|
|
|
|
def load_audio(self, audio_path: str) -> Tuple[torch.Tensor, int]: |
|
|
"""Load and preprocess audio efficiently.""" |
|
|
xprint(f"Loading audio: {audio_path}") |
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
|
|
|
|
|
|
|
if waveform.shape[0] > 1: |
|
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
|
|
xprint(f"Audio: {waveform.shape[1]} samples at {sample_rate}Hz") |
|
|
return waveform, sample_rate |
|
|
|
|
|
def perform_optimized_diarization(self, audio_path: str) -> object: |
|
|
""" |
|
|
Optimized diarization with efficient parameter testing. |
|
|
""" |
|
|
xprint("Running optimized Pyannote 3.1 diarization...") |
|
|
|
|
|
|
|
|
strategies = [ |
|
|
{"min_speakers": 2, "max_speakers": 2}, |
|
|
{"num_speakers": 2}, |
|
|
{"min_speakers": 2, "max_speakers": 3}, |
|
|
{"min_speakers": 1, "max_speakers": 2}, |
|
|
{"min_speakers": 2, "max_speakers": 4}, |
|
|
{} |
|
|
] |
|
|
|
|
|
for i, params in enumerate(strategies): |
|
|
try: |
|
|
xprint(f"Strategy {i+1}: {params}") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
diarization = self.pipeline(audio_path, **params) |
|
|
|
|
|
speakers = list(diarization.labels()) |
|
|
speaker_count = len(speakers) |
|
|
|
|
|
xprint(f" β Detected {speaker_count} speakers: {speakers}") |
|
|
|
|
|
|
|
|
if speaker_count >= 2: |
|
|
xprint(f"β Success with strategy {i+1}! Using {speaker_count} speakers") |
|
|
return diarization |
|
|
elif speaker_count == 1 and i == 0: |
|
|
|
|
|
fallback_diarization = diarization |
|
|
|
|
|
except Exception as e: |
|
|
xprint(f" Strategy {i+1} failed: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
if 'fallback_diarization' in locals(): |
|
|
xprint("Attempting aggressive clustering for single speaker...") |
|
|
try: |
|
|
aggressive_diarization = self._try_aggressive_clustering(audio_path) |
|
|
if aggressive_diarization and len(list(aggressive_diarization.labels())) >= 2: |
|
|
return aggressive_diarization |
|
|
except Exception as e: |
|
|
xprint(f"Aggressive clustering failed: {e}") |
|
|
|
|
|
xprint("Using single speaker result") |
|
|
return fallback_diarization |
|
|
|
|
|
|
|
|
xprint("Last resort: running without constraints...") |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
return self.pipeline(audio_path) |
|
|
|
|
|
def _try_aggressive_clustering(self, audio_path: str) -> object: |
|
|
"""Try aggressive clustering parameters.""" |
|
|
try: |
|
|
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization |
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
|
|
|
temp_pipeline = SpeakerDiarization( |
|
|
segmentation=self.pipeline.segmentation, |
|
|
embedding=self.pipeline.embedding, |
|
|
clustering="AgglomerativeClustering" |
|
|
) |
|
|
|
|
|
temp_pipeline.instantiate({ |
|
|
"clustering": { |
|
|
"method": "centroid", |
|
|
"min_cluster_size": 1, |
|
|
"threshold": 0.1, |
|
|
}, |
|
|
"segmentation": { |
|
|
"min_duration_off": 0.0, |
|
|
"min_duration_on": 0.1, |
|
|
} |
|
|
}) |
|
|
|
|
|
return temp_pipeline(audio_path, min_speakers=2) |
|
|
|
|
|
except Exception as e: |
|
|
xprint(f"Aggressive clustering setup failed: {e}") |
|
|
return None |
|
|
|
|
|
def create_optimized_speaker_masks(self, diarization, audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]: |
|
|
"""Optimized mask creation using vectorized operations.""" |
|
|
xprint("Creating optimized speaker masks...") |
|
|
|
|
|
speakers = list(diarization.labels()) |
|
|
xprint(f"Processing speakers: {speakers}") |
|
|
|
|
|
|
|
|
if len(speakers) == 0: |
|
|
xprint("β No speakers detected, creating dummy masks") |
|
|
return self._create_dummy_masks(audio_length) |
|
|
|
|
|
if len(speakers) == 1: |
|
|
xprint("β Only 1 speaker detected, creating temporal split") |
|
|
return self._create_optimized_temporal_split(diarization, audio_length, sample_rate) |
|
|
|
|
|
|
|
|
both_speaking_regions = self._extract_both_speaking_regions(diarization, audio_length, sample_rate) |
|
|
|
|
|
|
|
|
masks = {} |
|
|
|
|
|
|
|
|
for speaker in speakers: |
|
|
|
|
|
segments = [] |
|
|
speaker_timeline = diarization.label_timeline(speaker) |
|
|
for segment in speaker_timeline: |
|
|
start_sample = max(0, int(segment.start * sample_rate)) |
|
|
end_sample = min(audio_length, int(segment.end * sample_rate)) |
|
|
if start_sample < end_sample: |
|
|
segments.append((start_sample, end_sample)) |
|
|
|
|
|
|
|
|
if segments: |
|
|
mask = self._create_mask_vectorized(segments, audio_length) |
|
|
masks[speaker] = mask |
|
|
speaking_time = np.sum(mask) / sample_rate |
|
|
xprint(f" {speaker}: {speaking_time:.1f}s speaking time") |
|
|
else: |
|
|
masks[speaker] = np.zeros(audio_length, dtype=np.float32) |
|
|
|
|
|
|
|
|
self._both_speaking_regions = both_speaking_regions |
|
|
|
|
|
return masks |
|
|
|
|
|
def _create_mask_vectorized(self, segments: List[Tuple[int, int]], audio_length: int) -> np.ndarray: |
|
|
"""Create mask using vectorized operations.""" |
|
|
mask = np.zeros(audio_length, dtype=np.float32) |
|
|
|
|
|
if not segments: |
|
|
return mask |
|
|
|
|
|
|
|
|
segments_array = np.array(segments) |
|
|
starts = segments_array[:, 0] |
|
|
ends = segments_array[:, 1] |
|
|
|
|
|
|
|
|
for start, end in zip(starts, ends): |
|
|
mask[start:end] = 1.0 |
|
|
|
|
|
return mask |
|
|
|
|
|
def _create_dummy_masks(self, audio_length: int) -> Dict[str, np.ndarray]: |
|
|
"""Create dummy masks for edge cases.""" |
|
|
return { |
|
|
"SPEAKER_00": np.ones(audio_length, dtype=np.float32) * 0.5, |
|
|
"SPEAKER_01": np.ones(audio_length, dtype=np.float32) * 0.5 |
|
|
} |
|
|
|
|
|
def _create_optimized_temporal_split(self, diarization, audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]: |
|
|
"""Optimized temporal split with vectorized operations.""" |
|
|
xprint("Creating optimized temporal split...") |
|
|
|
|
|
|
|
|
segments = [] |
|
|
for turn, _, speaker in diarization.itertracks(yield_label=True): |
|
|
segments.append((turn.start, turn.end)) |
|
|
|
|
|
segments.sort() |
|
|
xprint(f"Found {len(segments)} speech segments") |
|
|
|
|
|
if len(segments) <= 1: |
|
|
|
|
|
return self._create_simple_split(audio_length) |
|
|
|
|
|
|
|
|
segment_array = np.array(segments) |
|
|
gaps = segment_array[1:, 0] - segment_array[:-1, 1] |
|
|
|
|
|
if len(gaps) > 0: |
|
|
longest_gap_idx = np.argmax(gaps) |
|
|
longest_gap_duration = gaps[longest_gap_idx] |
|
|
|
|
|
xprint(f"Longest gap: {longest_gap_duration:.1f}s after segment {longest_gap_idx+1}") |
|
|
|
|
|
if longest_gap_duration > 1.0: |
|
|
|
|
|
split_point = longest_gap_idx + 1 |
|
|
xprint(f"Splitting at natural break: segments 1-{split_point} vs {split_point+1}-{len(segments)}") |
|
|
|
|
|
return self._create_split_masks(segments, split_point, audio_length, sample_rate) |
|
|
|
|
|
|
|
|
xprint("Using alternating assignment...") |
|
|
return self._create_alternating_masks(segments, audio_length, sample_rate) |
|
|
|
|
|
def _create_simple_split(self, audio_length: int) -> Dict[str, np.ndarray]: |
|
|
"""Simple temporal split in half.""" |
|
|
mid_point = audio_length // 2 |
|
|
masks = { |
|
|
"SPEAKER_00": np.zeros(audio_length, dtype=np.float32), |
|
|
"SPEAKER_01": np.zeros(audio_length, dtype=np.float32) |
|
|
} |
|
|
masks["SPEAKER_00"][:mid_point] = 1.0 |
|
|
masks["SPEAKER_01"][mid_point:] = 1.0 |
|
|
return masks |
|
|
|
|
|
def _create_split_masks(self, segments: List[Tuple[float, float]], split_point: int, |
|
|
audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]: |
|
|
"""Create masks with split at specific point.""" |
|
|
masks = { |
|
|
"SPEAKER_00": np.zeros(audio_length, dtype=np.float32), |
|
|
"SPEAKER_01": np.zeros(audio_length, dtype=np.float32) |
|
|
} |
|
|
|
|
|
|
|
|
for i, (start_time, end_time) in enumerate(segments): |
|
|
start_sample = max(0, int(start_time * sample_rate)) |
|
|
end_sample = min(audio_length, int(end_time * sample_rate)) |
|
|
|
|
|
if start_sample < end_sample: |
|
|
speaker_key = "SPEAKER_00" if i < split_point else "SPEAKER_01" |
|
|
masks[speaker_key][start_sample:end_sample] = 1.0 |
|
|
|
|
|
return masks |
|
|
|
|
|
def _create_alternating_masks(self, segments: List[Tuple[float, float]], |
|
|
audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]: |
|
|
"""Create masks with alternating assignment.""" |
|
|
masks = { |
|
|
"SPEAKER_00": np.zeros(audio_length, dtype=np.float32), |
|
|
"SPEAKER_01": np.zeros(audio_length, dtype=np.float32) |
|
|
} |
|
|
|
|
|
for i, (start_time, end_time) in enumerate(segments): |
|
|
start_sample = max(0, int(start_time * sample_rate)) |
|
|
end_sample = min(audio_length, int(end_time * sample_rate)) |
|
|
|
|
|
if start_sample < end_sample: |
|
|
speaker_key = f"SPEAKER_0{i % 2}" |
|
|
masks[speaker_key][start_sample:end_sample] = 1.0 |
|
|
|
|
|
return masks |
|
|
|
|
|
def apply_optimized_background_preservation(self, masks: Dict[str, np.ndarray], |
|
|
audio_length: int) -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Heavily optimized background preservation using pure vectorized operations. |
|
|
""" |
|
|
xprint("Applying optimized voice separation logic...") |
|
|
|
|
|
|
|
|
speaker_keys = self._get_top_speakers(masks, audio_length) |
|
|
|
|
|
|
|
|
final_masks = { |
|
|
speaker: np.zeros(audio_length, dtype=np.float32) |
|
|
for speaker in speaker_keys |
|
|
} |
|
|
|
|
|
|
|
|
active_0 = masks.get(speaker_keys[0], np.zeros(audio_length)) > 0.5 |
|
|
active_1 = masks.get(speaker_keys[1], np.zeros(audio_length)) > 0.5 |
|
|
|
|
|
|
|
|
both_active = active_0 & active_1 |
|
|
only_0 = active_0 & ~active_1 |
|
|
only_1 = ~active_0 & active_1 |
|
|
neither = ~active_0 & ~active_1 |
|
|
|
|
|
|
|
|
final_masks[speaker_keys[0]][both_active] = 1.0 |
|
|
final_masks[speaker_keys[1]][both_active] = 1.0 |
|
|
|
|
|
final_masks[speaker_keys[0]][only_0] = 1.0 |
|
|
final_masks[speaker_keys[1]][only_0] = 0.0 |
|
|
|
|
|
final_masks[speaker_keys[0]][only_1] = 0.0 |
|
|
final_masks[speaker_keys[1]][only_1] = 1.0 |
|
|
|
|
|
|
|
|
if np.any(neither): |
|
|
ambiguous_assignments = self._compute_ambiguous_assignments_vectorized( |
|
|
masks, speaker_keys, neither, audio_length |
|
|
) |
|
|
|
|
|
|
|
|
final_masks[speaker_keys[0]][neither] = (ambiguous_assignments == 0).astype(np.float32) * 0.5 |
|
|
final_masks[speaker_keys[1]][neither] = (ambiguous_assignments == 1).astype(np.float32) * 0.5 |
|
|
|
|
|
|
|
|
sample_rate = 16000 |
|
|
xprint(f" Both speaking clearly: {np.sum(both_active)/sample_rate:.1f}s") |
|
|
xprint(f" {speaker_keys[0]} only: {np.sum(only_0)/sample_rate:.1f}s") |
|
|
xprint(f" {speaker_keys[1]} only: {np.sum(only_1)/sample_rate:.1f}s") |
|
|
xprint(f" Ambiguous (assigned): {np.sum(neither)/sample_rate:.1f}s") |
|
|
|
|
|
|
|
|
final_masks = self._apply_minimum_duration_smoothing(final_masks, sample_rate) |
|
|
|
|
|
return final_masks |
|
|
|
|
|
def _get_top_speakers(self, masks: Dict[str, np.ndarray], audio_length: int) -> List[str]: |
|
|
"""Get top 2 speakers by speaking time.""" |
|
|
speaker_keys = list(masks.keys()) |
|
|
|
|
|
if len(speaker_keys) > 2: |
|
|
|
|
|
speaking_times = {k: np.sum(v) for k, v in masks.items()} |
|
|
speaker_keys = sorted(speaking_times.keys(), key=lambda x: speaking_times[x], reverse=True)[:2] |
|
|
xprint(f"Keeping top 2 speakers: {speaker_keys}") |
|
|
elif len(speaker_keys) == 1: |
|
|
speaker_keys.append("SPEAKER_SILENT") |
|
|
|
|
|
return speaker_keys |
|
|
|
|
|
def _compute_ambiguous_assignments_vectorized(self, masks: Dict[str, np.ndarray], |
|
|
speaker_keys: List[str], |
|
|
ambiguous_mask: np.ndarray, |
|
|
audio_length: int) -> np.ndarray: |
|
|
"""Compute speaker assignments for ambiguous regions using vectorized operations.""" |
|
|
ambiguous_indices = np.where(ambiguous_mask)[0] |
|
|
|
|
|
if len(ambiguous_indices) == 0: |
|
|
return np.array([]) |
|
|
|
|
|
|
|
|
speaker_segments = {} |
|
|
for speaker in speaker_keys: |
|
|
if speaker in masks and speaker != "SPEAKER_SILENT": |
|
|
mask = masks[speaker] > 0.5 |
|
|
|
|
|
diff = np.diff(np.concatenate(([False], mask, [False])).astype(int)) |
|
|
starts = np.where(diff == 1)[0] |
|
|
ends = np.where(diff == -1)[0] |
|
|
speaker_segments[speaker] = np.column_stack([starts, ends]) |
|
|
else: |
|
|
speaker_segments[speaker] = np.array([]).reshape(0, 2) |
|
|
|
|
|
|
|
|
distances = {} |
|
|
for speaker in speaker_keys: |
|
|
segments = speaker_segments[speaker] |
|
|
if len(segments) == 0: |
|
|
distances[speaker] = np.full(len(ambiguous_indices), np.inf) |
|
|
else: |
|
|
|
|
|
distances[speaker] = self._compute_distances_to_segments(ambiguous_indices, segments) |
|
|
|
|
|
|
|
|
assignments = self._assign_based_on_distance( |
|
|
distances, speaker_keys, ambiguous_indices, audio_length |
|
|
) |
|
|
|
|
|
return assignments |
|
|
|
|
|
def _apply_minimum_duration_smoothing(self, masks: Dict[str, np.ndarray], |
|
|
sample_rate: int, min_duration_ms: int = 600) -> Dict[str, np.ndarray]: |
|
|
""" |
|
|
Apply minimum duration smoothing with STRICT timer enforcement. |
|
|
Uses original both-speaking regions from diarization. |
|
|
""" |
|
|
xprint(f"Applying STRICT minimum duration smoothing ({min_duration_ms}ms)...") |
|
|
|
|
|
min_samples = int(min_duration_ms * sample_rate / 1000) |
|
|
speaker_keys = list(masks.keys()) |
|
|
|
|
|
if len(speaker_keys) != 2: |
|
|
return masks |
|
|
|
|
|
mask0 = masks[speaker_keys[0]] |
|
|
mask1 = masks[speaker_keys[1]] |
|
|
|
|
|
|
|
|
both_speaking_original = getattr(self, '_both_speaking_regions', np.zeros(len(mask0), dtype=bool)) |
|
|
|
|
|
|
|
|
ambiguous_original = (mask0 < 0.3) & (mask1 < 0.3) & ~both_speaking_original |
|
|
|
|
|
|
|
|
remaining_mask = ~both_speaking_original & ~ambiguous_original |
|
|
speaker0_dominant = (mask0 > mask1) & remaining_mask |
|
|
speaker1_dominant = (mask1 > mask0) & remaining_mask |
|
|
|
|
|
|
|
|
|
|
|
preference_signal = np.full(len(mask0), -1, dtype=int) |
|
|
preference_signal[speaker0_dominant] = 0 |
|
|
preference_signal[speaker1_dominant] = 1 |
|
|
preference_signal[both_speaking_original] = 2 |
|
|
|
|
|
|
|
|
smoothed_assignment = np.full(len(mask0), -1, dtype=int) |
|
|
corrections = 0 |
|
|
|
|
|
|
|
|
current_state = -1 |
|
|
samples_remaining = 0 |
|
|
|
|
|
|
|
|
for i in range(len(preference_signal)): |
|
|
preference = preference_signal[i] |
|
|
|
|
|
|
|
|
if samples_remaining > 0: |
|
|
|
|
|
smoothed_assignment[i] = current_state |
|
|
samples_remaining -= 1 |
|
|
|
|
|
|
|
|
if preference >= 0 and preference != current_state: |
|
|
corrections += 1 |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
if preference >= 0: |
|
|
|
|
|
if current_state != preference: |
|
|
|
|
|
current_state = preference |
|
|
samples_remaining = min_samples - 1 |
|
|
|
|
|
smoothed_assignment[i] = current_state |
|
|
|
|
|
else: |
|
|
|
|
|
if current_state >= 0: |
|
|
|
|
|
smoothed_assignment[i] = current_state |
|
|
else: |
|
|
|
|
|
smoothed_assignment[i] = -1 |
|
|
|
|
|
|
|
|
smoothed_masks = {} |
|
|
|
|
|
for i, speaker in enumerate(speaker_keys): |
|
|
new_mask = np.zeros_like(mask0) |
|
|
|
|
|
|
|
|
speaker_regions = smoothed_assignment == i |
|
|
new_mask[speaker_regions] = 1.0 |
|
|
|
|
|
|
|
|
both_speaking_regions = smoothed_assignment == 2 |
|
|
new_mask[both_speaking_regions] = 1.0 |
|
|
|
|
|
|
|
|
unassigned_ambiguous = smoothed_assignment == -1 |
|
|
if np.any(unassigned_ambiguous): |
|
|
|
|
|
original_ambiguous_mask = ambiguous_original & unassigned_ambiguous |
|
|
new_mask[original_ambiguous_mask] = masks[speaker][original_ambiguous_mask] |
|
|
|
|
|
smoothed_masks[speaker] = new_mask |
|
|
|
|
|
|
|
|
both_speaking_time = np.sum(smoothed_assignment == 2) / sample_rate |
|
|
speaker0_time = np.sum(smoothed_assignment == 0) / sample_rate |
|
|
speaker1_time = np.sum(smoothed_assignment == 1) / sample_rate |
|
|
ambiguous_time = np.sum(smoothed_assignment == -1) / sample_rate |
|
|
|
|
|
xprint(f" Both speaking clearly: {both_speaking_time:.1f}s") |
|
|
xprint(f" {speaker_keys[0]} only: {speaker0_time:.1f}s") |
|
|
xprint(f" {speaker_keys[1]} only: {speaker1_time:.1f}s") |
|
|
xprint(f" Ambiguous (assigned): {ambiguous_time:.1f}s") |
|
|
xprint(f" Enforced minimum duration on {corrections} samples ({corrections/sample_rate:.2f}s)") |
|
|
|
|
|
return smoothed_masks |
|
|
|
|
|
def _compute_distances_to_segments(self, indices: np.ndarray, segments: np.ndarray) -> np.ndarray: |
|
|
"""Compute minimum distances from indices to segments (vectorized).""" |
|
|
if len(segments) == 0: |
|
|
return np.full(len(indices), np.inf) |
|
|
|
|
|
|
|
|
indices_expanded = indices[:, np.newaxis] |
|
|
starts = segments[:, 0] |
|
|
ends = segments[:, 1] |
|
|
|
|
|
|
|
|
dist_to_start = np.maximum(0, starts - indices_expanded) |
|
|
dist_from_end = np.maximum(0, indices_expanded - ends) |
|
|
|
|
|
|
|
|
distances = np.minimum(dist_to_start, dist_from_end) |
|
|
|
|
|
|
|
|
return np.min(distances, axis=1) |
|
|
|
|
|
def _assign_based_on_distance(self, distances: Dict[str, np.ndarray], |
|
|
speaker_keys: List[str], |
|
|
ambiguous_indices: np.ndarray, |
|
|
audio_length: int) -> np.ndarray: |
|
|
"""Assign speakers based on distance with late-audio bias.""" |
|
|
speaker_0_distances = distances[speaker_keys[0]] |
|
|
speaker_1_distances = distances[speaker_keys[1]] |
|
|
|
|
|
|
|
|
assignments = (speaker_1_distances < speaker_0_distances).astype(int) |
|
|
|
|
|
|
|
|
late_threshold = int(audio_length * 0.6) |
|
|
late_indices = ambiguous_indices > late_threshold |
|
|
|
|
|
if np.any(late_indices) and len(speaker_keys) > 1: |
|
|
|
|
|
assignments[late_indices] = 1 |
|
|
|
|
|
return assignments |
|
|
|
|
|
def _save_outputs_optimized(self, waveform: torch.Tensor, masks: Dict[str, np.ndarray], |
|
|
sample_rate: int, audio_path: str, output1, output2) -> Dict[str, str]: |
|
|
"""Optimized output saving with parallel processing.""" |
|
|
output_paths = {} |
|
|
|
|
|
def save_speaker_audio(speaker_mask_pair, output): |
|
|
speaker, mask = speaker_mask_pair |
|
|
|
|
|
mask_tensor = torch.from_numpy(mask).unsqueeze(0) |
|
|
|
|
|
|
|
|
masked_audio = waveform * mask_tensor |
|
|
|
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
torchaudio.save(output, masked_audio, sample_rate) |
|
|
|
|
|
xprint(f"β Saved {speaker}: {output}") |
|
|
return speaker, output |
|
|
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=2) as executor: |
|
|
results = list(executor.map(save_speaker_audio, masks.items(), [output1, output2])) |
|
|
|
|
|
output_paths = dict(results) |
|
|
return output_paths |
|
|
|
|
|
def print_summary(self, audio_path: str): |
|
|
"""xprint diarization summary.""" |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
diarization = self.perform_optimized_diarization(audio_path) |
|
|
|
|
|
xprint("\n=== Diarization Summary ===") |
|
|
for turn, _, speaker in diarization.itertracks(yield_label=True): |
|
|
xprint(f"{speaker}: {turn.start:.1f}s - {turn.end:.1f}s") |
|
|
|
|
|
def extract_dual_audio(audio, output1, output2, verbose = False, audio_original = None): |
|
|
global verbose_output |
|
|
verbose_output = verbose |
|
|
separator = OptimizedPyannote31SpeakerSeparator( |
|
|
None, |
|
|
None, |
|
|
vad_onset=0.2, |
|
|
vad_offset=0.8 |
|
|
) |
|
|
|
|
|
import time |
|
|
start_time = time.time() |
|
|
|
|
|
outputs = separator.separate_audio(audio, output1, output2, audio_original) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===") |
|
|
for speaker, path in outputs.items(): |
|
|
xprint(f"{speaker}: {path}") |
|
|
|
|
|
def main(): |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Optimized Pyannote 3.1 Speaker Separator") |
|
|
parser.add_argument("--audio", required=True, help="Input audio file") |
|
|
parser.add_argument("--output", required=True, help="Output directory") |
|
|
parser.add_argument("--token", help="Hugging Face token") |
|
|
parser.add_argument("--local-model", help="Path to local 3.1 model") |
|
|
parser.add_argument("--summary", action="store_true", help="xprint summary") |
|
|
|
|
|
|
|
|
parser.add_argument("--vad-onset", type=float, default=0.2, |
|
|
help="VAD onset threshold (lower = more sensitive to speech start, default: 0.2)") |
|
|
parser.add_argument("--vad-offset", type=float, default=0.8, |
|
|
help="VAD offset threshold (higher = keeps speech longer, default: 0.8)") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
xprint("=== Optimized Pyannote 3.1 Speaker Separator ===") |
|
|
xprint("Performance optimizations: vectorized operations, memory management, parallel processing") |
|
|
xprint(f"Audio: {args.audio}") |
|
|
xprint(f"Output: {args.output}") |
|
|
xprint(f"VAD onset: {args.vad_onset}") |
|
|
xprint(f"VAD offset: {args.vad_offset}") |
|
|
xprint() |
|
|
|
|
|
if not os.path.exists(args.audio): |
|
|
xprint(f"ERROR: Audio file not found: {args.audio}") |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
separator = OptimizedPyannote31SpeakerSeparator( |
|
|
args.token, |
|
|
args.local_model, |
|
|
vad_onset=args.vad_onset, |
|
|
vad_offset=args.vad_offset |
|
|
) |
|
|
|
|
|
|
|
|
if args.summary: |
|
|
separator.print_summary(args.audio) |
|
|
|
|
|
|
|
|
import time |
|
|
start_time = time.time() |
|
|
|
|
|
audio_name = Path(args.audio).stem |
|
|
output_filename = f"{audio_name}_speaker0.wav" |
|
|
output_filename1 = f"{audio_name}_speaker1.wav" |
|
|
output_path = os.path.join(args.output, output_filename) |
|
|
output_path1 = os.path.join(args.output, output_filename1) |
|
|
|
|
|
outputs = separator.separate_audio(args.audio, output_path, output_path1) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===") |
|
|
for speaker, path in outputs.items(): |
|
|
xprint(f"{speaker}: {path}") |
|
|
|
|
|
except Exception as e: |
|
|
xprint(f"ERROR: {e}") |
|
|
return 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |