File size: 15,876 Bytes
0b3a48c
 
 
 
 
 
 
 
 
 
 
c710b8a
 
0b3a48c
 
 
 
c710b8a
0b3a48c
 
 
 
c710b8a
0b3a48c
 
 
 
 
 
 
 
 
 
c710b8a
 
 
 
 
 
 
 
 
45e0229
c710b8a
0b3a48c
 
 
 
 
 
 
 
 
c710b8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b3a48c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c710b8a
 
0b3a48c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c710b8a
0b3a48c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c710b8a
0b3a48c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
import os
import json
import torch
import tempfile
import torchaudio
import threading
import numpy as np
import soundfile as sf
import noisereduce as nr
from scipy import signal
from numpy.linalg import norm
from speechbrain.pretrained import SpeakerRecognition, EncoderClassifier
from speechbrain.pretrained import SpectralMaskEnhancement
from transformers import pipeline as hf_pipeline
from jiwer import wer, Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSpaces, Strip

class ASR_Diarization:
    def __init__(self, HF_TOKEN, asr_model="openai/whisper-medium"):
        self.HF_TOKEN = HF_TOKEN
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._unknown_lock = threading.Lock()

        # Load SpeechBrain models
        try:
            self.embedding_model = EncoderClassifier.from_hparams(
                source="speechbrain/spkrec-ecapa-voxceleb",
                run_opts={"device": str(self.device)}
            )
            print("[ECAPA] Model loaded successfully.")
        except Exception as e:
            self.embedding_model = None
            print(f"[ERROR] Failed to load ECAPA: {e}")

        try:
            self.speaker_diarization = SpeakerRecognition.from_hparams(
                source="speechbrain/spkrec-ecapa-voxceleb",
                savedir="pretrained_models/spkrec-ecapa-voxceleb"
            )
            print("[Speaker Recognition] Model loaded successfully.")
        except Exception as e:
            self.speaker_diarization = None
            print(f"[ERROR] Failed to load Speaker Recognition: {e}")

        # Load ASR pipeline
        device_index = 0 if torch.cuda.is_available() else -1
        self.asr_pipeline = hf_pipeline(
            "automatic-speech-recognition",
            model=asr_model,
            device=device_index,
            return_timestamps=True
        )

    def run_diarization(self, audio_path):
        """Simple diarization using SpeechBrain embedding clustering"""
        audio, sr = torchaudio.load(audio_path)
        audio_np = audio[0].numpy() if audio.shape[0] == 1 else audio.mean(dim=0).numpy()
        
        # Segment audio into chunks for diarization
        chunk_duration = 2.0  # 2-second chunks
        chunk_size = int(chunk_duration * sr)
        segments = []
        
        for i in range(0, len(audio_np), chunk_size):
            start_time = i / sr
            end_time = min((i + chunk_size) / sr, len(audio_np) / sr)
            chunk = audio_np[i:i+chunk_size]
            
            if len(chunk) < 8000:  # Skip very short chunks
                continue
                
            # Get speaker embedding for this chunk
            if self.embedding_model:
                try:
                    chunk_tensor = torch.from_numpy(chunk).unsqueeze(0).to(self.device)
                    with torch.no_grad():
                        embedding = self.embedding_model.encode_batch(chunk_tensor).squeeze().cpu().numpy()
                    
                    # Simple speaker assignment based on embedding similarity
                    speaker_id = self._assign_speaker(embedding, segments)
                    
                    segments.append({
                        "start": start_time,
                        "end": end_time,
                        "speaker": speaker_id,
                        "embedding": embedding
                    })
                except Exception as e:
                    print(f"Error processing chunk: {e}")
                    continue
        
        return segments

    def _assign_speaker(self, embedding, existing_segments, threshold=0.7):
        """Assign speaker based on embedding similarity"""
        if not existing_segments:
            return "speaker_1"
        
        # Calculate similarity with existing speakers
        similarities = []
        for seg in existing_segments[-10:]:  # Check last 10 segments
            if "embedding" in seg:
                sim = np.dot(embedding.flatten(), seg["embedding"].flatten()) / (
                    norm(embedding.flatten()) * norm(seg["embedding"].flatten())
                )
                similarities.append((seg["speaker"], sim))
        
        if similarities:
            best_speaker, best_sim = max(similarities, key=lambda x: x[1])
            if best_sim > threshold:
                return best_speaker
        
        # Create new speaker
        existing_speakers = set(seg["speaker"] for seg in existing_segments)
        speaker_num = 1
        while f"speaker_{speaker_num}" in existing_speakers:
            speaker_num += 1
        return f"speaker_{speaker_num}"

    def load_unknown_speakers(self, unknown_speakers_path):
        if os.path.exists(unknown_speakers_path):
            try:
                with open(unknown_speakers_path, "r") as f:
                    content = f.read().strip()
                    if content:
                        return json.loads(content)
            except Exception as e:
                print(f"[WARN] Failed to load unknown speakers ({e}), starting fresh")
        return {}

    def save_unknown_speakers(self, unknown_speakers, unknown_speakers_path):
        try:
            os.makedirs(os.path.dirname(unknown_speakers_path), exist_ok=True)
            tmp = unknown_speakers_path + ".tmp"
            with open(tmp, "w", encoding="utf-8") as f:
                json.dump(unknown_speakers, f, indent=2)
                f.flush()
                os.fsync(f.fileno())
            os.replace(tmp, unknown_speakers_path)
            return True
        except Exception as e:
            print(f"[ERROR] Failed to save unknown speakers: {e}")
            return False

    def get_next_unknown_id(self, unknown_speakers):
        if not unknown_speakers:
            return "unknown_1"
        max_id = 0
        for speaker_id in unknown_speakers.keys():
            if speaker_id.startswith("unknown_"):
                try:
                    num = int(speaker_id.split("_")[1])
                    max_id = max(max_id, num)
                except (IndexError, ValueError):
                    continue
        return f"unknown_{max_id + 1}"

    def match_speaker_embedding(self, cluster_embedding, enrolled_speakers_np, unknown_speakers, threshold=0.5):
        cluster_embedding = cluster_embedding / norm(cluster_embedding)
        best_name, best_score, is_enrolled = None, -1.0, False

        # Log all similarities
        sim_log = []

        # Check enrolled speakers
        for name, e_emb in enrolled_speakers_np.items():
            sim = float(np.dot(cluster_embedding, e_emb / norm(e_emb)))
            sim_log.append((name, sim, True))
            if sim > best_score:
                best_name, best_score, is_enrolled = name, sim, True

        # Check unknown speakers
        for u_id, u_emb in unknown_speakers.items():
            sim = float(np.dot(cluster_embedding, np.array(u_emb) / norm(u_emb)))
            sim_log.append((u_id, sim, False))
            if sim > best_score:
                best_name, best_score, is_enrolled = u_id, sim, False

        # Log before creating new unknown
        print("[MATCH LOG] Cluster embedding compared:", sim_log)
        print(f"[MATCH LOG] Best match: {best_name}, score: {best_score}, enrolled: {is_enrolled}")

        return best_name, best_score, is_enrolled

    def run_transcription(self, audio_path, diar_json, enrolled_speakers=None, unknown_speakers_path=None):
        unknown_speakers_path = unknown_speakers_path or os.path.join(os.path.dirname(audio_path), "unknown_speakers.json")

        # Load unknown speakers safely
        with self._unknown_lock:
            unknown_speakers = self.load_unknown_speakers(unknown_speakers_path)

        audio, sr = torchaudio.load(audio_path)
        audio_np = audio[0].numpy() if audio.shape[0] == 1 else audio.mean(dim=0).numpy()
        merged_segments, speaker_segments = [], {}
        enrolled_speakers_np = {n: v/norm(v) for n,v in (enrolled_speakers or {}).items() if norm(v) > 0}

        target_sr = 16000
        
        # Group segments by speaker for clustering
        clusters = {}
        for seg in diar_json:
            clusters.setdefault(seg["speaker"], []).append(seg)

        # Compute cluster embeddings
        cluster_embeddings = {}
        for cluster_label, segs in clusters.items():
            seg_embs = []
            for seg in segs:
                start, end = seg["start"], seg["end"]
                start_sample, end_sample = int(start*sr), int(end*sr)
                chunk = audio_np[start_sample:end_sample]
                if chunk.size < 8000:
                    chunk = np.pad(chunk, (0, 8000 - chunk.size), mode='constant')
                if sr != target_sr:
                    chunk = signal.resample(chunk, int(len(chunk)*target_sr/sr)).astype(np.float32)
                if self.embedding_model:
                    tensor = torch.from_numpy(chunk).unsqueeze(0).to(self.device)
                    with torch.no_grad():
                        emb = np.ravel(self.embedding_model.encode_batch(tensor).squeeze().cpu().numpy())
                        if norm(emb) > 0:
                            seg_embs.append(emb / norm(emb))
            if seg_embs:
                cluster_emb = np.mean(np.stack(seg_embs), axis=0)
                cluster_embeddings[cluster_label] = cluster_emb / norm(cluster_emb)

        speaker_map, speakers_updated = {}, {}
        threshold = 0.5

        # Thread-safe unknown speaker update
        with self._unknown_lock:
            for cluster_label, c_emb in cluster_embeddings.items():
                matched_name, best_score, is_enrolled = self.match_speaker_embedding(
                    c_emb, enrolled_speakers_np, unknown_speakers, threshold
                )

                if best_score >= threshold:
                    speaker_map[cluster_label] = matched_name
                    # Update unknown embedding if matched_name is an unknown
                    if not is_enrolled:
                        old_emb = np.array(unknown_speakers[matched_name])
                        new_emb = (old_emb + c_emb) / 2.0
                        unknown_speakers[matched_name] = (new_emb / norm(new_emb)).tolist()
                        speakers_updated = True
                else:
                    # No sufficient match found, create new unknown
                    new_id = self.get_next_unknown_id(unknown_speakers)
                    unknown_speakers[new_id] = c_emb.tolist()
                    speaker_map[cluster_label] = new_id
                    speakers_updated = True

            if speakers_updated:
                self.save_unknown_speakers(unknown_speakers, unknown_speakers_path)

        # ASR transcription
        for seg in diar_json:
            start, end, spk = seg["start"], seg["end"], seg["speaker"]
            start_sample, end_sample = int(start*sr), int(end*sr)
            chunk = audio_np[start_sample:end_sample]
            if chunk.size == 0: continue
            if sr != target_sr:
                chunk = signal.resample(chunk, int(len(chunk)*target_sr/sr)).astype(np.float32)
                sr_chunk = target_sr
            else:
                sr_chunk = sr
            try:
                reduced = nr.reduce_noise(chunk, sr=sr_chunk)
            except Exception:
                reduced = chunk
            try:
                result = self.asr_pipeline({"array": reduced, "sampling_rate": sr_chunk})
            except Exception:
                with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpf:
                    sf.write(tmpf.name, reduced, sr_chunk, subtype="PCM_16")
                    result = self.asr_pipeline(tmpf.name)

            tokens, transcript_text = [], ""
            if isinstance(result, dict) and "chunks" in result:
                for w in result["chunks"]:
                    start_ts = w.get("start") or (w.get("timestamp") and w["timestamp"][0])
                    end_ts = w.get("end") or (w.get("timestamp") and w["timestamp"][1])
                    word_text = w.get("text","").strip()
                    tokens.append({"start":start_ts,"end":end_ts,"text":word_text,"tag":"w"})
                    transcript_text += word_text + " "
            else:
                text = result.get("text") if isinstance(result, dict) else str(result)
                transcript_text = text or ""
                tokens.append({"start":None,"end":None,"text":transcript_text,"tag":"w"})

            final_speaker = speaker_map.get(spk,"unknown")
            seg_dict = {"speaker":final_speaker,"start":start,"end":end,"text":transcript_text.strip(),"tokens":tokens}
            merged_segments.append(seg_dict)
            speaker_segments.setdefault(final_speaker,[]).append(seg_dict)

        return merged_segments, list(speaker_segments.keys())

    def run_pipeline(self, audio_path, output_dir=None, base_name=None,
                 ref_rttm=None, ref_json=None, enrolled_speakers=None, unknown_speakers_path=None):
        diar_json = self.run_diarization(audio_path)
        merged_segments, speakers = self.run_transcription(
            audio_path, diar_json, enrolled_speakers=enrolled_speakers,
            unknown_speakers_path=unknown_speakers_path
        )

        if output_dir and base_name:
            os.makedirs(output_dir, exist_ok=True)

            # Save RTTM
            rttm_path = os.path.join(output_dir, f"{base_name}.rttm")
            with open(rttm_path, "w") as f:
                for seg in diar_json:
                    f.write(
                        f"SPEAKER {base_name} 1 {seg['start']:.6f} "
                        f"{seg['end']-seg['start']:.6f} <NA> <NA> "
                        f"{seg['speaker']} <NA>\n"
                    )

            # Save transcription
            merged_path = os.path.join(output_dir, f"{base_name}_merged_transcription.json")
            with open(merged_path, "w") as f:
                json.dump(merged_segments, f, indent=2)

        # Evaluation
        eval_results = None
        if ref_rttm or ref_json:
            eval_results = self.evaluate(output_dir, base_name,
                                         ref_rttm=ref_rttm, ref_json=ref_json)

        return {
            "speakers": speakers,
            "segments": merged_segments,
            "evaluation": eval_results
        }

    def evaluate(self, output_dir, base_name, ref_rttm=None, ref_json=None):
        results = {}
        hyp_rttm = os.path.join(output_dir, f"{base_name}.rttm")
        hyp_json = os.path.join(output_dir, f"{base_name}_merged_transcription.json")

        if ref_json:
            def load_words(path):
                data = json.load(open(path))
                return " ".join([tok["text"] for seg in data for tok in seg["tokens"]])

            ref_text, hyp_text = load_words(ref_json), load_words(hyp_json)
            transform = Compose([ToLowerCase(), RemovePunctuation(),
                                 RemoveMultipleSpaces(), Strip()])
            results["WER_raw"] = round(wer(ref_text, hyp_text), 4)
            results["WER_normalized"] = round(wer(transform(ref_text), transform(hyp_text)), 4)

        return results if results else None

    def __call__(self, inputs):
        if isinstance(inputs, dict):
            if "audio_bytes" in inputs:
                audio_bytes = inputs["audio_bytes"]
            elif "audio" in inputs:
                audio_bytes = inputs["audio"]
            else:
                raise ValueError("No audio found in inputs")
        else:
            audio_bytes = inputs

        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            tmp.write(audio_bytes)
            tmp_path = tmp.name

        return self.run_pipeline(tmp_path)