Luigi commited on
Commit
de5da7a
·
1 Parent(s): 9453a6f

Add vendored improved_diarization into src for Spaces importability

Browse files
Files changed (1) hide show
  1. src/improved_diarization.py +319 -0
src/improved_diarization.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diarisation Améliorée avec Clustering Adaptatif et Validation de Qualité
3
+ Vendored copy so the module is importable when running Streamlit from `src/`.
4
+ """
5
+
6
+ import numpy as np
7
+ from sklearn.cluster import AgglomerativeClustering
8
+ from sklearn.metrics import silhouette_score
9
+ from typing import List, Dict, Tuple, Any
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class ImprovedDiarization:
15
+ """Diarisation améliorée avec clustering adaptatif et validation de qualité"""
16
+
17
+ def __init__(self):
18
+ self.min_speaker_duration = 3.0 # Durée minimum par locuteur (secondes)
19
+ self.max_speakers = 10
20
+ self.quality_threshold = 0.3 # Seuil de qualité minimum
21
+
22
+ def adaptive_clustering(self, embeddings: np.ndarray) -> Tuple[int, float, np.ndarray]:
23
+ """
24
+ Détermine automatiquement le nombre optimal de locuteurs
25
+
26
+ Returns:
27
+ (optimal_n_speakers, best_score, best_labels)
28
+ """
29
+ if len(embeddings) < 2:
30
+ return 1, 1.0, np.zeros(len(embeddings))
31
+
32
+ best_score = -1
33
+ best_n_speakers = 2
34
+ best_labels = None
35
+
36
+ # Test différentes configurations
37
+ configurations = [
38
+ ('euclidean', 'ward'),
39
+ ('cosine', 'average'),
40
+ ('cosine', 'complete'),
41
+ ('euclidean', 'complete'),
42
+ ]
43
+
44
+ max_clusters = min(self.max_speakers, len(embeddings) - 1)
45
+
46
+ for n_speakers in range(2, max_clusters + 1):
47
+ for metric, linkage in configurations:
48
+ try:
49
+ clustering = AgglomerativeClustering(
50
+ n_clusters=n_speakers,
51
+ metric=metric,
52
+ linkage=linkage
53
+ )
54
+ labels = clustering.fit_predict(embeddings)
55
+
56
+ # Score de silhouette
57
+ score = silhouette_score(embeddings, labels, metric=metric)
58
+
59
+ # Bonus pour distribution équilibrée
60
+ unique, counts = np.unique(labels, return_counts=True)
61
+ balance_ratio = min(counts) / max(counts)
62
+ adjusted_score = score * (0.7 + 0.3 * balance_ratio)
63
+
64
+ logger.debug(f"n_speakers={n_speakers}, metric={metric}, linkage={linkage}: "
65
+ f"score={score:.3f}, balance={balance_ratio:.3f}, "
66
+ f"adjusted={adjusted_score:.3f}")
67
+
68
+ if adjusted_score > best_score:
69
+ best_score = adjusted_score
70
+ best_n_speakers = n_speakers
71
+ best_labels = labels.copy()
72
+
73
+ except Exception as e:
74
+ logger.warning(f"Clustering failed for n={n_speakers}, "
75
+ f"metric={metric}, linkage={linkage}: {e}")
76
+ continue
77
+
78
+ return best_n_speakers, best_score, best_labels
79
+
80
+ def validate_clustering_quality(self, embeddings: np.ndarray, labels: np.ndarray) -> Dict[str, Any]:
81
+ """Valide la qualité du clustering"""
82
+
83
+ if len(np.unique(labels)) == 1:
84
+ return {
85
+ 'silhouette_score': -1.0,
86
+ 'cluster_balance': 1.0,
87
+ 'quality': 'poor',
88
+ 'reason': 'single_cluster'
89
+ }
90
+
91
+ try:
92
+ # Score de silhouette
93
+ sil_score = silhouette_score(embeddings, labels)
94
+
95
+ # Distribution des clusters
96
+ unique, counts = np.unique(labels, return_counts=True)
97
+ cluster_balance = min(counts) / max(counts)
98
+
99
+ # Distance intra vs inter-cluster
100
+ intra_distances = []
101
+ inter_distances = []
102
+
103
+ for i in range(len(embeddings)):
104
+ for j in range(i + 1, len(embeddings)):
105
+ dist = np.linalg.norm(embeddings[i] - embeddings[j])
106
+ if labels[i] == labels[j]:
107
+ intra_distances.append(dist)
108
+ else:
109
+ inter_distances.append(dist)
110
+
111
+ separation_ratio = np.mean(inter_distances) / np.mean(intra_distances) if intra_distances else 1.0
112
+
113
+ # Évaluation globale
114
+ quality = 'excellent' if sil_score > 0.7 and cluster_balance > 0.5 else \
115
+ 'good' if sil_score > 0.5 and cluster_balance > 0.3 else \
116
+ 'fair' if sil_score > 0.3 else 'poor'
117
+
118
+ return {
119
+ 'silhouette_score': sil_score,
120
+ 'cluster_balance': cluster_balance,
121
+ 'separation_ratio': separation_ratio,
122
+ 'cluster_distribution': dict(zip(unique, counts)),
123
+ 'quality': quality,
124
+ 'reason': f"sil_score={sil_score:.3f}, balance={cluster_balance:.3f}"
125
+ }
126
+
127
+ except Exception as e:
128
+ logger.error(f"Quality validation failed: {e}")
129
+ return {
130
+ 'silhouette_score': -1.0,
131
+ 'cluster_balance': 0.0,
132
+ 'quality': 'error',
133
+ 'reason': str(e)
134
+ }
135
+
136
+ def refine_speaker_assignments(self, utterances: List[Dict],
137
+ min_duration: float = None) -> List[Dict]:
138
+ """Affine les assignations de locuteurs"""
139
+
140
+ if min_duration is None:
141
+ min_duration = self.min_speaker_duration
142
+
143
+ # Calcule la durée par locuteur
144
+ speaker_durations = {}
145
+ for utt in utterances:
146
+ speaker = utt['speaker']
147
+ duration = utt['end'] - utt['start']
148
+ speaker_durations[speaker] = speaker_durations.get(speaker, 0) + duration
149
+
150
+ logger.info(f"Speaker durations: {speaker_durations}")
151
+
152
+ # Identifie les locuteurs avec durée insuffisante
153
+ weak_speakers = {s for s, d in speaker_durations.items() if d < min_duration}
154
+
155
+ if not weak_speakers:
156
+ return utterances
157
+
158
+ logger.info(f"Weak speakers to reassign: {weak_speakers}")
159
+
160
+ # Réassigne les segments des locuteurs faibles
161
+ refined_utterances = []
162
+ for utt in utterances:
163
+ if utt['speaker'] in weak_speakers:
164
+ # Trouve le locuteur dominant adjacent
165
+ new_speaker = self._find_dominant_adjacent_speaker(utt, utterances, weak_speakers)
166
+ utt['speaker'] = new_speaker
167
+ logger.debug(f"Reassigned segment [{utt['start']:.1f}-{utt['end']:.1f}s] "
168
+ f"to speaker {new_speaker}")
169
+
170
+ refined_utterances.append(utt)
171
+
172
+ return refined_utterances
173
+
174
+ def _find_dominant_adjacent_speaker(self, target_utt: Dict,
175
+ all_utterances: List[Dict],
176
+ exclude_speakers: set) -> int:
177
+ """Trouve le locuteur dominant adjacent pour réassignation"""
178
+
179
+ # Trouve les segments adjacents
180
+ target_start = target_utt['start']
181
+ target_end = target_utt['end']
182
+
183
+ candidates = []
184
+ for utt in all_utterances:
185
+ if utt['speaker'] in exclude_speakers:
186
+ continue
187
+
188
+ # Distance temporelle
189
+ if utt['end'] <= target_start:
190
+ distance = target_start - utt['end']
191
+ elif utt['start'] >= target_end:
192
+ distance = utt['start'] - target_end
193
+ else:
194
+ distance = 0 # Chevauchement
195
+
196
+ candidates.append((utt['speaker'], distance))
197
+
198
+ if not candidates:
199
+ # Fallback: premier locuteur non exclu
200
+ for utt in all_utterances:
201
+ if utt['speaker'] not in exclude_speakers:
202
+ return utt['speaker']
203
+ return 0 # Fallback ultime
204
+
205
+ # Retourne le locuteur le plus proche
206
+ return min(candidates, key=lambda x: x[1])[0]
207
+
208
+ def merge_consecutive_same_speaker(self, utterances: List[Dict],
209
+ max_gap: float = 1.0) -> List[Dict]:
210
+ """Fusionne les segments consécutifs du même locuteur"""
211
+
212
+ if not utterances:
213
+ return utterances
214
+
215
+ merged = []
216
+ current = utterances[0].copy()
217
+
218
+ for next_utt in utterances[1:]:
219
+ # Même locuteur et gap acceptable
220
+ if (current['speaker'] == next_utt['speaker'] and
221
+ next_utt['start'] - current['end'] <= max_gap):
222
+
223
+ # Fusionne les textes
224
+ current['text'] = current['text'].strip() + ' ' + next_utt['text'].strip()
225
+ current['end'] = next_utt['end']
226
+
227
+ logger.debug(f"Merged segments: [{current['start']:.1f}-{current['end']:.1f}s] "
228
+ f"Speaker {current['speaker']}")
229
+ else:
230
+ # Finalise le segment actuel
231
+ merged.append(current)
232
+ current = next_utt.copy()
233
+
234
+ # Ajoute le dernier segment
235
+ merged.append(current)
236
+
237
+ return merged
238
+
239
+ def diarize_with_quality_control(self, embeddings: np.ndarray,
240
+ utterances: List[Dict]) -> Tuple[List[Dict], Dict[str, Any]]:
241
+ """
242
+ Diarisation complète avec contrôle qualité
243
+
244
+ Returns:
245
+ (utterances_with_speakers, quality_metrics)
246
+ """
247
+
248
+ if len(embeddings) < 2:
249
+ # Cas trivial : un seul segment
250
+ for utt in utterances:
251
+ utt['speaker'] = 0
252
+ return utterances, {'quality': 'trivial', 'n_speakers': 1}
253
+
254
+ # Clustering adaptatif
255
+ n_speakers, clustering_score, labels = self.adaptive_clustering(embeddings)
256
+
257
+ # Validation de qualité
258
+ quality_metrics = self.validate_clustering_quality(embeddings, labels)
259
+ quality_metrics['n_speakers'] = n_speakers
260
+ quality_metrics['clustering_score'] = clustering_score
261
+
262
+ logger.info(f"Adaptive clustering: {n_speakers} speakers, "
263
+ f"score={clustering_score:.3f}, quality={quality_metrics['quality']}")
264
+
265
+ # Applique les labels aux utterances
266
+ for i, utt in enumerate(utterances):
267
+ utt['speaker'] = int(labels[i])
268
+
269
+ # Affinage des assignations
270
+ if quality_metrics['quality'] not in ['error']:
271
+ utterances = self.refine_speaker_assignments(utterances)
272
+ utterances = self.merge_consecutive_same_speaker(utterances)
273
+
274
+ return utterances, quality_metrics
275
+
276
+
277
+ def enhance_diarization_pipeline(embeddings: np.ndarray,
278
+ utterances: List[Dict]) -> Tuple[List[Dict], Dict[str, Any]]:
279
+ """
280
+ Pipeline de diarisation amélioré - fonction principale
281
+
282
+ Args:
283
+ embeddings: Embeddings des segments audio (n_segments, 512)
284
+ utterances: Liste des segments avec transcription
285
+
286
+ Returns:
287
+ (utterances_with_speakers, quality_report)
288
+ """
289
+
290
+ improved_diarizer = ImprovedDiarization()
291
+
292
+ # Diarisation avec contrôle qualité
293
+ diarized_utterances, quality_metrics = improved_diarizer.diarize_with_quality_control(
294
+ embeddings, utterances
295
+ )
296
+
297
+ # Rapport de qualité détaillé
298
+ quality_report = {
299
+ 'success': quality_metrics['quality'] not in ['error', 'poor'],
300
+ 'confidence': 'high' if quality_metrics['quality'] in ['excellent', 'good'] else 'low',
301
+ 'metrics': quality_metrics,
302
+ 'recommendations': []
303
+ }
304
+
305
+ # Recommandations basées sur la qualité
306
+ if quality_metrics['quality'] == 'poor':
307
+ quality_report['recommendations'].append(
308
+ "Consider using single-speaker mode - clustering quality too low"
309
+ )
310
+ elif quality_metrics['silhouette_score'] < 0.3:
311
+ quality_report['recommendations'].append(
312
+ "Low speaker differentiation - verify audio quality"
313
+ )
314
+ elif quality_metrics['cluster_balance'] < 0.2:
315
+ quality_report['recommendations'].append(
316
+ "Unbalanced speaker distribution - check audio content"
317
+ )
318
+
319
+ return diarized_utterances, quality_report