File size: 5,677 Bytes
4051511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import torch
from typing import Dict, List, Optional
from pyannote.audio import Pipeline
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class SpeakerDiarizer:
    """Handles speaker diarization using pyannote.audio"""

    def __init__(self, hf_token: Optional[str] = None):
        """
        Initialize speaker diarization pipeline

        Args:
            hf_token: Hugging Face access token (required for pyannote models)
        """
        self.hf_token = hf_token or os.getenv('HF_TOKEN')
        if not self.hf_token:
            logger.warning("No HF_TOKEN provided. Diarization may fail.")

        self.pipeline = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def load_pipeline(self, progress_callback=None):
        """Load the diarization pipeline"""
        if progress_callback:
            progress_callback("Loading speaker diarization model...")

        try:
            self.pipeline = Pipeline.from_pretrained(
                "pyannote/speaker-diarization-3.1",
                use_auth_token=self.hf_token
            )

            # Move to GPU if available
            if self.device == "cuda":
                self.pipeline.to(torch.device("cuda"))

            if progress_callback:
                progress_callback("Diarization model loaded successfully")

            logger.info(f"Diarization pipeline loaded on {self.device}")

        except Exception as e:
            logger.error(f"Failed to load diarization pipeline: {e}")
            raise Exception(
                f"Failed to load diarization model. "
                f"Make sure you have accepted the terms at: "
                f"https://huggingface.co/pyannote/speaker-diarization-3.1 "
                f"and provided a valid HF_TOKEN. Error: {str(e)}"
            )

    def diarize(self, audio_path: str, progress_callback=None) -> Dict:
        """
        Perform speaker diarization on audio file

        Args:
            audio_path: Path to audio file
            progress_callback: Optional callback for progress updates

        Returns:
            Dictionary mapping time segments to speaker labels
        """
        if self.pipeline is None:
            self.load_pipeline(progress_callback)

        if progress_callback:
            progress_callback("Analyzing speakers in audio...")

        try:
            # Run diarization
            diarization = self.pipeline(audio_path)

            # Convert to dictionary of segments
            segments = []
            for turn, _, speaker in diarization.itertracks(yield_label=True):
                segments.append({
                    'start': turn.start,
                    'end': turn.end,
                    'speaker': speaker
                })

            if progress_callback:
                num_speakers = len(set(seg['speaker'] for seg in segments))
                progress_callback(f"Diarization complete. Found {num_speakers} speakers")

            logger.info(f"Diarization found {len(segments)} segments")
            return {'segments': segments}

        except Exception as e:
            logger.error(f"Diarization failed: {e}")
            raise Exception(f"Speaker diarization failed: {str(e)}")

    def align_with_transcription(
        self,
        diarization_result: Dict,
        transcription_result: Dict,
        progress_callback=None
    ) -> Dict[int, str]:
        """
        Align speaker labels with transcription chunks

        Args:
            diarization_result: Result from diarize()
            transcription_result: Result from transcription
            progress_callback: Optional callback for progress updates

        Returns:
            Dictionary mapping chunk index to speaker label
        """
        if progress_callback:
            progress_callback("Aligning speakers with transcription...")

        speaker_labels = {}
        diarization_segments = diarization_result.get('segments', [])
        transcription_chunks = transcription_result.get('chunks', [])

        for chunk_idx, chunk in enumerate(transcription_chunks):
            timestamp = chunk.get('timestamp', (None, None))
            if timestamp[0] is None:
                continue

            chunk_start = timestamp[0]
            chunk_end = timestamp[1] if timestamp[1] is not None else chunk_start + 1.0

            # Find overlapping speaker segments
            chunk_mid = (chunk_start + chunk_end) / 2
            best_speaker = None
            best_overlap = 0

            for seg in diarization_segments:
                seg_start = seg['start']
                seg_end = seg['end']

                # Check if chunk midpoint is in this segment
                if seg_start <= chunk_mid <= seg_end:
                    best_speaker = seg['speaker']
                    break

                # Calculate overlap
                overlap_start = max(chunk_start, seg_start)
                overlap_end = min(chunk_end, seg_end)
                overlap = max(0, overlap_end - overlap_start)

                if overlap > best_overlap:
                    best_overlap = overlap
                    best_speaker = seg['speaker']

            if best_speaker:
                speaker_labels[chunk_idx] = best_speaker

        if progress_callback:
            progress_callback("Speaker alignment complete")

        logger.info(f"Aligned {len(speaker_labels)} chunks with speakers")
        return speaker_labels

    @staticmethod
    def is_available() -> bool:
        """Check if diarization is available (HF_TOKEN set)"""
        return os.getenv('HF_TOKEN') is not None