File size: 9,602 Bytes
2b83054
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
from pyannote.audio import Pipeline
import torch
import os
import time

class SpeakerDiarizer:
    def __init__(self, hf_token, device=None):
        """Initialize speaker diarization with HuggingFace token"""
        self.diarization_pipeline = None
        try:
            print("Loading diarization pipeline...")
            # Check available devices
            if device is None:
                device = "cuda:0" if torch.cuda.is_available() else "cpu"
            print(f"Using device: {device}")
            
            # Use the newer version that's compatible with your libraries
            self.diarization_pipeline = Pipeline.from_pretrained(
                "pyannote/speaker-diarization-3.1",
                use_auth_token=hf_token
            )
            self.diarization_pipeline.to(torch.device(device))
            print("Diarization model loaded successfully!")
        except Exception as e:
            print(f"Error loading diarization model: {e}")
    
    def diarize(self, audio_path, min_speakers=1, max_speakers=None, device=None):
        """Identify speakers in audio file"""
        if not self.diarization_pipeline:
            print("Diarization pipeline not available")
            return []
        
        try:
            print("Starting speaker diarization (this may take several minutes for longer files)...")
            start_time = time.time()
            
            # Set parameters for diarization
            params = {}
            if min_speakers is not None:
                params["min_speakers"] = min_speakers
            if max_speakers is not None:
                params["max_speakers"] = max_speakers
            
            # Set device if specified (cuda:0, cpu, etc.)
            if device:
                print(f"Using device: {device}")
                self.diarization_pipeline.to(torch.device(device))
            
            # Add progress updates
            print("Running diarization model...")
            print("This process may take several minutes with no visible progress...")
            print("Consider using a smaller audio segment for testing")
            
            # Use the diarization pipeline
            diarization = self.diarization_pipeline(audio_path, **params)
            
            print("Processing diarization results...")
            speakers = []
            for turn, _, speaker in diarization.itertracks(yield_label=True):
                speakers.append({'start': turn.start, 'end': turn.end, 'speaker': speaker})
            
            duration = time.time() - start_time
            print(f"Diarization completed in {duration:.1f} seconds")
            print(f"Detected {len(set(s['speaker'] for s in speakers))} unique speakers")
            
            return speakers
        except Exception as e:
            print(f"Error during diarization: {e}")
            return []
    
    def assign_speakers_to_segments(self, segments, speakers):
        """
        Assign speaker labels to transcript segments based on timing overlap
        
        Args:
            segments: List of transcript segments with start/end times
            speakers: List of speaker segments from diarization
            
        Returns:
            Updated segments with speaker information
        """
        # If no speakers found, assign everything to SPEAKER_0
        if not speakers:
            for segment in segments:
                segment["speaker"] = "SPEAKER_0"
            return segments
        
        # For single speaker, optimize by assigning all to same speaker
        if len(set(s["speaker"] for s in speakers)) == 1:
            speaker_id = speakers[0]["speaker"]
            for segment in segments:
                segment["speaker"] = speaker_id
            return segments
        
        # Process each segment
        for segment in segments:
            segment_start = segment.get("start", 0)
            segment_end = segment.get("end", 0)
            segment_duration = segment_end - segment_start
            
            # Find overlapping speakers
            speaker_overlaps = []
            
            for speaker_turn in speakers:
                # Fast check for any overlap
                if not (speaker_turn["end"] <= segment_start or speaker_turn["start"] >= segment_end):
                    # Calculate overlap duration
                    overlap_start = max(speaker_turn["start"], segment_start)
                    overlap_end = min(speaker_turn["end"], segment_end)
                    overlap_duration = overlap_end - overlap_start
                    
                    # Calculate overlap percentage relative to segment duration
                    overlap_percentage = overlap_duration / segment_duration if segment_duration > 0 else 0
                    
                    speaker_overlaps.append((speaker_turn["speaker"], overlap_duration, overlap_percentage))
            
            # Assign speaker with the most overlap
            if speaker_overlaps:
                # Sort by overlap duration (descending)
                speaker_overlaps.sort(key=lambda x: x[1], reverse=True)
                segment["speaker"] = speaker_overlaps[0][0]
                
                # Add confidence score if desired
                # segment["speaker_confidence"] = speaker_overlaps[0][2]
            else:
                # Find nearest speaker if no overlap
                segment_mid = (segment_start + segment_end) / 2
                
                closest_speaker = min(
                    speakers,
                    key=lambda s: min(
                        abs(s["start"] - segment_mid),
                        abs(s["end"] - segment_mid)
                    )
                )
                segment["speaker"] = closest_speaker["speaker"]
                
                # You can log this if logging is set up
                # print(f"No speaker overlap found for segment at {segment_start:.2f}s, using nearest speaker")
        
        return segments
    
    def extract_speaker_references(self, audio_path, speakers, output_dir="reference_audio", min_duration=3.0, max_duration=10.0):
        """
        Extract reference audio clips for each unique speaker.
        
        Args:
            audio_path: Path to the original audio file
            speakers: List of speaker segments from diarization
            output_dir: Directory to save reference audio clips
            min_duration: Minimum duration for a reference clip (seconds)
            max_duration: Maximum duration for a reference clip (seconds)
            
        Returns:
            Dictionary mapping speaker IDs to reference audio file paths
        """
        import os
        from pydub import AudioSegment
        
        # Ensure output directory exists
        os.makedirs(output_dir, exist_ok=True)
        
        # Load the original audio file
        try:
            full_audio = AudioSegment.from_file(audio_path)
        except Exception as e:
            print(f"Error loading audio file: {e}")
            return {}
        
        # Get unique speaker IDs
        unique_speakers = set(segment["speaker"] for segment in speakers)
        reference_files = {}
        
        print(f"Extracting reference audio for {len(unique_speakers)} speakers...")
        
        for speaker in unique_speakers:
            # Find all segments for this speaker
            speaker_segments = [s for s in speakers if s["speaker"] == speaker]
            
            # Sort segments by duration (descending)
            speaker_segments.sort(key=lambda s: s["end"] - s["start"], reverse=True)
            
            # Find a segment with suitable duration
            selected_segment = None
            for segment in speaker_segments:
                duration = segment["end"] - segment["start"]
                if duration >= min_duration:
                    # If longer than max_duration, trim it
                    if duration > max_duration:
                        mid_point = (segment["start"] + segment["end"]) / 2
                        half_max = max_duration / 2
                        segment = {
                            "start": mid_point - half_max,
                            "end": mid_point + half_max,
                            "speaker": speaker
                        }
                    selected_segment = segment
                    break
            
            # If no segment is long enough, take the longest one
            if selected_segment is None and speaker_segments:
                selected_segment = speaker_segments[0]
                
            # Extract the audio segment
            if selected_segment:
                start_ms = int(selected_segment["start"] * 1000)
                end_ms = int(selected_segment["end"] * 1000)
                
                # Extract audio segment
                speaker_audio = full_audio[start_ms:end_ms]
                
                # Save to file
                speaker_id = speaker.replace("SPEAKER_", "")
                output_path = os.path.join(output_dir, f"speaker_{speaker_id}_reference.wav")
                speaker_audio.export(output_path, format="wav")
                
                reference_files[speaker] = output_path
                
                print(f"  Extracted {selected_segment['end'] - selected_segment['start']:.2f}s reference audio for {speaker}")
            else:
                print(f"  No suitable audio segment found for {speaker}")
        
        return reference_files