Spaces:
Runtime error
Runtime error
| from pyannote.audio import Pipeline | |
| import torch | |
| import os | |
| import time | |
| class SpeakerDiarizer: | |
| def __init__(self, hf_token, device=None): | |
| """Initialize speaker diarization with HuggingFace token""" | |
| self.diarization_pipeline = None | |
| try: | |
| print("Loading diarization pipeline...") | |
| # Check available devices | |
| if device is None: | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Use the newer version that's compatible with your libraries | |
| self.diarization_pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", | |
| use_auth_token=hf_token | |
| ) | |
| self.diarization_pipeline.to(torch.device(device)) | |
| print("Diarization model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading diarization model: {e}") | |
| def diarize(self, audio_path, min_speakers=1, max_speakers=None, device=None): | |
| """Identify speakers in audio file""" | |
| if not self.diarization_pipeline: | |
| print("Diarization pipeline not available") | |
| return [] | |
| try: | |
| print("Starting speaker diarization (this may take several minutes for longer files)...") | |
| start_time = time.time() | |
| # Set parameters for diarization | |
| params = {} | |
| if min_speakers is not None: | |
| params["min_speakers"] = min_speakers | |
| if max_speakers is not None: | |
| params["max_speakers"] = max_speakers | |
| # Set device if specified (cuda:0, cpu, etc.) | |
| if device: | |
| print(f"Using device: {device}") | |
| self.diarization_pipeline.to(torch.device(device)) | |
| # Add progress updates | |
| print("Running diarization model...") | |
| print("This process may take several minutes with no visible progress...") | |
| print("Consider using a smaller audio segment for testing") | |
| # Use the diarization pipeline | |
| diarization = self.diarization_pipeline(audio_path, **params) | |
| print("Processing diarization results...") | |
| speakers = [] | |
| for turn, _, speaker in diarization.itertracks(yield_label=True): | |
| speakers.append({'start': turn.start, 'end': turn.end, 'speaker': speaker}) | |
| duration = time.time() - start_time | |
| print(f"Diarization completed in {duration:.1f} seconds") | |
| print(f"Detected {len(set(s['speaker'] for s in speakers))} unique speakers") | |
| return speakers | |
| except Exception as e: | |
| print(f"Error during diarization: {e}") | |
| return [] | |
| def assign_speakers_to_segments(self, segments, speakers): | |
| """ | |
| Assign speaker labels to transcript segments based on timing overlap | |
| Args: | |
| segments: List of transcript segments with start/end times | |
| speakers: List of speaker segments from diarization | |
| Returns: | |
| Updated segments with speaker information | |
| """ | |
| # If no speakers found, assign everything to SPEAKER_0 | |
| if not speakers: | |
| for segment in segments: | |
| segment["speaker"] = "SPEAKER_0" | |
| return segments | |
| # For single speaker, optimize by assigning all to same speaker | |
| if len(set(s["speaker"] for s in speakers)) == 1: | |
| speaker_id = speakers[0]["speaker"] | |
| for segment in segments: | |
| segment["speaker"] = speaker_id | |
| return segments | |
| # Process each segment | |
| for segment in segments: | |
| segment_start = segment.get("start", 0) | |
| segment_end = segment.get("end", 0) | |
| segment_duration = segment_end - segment_start | |
| # Find overlapping speakers | |
| speaker_overlaps = [] | |
| for speaker_turn in speakers: | |
| # Fast check for any overlap | |
| if not (speaker_turn["end"] <= segment_start or speaker_turn["start"] >= segment_end): | |
| # Calculate overlap duration | |
| overlap_start = max(speaker_turn["start"], segment_start) | |
| overlap_end = min(speaker_turn["end"], segment_end) | |
| overlap_duration = overlap_end - overlap_start | |
| # Calculate overlap percentage relative to segment duration | |
| overlap_percentage = overlap_duration / segment_duration if segment_duration > 0 else 0 | |
| speaker_overlaps.append((speaker_turn["speaker"], overlap_duration, overlap_percentage)) | |
| # Assign speaker with the most overlap | |
| if speaker_overlaps: | |
| # Sort by overlap duration (descending) | |
| speaker_overlaps.sort(key=lambda x: x[1], reverse=True) | |
| segment["speaker"] = speaker_overlaps[0][0] | |
| # Add confidence score if desired | |
| # segment["speaker_confidence"] = speaker_overlaps[0][2] | |
| else: | |
| # Find nearest speaker if no overlap | |
| segment_mid = (segment_start + segment_end) / 2 | |
| closest_speaker = min( | |
| speakers, | |
| key=lambda s: min( | |
| abs(s["start"] - segment_mid), | |
| abs(s["end"] - segment_mid) | |
| ) | |
| ) | |
| segment["speaker"] = closest_speaker["speaker"] | |
| # You can log this if logging is set up | |
| # print(f"No speaker overlap found for segment at {segment_start:.2f}s, using nearest speaker") | |
| return segments | |
| def extract_speaker_references(self, audio_path, speakers, output_dir="reference_audio", min_duration=3.0, max_duration=10.0): | |
| """ | |
| Extract reference audio clips for each unique speaker. | |
| Args: | |
| audio_path: Path to the original audio file | |
| speakers: List of speaker segments from diarization | |
| output_dir: Directory to save reference audio clips | |
| min_duration: Minimum duration for a reference clip (seconds) | |
| max_duration: Maximum duration for a reference clip (seconds) | |
| Returns: | |
| Dictionary mapping speaker IDs to reference audio file paths | |
| """ | |
| import os | |
| from pydub import AudioSegment | |
| # Ensure output directory exists | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Load the original audio file | |
| try: | |
| full_audio = AudioSegment.from_file(audio_path) | |
| except Exception as e: | |
| print(f"Error loading audio file: {e}") | |
| return {} | |
| # Get unique speaker IDs | |
| unique_speakers = set(segment["speaker"] for segment in speakers) | |
| reference_files = {} | |
| print(f"Extracting reference audio for {len(unique_speakers)} speakers...") | |
| for speaker in unique_speakers: | |
| # Find all segments for this speaker | |
| speaker_segments = [s for s in speakers if s["speaker"] == speaker] | |
| # Sort segments by duration (descending) | |
| speaker_segments.sort(key=lambda s: s["end"] - s["start"], reverse=True) | |
| # Find a segment with suitable duration | |
| selected_segment = None | |
| for segment in speaker_segments: | |
| duration = segment["end"] - segment["start"] | |
| if duration >= min_duration: | |
| # If longer than max_duration, trim it | |
| if duration > max_duration: | |
| mid_point = (segment["start"] + segment["end"]) / 2 | |
| half_max = max_duration / 2 | |
| segment = { | |
| "start": mid_point - half_max, | |
| "end": mid_point + half_max, | |
| "speaker": speaker | |
| } | |
| selected_segment = segment | |
| break | |
| # If no segment is long enough, take the longest one | |
| if selected_segment is None and speaker_segments: | |
| selected_segment = speaker_segments[0] | |
| # Extract the audio segment | |
| if selected_segment: | |
| start_ms = int(selected_segment["start"] * 1000) | |
| end_ms = int(selected_segment["end"] * 1000) | |
| # Extract audio segment | |
| speaker_audio = full_audio[start_ms:end_ms] | |
| # Save to file | |
| speaker_id = speaker.replace("SPEAKER_", "") | |
| output_path = os.path.join(output_dir, f"speaker_{speaker_id}_reference.wav") | |
| speaker_audio.export(output_path, format="wav") | |
| reference_files[speaker] = output_path | |
| print(f" Extracted {selected_segment['end'] - selected_segment['start']:.2f}s reference audio for {speaker}") | |
| else: | |
| print(f" No suitable audio segment found for {speaker}") | |
| return reference_files | |