Yermia commited on
Commit
fda93d9
·
verified ·
1 Parent(s): bc1c0d3

Upload 13 files

Browse files
src/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Meeting Transcriber - Sistem Notulensi Rapat Otomatis
3
+ =====================================================
4
+
5
+ Sistem end-to-end untuk mengubah rekaman audio rapat menjadi
6
+ dokumen notulensi terstruktur menggunakan SpeechBrain dan BERT.
7
+
8
+ Modules:
9
+ - config: Konfigurasi sistem
10
+ - audio_processor: Preprocessing audio
11
+ - diarization: Speaker diarization
12
+ - transcriber: ASR transcription
13
+ - summarizer: BERT summarization
14
+ - document_generator: Export ke .docx
15
+ - evaluator: Metrik evaluasi (WER, DER)
16
+ - pipeline: Main orchestrator
17
+ - utils: Utility functions
18
+
19
+ Example:
20
+ >>> from src.pipeline import MeetingTranscriberPipeline
21
+ >>> pipeline = MeetingTranscriberPipeline()
22
+ >>> result = pipeline.process("meeting.wav", title="Team Meeting")
23
+ >>> print(result.document_path)
24
+ """
25
+
26
+ __version__ = "1.0.0"
27
+ __author__ = "Yermia Turangan"
28
+ __email__ = "yermiaturangan026@student.unsrat.ac.id"
29
+
30
+ from src.audio_processor import AudioConfig, AudioProcessor
31
+ from src.config import Config, load_config
32
+ from src.diarization import DiarizationConfig, SpeakerDiarizer, SpeakerSegment
33
+ from src.document_generator import DocumentGenerator, MeetingMetadata
34
+ from src.evaluator import DERResult, Evaluator, WERResult
35
+ from src.pipeline import MeetingTranscriberPipeline, PipelineConfig, PipelineResult
36
+ from src.summarizer import BERTSummarizer, MeetingSummary, SummarizationConfig
37
+ from src.transcriber import ASRConfig, ASRTranscriber, TranscriptSegment
38
+
39
+ __all__ = [
40
+ # Config
41
+ "Config",
42
+ "load_config",
43
+ # Audio
44
+ "AudioProcessor",
45
+ "AudioConfig",
46
+ # Diarization
47
+ "SpeakerDiarizer",
48
+ "DiarizationConfig",
49
+ "SpeakerSegment",
50
+ # ASR
51
+ "ASRTranscriber",
52
+ "ASRConfig",
53
+ "TranscriptSegment",
54
+ # Summarization
55
+ "BERTSummarizer",
56
+ "SummarizationConfig",
57
+ "MeetingSummary",
58
+ # Document
59
+ "DocumentGenerator",
60
+ "MeetingMetadata",
61
+ # Evaluation
62
+ "Evaluator",
63
+ "WERResult",
64
+ "DERResult",
65
+ # Pipeline
66
+ "MeetingTranscriberPipeline",
67
+ "PipelineConfig",
68
+ "PipelineResult",
69
+ ]
src/audio_processor.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio Processor Module
3
+ ======================
4
+ Handles audio loading, preprocessing, and segmentation.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import List, Optional, Tuple, Union
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torchaudio
16
+ from torchaudio.transforms import Resample
17
+
18
+ try:
19
+ import librosa
20
+
21
+ LIBROSA_AVAILABLE = True
22
+ except ImportError:
23
+ LIBROSA_AVAILABLE = False
24
+
25
+
26
+ @dataclass
27
+ class AudioConfig:
28
+ """Configuration for audio processing"""
29
+
30
+ sample_rate: int = 16000
31
+ mono: bool = True
32
+ normalize: bool = True
33
+ trim_silence: bool = False
34
+ silence_threshold_db: float = -40.0
35
+ max_duration_seconds: Optional[float] = None
36
+
37
+
38
+ @dataclass
39
+ class AudioInfo:
40
+ """Information about loaded audio"""
41
+
42
+ path: str
43
+ duration_seconds: float
44
+ sample_rate: int
45
+ num_channels: int
46
+ num_samples: int
47
+
48
+
49
+ class AudioProcessor:
50
+ """
51
+ Handles all audio preprocessing operations.
52
+ Converts input audio to standardized format for downstream processing.
53
+
54
+ Attributes:
55
+ config: AudioConfig object with processing settings
56
+
57
+ Example:
58
+ >>> processor = AudioProcessor()
59
+ >>> waveform, sr = processor.load_audio("meeting.wav")
60
+ >>> print(f"Duration: {processor.get_duration(waveform, sr):.2f}s")
61
+ """
62
+
63
+ SUPPORTED_FORMATS = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".wma", ".aac"}
64
+
65
+ def __init__(self, config: Optional[AudioConfig] = None):
66
+ """
67
+ Initialize AudioProcessor.
68
+
69
+ Args:
70
+ config: AudioConfig object (uses defaults if None)
71
+ """
72
+ self.config = config or AudioConfig()
73
+ self._resampler_cache: dict = {}
74
+
75
+ def load_audio(
76
+ self,
77
+ audio_path: Union[str, Path],
78
+ start_time: Optional[float] = None,
79
+ end_time: Optional[float] = None,
80
+ ) -> Tuple[torch.Tensor, int]:
81
+ """
82
+ Load and preprocess audio file.
83
+
84
+ Args:
85
+ audio_path: Path to audio file
86
+ start_time: Start time in seconds (optional)
87
+ end_time: End time in seconds (optional)
88
+
89
+ Returns:
90
+ Tuple of (waveform tensor [1, T], sample_rate)
91
+
92
+ Raises:
93
+ FileNotFoundError: If audio file doesn't exist
94
+ ValueError: If audio format is not supported
95
+ """
96
+ audio_path = Path(audio_path)
97
+
98
+ # Validate file exists
99
+ if not audio_path.exists():
100
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
101
+
102
+ # Validate format
103
+ if audio_path.suffix.lower() not in self.SUPPORTED_FORMATS:
104
+ raise ValueError(
105
+ f"Unsupported audio format: {audio_path.suffix}. "
106
+ f"Supported formats: {self.SUPPORTED_FORMATS}"
107
+ )
108
+
109
+ # Load audio
110
+ try:
111
+ waveform, orig_sr = torchaudio.load(str(audio_path))
112
+ except Exception as e:
113
+ # Fallback to librosa if torchaudio fails
114
+ if LIBROSA_AVAILABLE:
115
+ try:
116
+ audio_np, orig_sr = librosa.load(str(audio_path), sr=None, mono=False)
117
+ if audio_np.ndim == 1:
118
+ audio_np = audio_np[np.newaxis, :]
119
+ waveform = torch.from_numpy(audio_np).float()
120
+ except Exception:
121
+ # Try pydub (requires ffmpeg) as a robust fallback
122
+ try:
123
+ from pydub import AudioSegment
124
+
125
+ seg = AudioSegment.from_file(str(audio_path))
126
+ orig_sr = seg.frame_rate
127
+ samples = np.array(seg.get_array_of_samples())
128
+
129
+ if seg.channels > 1:
130
+ samples = samples.reshape((-1, seg.channels)).T
131
+ else:
132
+ samples = samples[np.newaxis, :]
133
+
134
+ # Normalize based on sample width
135
+ max_val = float(1 << (8 * seg.sample_width - 1))
136
+ audio_np = samples.astype(np.float32) / max_val
137
+ waveform = torch.from_numpy(audio_np).float()
138
+ except Exception:
139
+ # Try ffmpeg CLI (system binary) to decode to WAV in-memory (no extra Python packages required)
140
+ try:
141
+ import io
142
+ import subprocess
143
+
144
+ import soundfile as sf
145
+
146
+ proc = subprocess.run(
147
+ [
148
+ "ffmpeg",
149
+ "-i",
150
+ str(audio_path),
151
+ "-f",
152
+ "wav",
153
+ "-ar",
154
+ "16000",
155
+ "-ac",
156
+ "1",
157
+ "pipe:1",
158
+ ],
159
+ stdout=subprocess.PIPE,
160
+ stderr=subprocess.DEVNULL,
161
+ check=True,
162
+ )
163
+ out = proc.stdout
164
+
165
+ audio_np, orig_sr = sf.read(io.BytesIO(out), dtype="float32")
166
+ if audio_np.ndim == 1:
167
+ audio_np = audio_np[np.newaxis, :]
168
+ else:
169
+ audio_np = audio_np.T
170
+ waveform = torch.from_numpy(audio_np).float()
171
+ except Exception:
172
+ # Last resort: use ffmpeg-python to decode into WAV bytes and read via soundfile
173
+ try:
174
+ import io
175
+
176
+ import ffmpeg
177
+ import soundfile as sf
178
+
179
+ out, _ = (
180
+ ffmpeg.input(str(audio_path))
181
+ .output("pipe:", format="wav", acodec="pcm_s16le")
182
+ .run(capture_stdout=True, capture_stderr=True)
183
+ )
184
+
185
+ audio_np, orig_sr = sf.read(io.BytesIO(out), dtype="float32")
186
+ if audio_np.ndim == 1:
187
+ audio_np = audio_np[np.newaxis, :]
188
+ else:
189
+ audio_np = audio_np.T
190
+ waveform = torch.from_numpy(audio_np).float()
191
+ except Exception:
192
+ raise RuntimeError(
193
+ "Format file tidak didukung atau backend decoding (ffmpeg) tidak tersedia. "
194
+ "Silakan install ffmpeg (pastikan tersedia di PATH) atau gunakan format WAV/MP3 yang didukung."
195
+ )
196
+ else:
197
+ raise RuntimeError(f"Failed to load audio: {e}")
198
+
199
+ # Trim to time range if specified
200
+ if start_time is not None or end_time is not None:
201
+ waveform = self._trim_to_range(waveform, orig_sr, start_time, end_time)
202
+
203
+ # Convert to mono if needed
204
+ if self.config.mono and waveform.shape[0] > 1:
205
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
206
+
207
+ # Resample if needed
208
+ if orig_sr != self.config.sample_rate:
209
+ waveform = self._resample(waveform, orig_sr, self.config.sample_rate)
210
+
211
+ # Normalize amplitude
212
+ if self.config.normalize:
213
+ waveform = self._normalize(waveform)
214
+
215
+ # Trim silence if requested
216
+ if self.config.trim_silence:
217
+ waveform = self._trim_silence(waveform)
218
+
219
+ # Enforce max duration
220
+ if self.config.max_duration_seconds:
221
+ max_samples = int(self.config.max_duration_seconds * self.config.sample_rate)
222
+ if waveform.shape[-1] > max_samples:
223
+ waveform = waveform[:, :max_samples]
224
+
225
+ return waveform, self.config.sample_rate
226
+
227
+ def get_audio_info(self, audio_path: Union[str, Path]) -> AudioInfo:
228
+ """
229
+ Get information about audio file without loading full waveform.
230
+
231
+ Args:
232
+ audio_path: Path to audio file
233
+
234
+ Returns:
235
+ AudioInfo object with file details
236
+ """
237
+ audio_path = Path(audio_path)
238
+
239
+ if not audio_path.exists():
240
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
241
+
242
+ info = torchaudio.info(str(audio_path))
243
+
244
+ return AudioInfo(
245
+ path=str(audio_path),
246
+ duration_seconds=info.num_frames / info.sample_rate,
247
+ sample_rate=info.sample_rate,
248
+ num_channels=info.num_channels,
249
+ num_samples=info.num_frames,
250
+ )
251
+
252
+ def _trim_to_range(
253
+ self,
254
+ waveform: torch.Tensor,
255
+ sample_rate: int,
256
+ start_time: Optional[float],
257
+ end_time: Optional[float],
258
+ ) -> torch.Tensor:
259
+ """Trim waveform to specified time range"""
260
+ start_sample = int((start_time or 0) * sample_rate)
261
+ end_sample = int((end_time or waveform.shape[-1] / sample_rate) * sample_rate)
262
+
263
+ start_sample = max(0, start_sample)
264
+ end_sample = min(waveform.shape[-1], end_sample)
265
+
266
+ return waveform[:, start_sample:end_sample]
267
+
268
+ def _resample(self, waveform: torch.Tensor, orig_sr: int, target_sr: int) -> torch.Tensor:
269
+ """Resample audio to target sample rate with caching"""
270
+ cache_key = (orig_sr, target_sr)
271
+
272
+ if cache_key not in self._resampler_cache:
273
+ self._resampler_cache[cache_key] = Resample(orig_freq=orig_sr, new_freq=target_sr)
274
+
275
+ return self._resampler_cache[cache_key](waveform)
276
+
277
+ def _normalize(self, waveform: torch.Tensor) -> torch.Tensor:
278
+ """Normalize waveform to [-1, 1] range"""
279
+ max_val = torch.max(torch.abs(waveform))
280
+ if max_val > 0:
281
+ waveform = waveform / max_val
282
+ return waveform
283
+
284
+ def _trim_silence(self, waveform: torch.Tensor) -> torch.Tensor:
285
+ """Remove leading and trailing silence"""
286
+ # Convert threshold from dB to amplitude
287
+ threshold = 10 ** (self.config.silence_threshold_db / 20)
288
+
289
+ # Find non-silent regions
290
+ amplitude = torch.abs(waveform).squeeze()
291
+ non_silent = amplitude > threshold
292
+
293
+ if not non_silent.any():
294
+ return waveform
295
+
296
+ # Find first and last non-silent sample
297
+ non_silent_indices = torch.where(non_silent)[0]
298
+ start_idx = non_silent_indices[0].item()
299
+ end_idx = non_silent_indices[-1].item() + 1
300
+
301
+ return waveform[:, start_idx:end_idx]
302
+
303
+ def get_duration(self, waveform: torch.Tensor, sample_rate: int) -> float:
304
+ """Get duration of waveform in seconds"""
305
+ return waveform.shape[-1] / sample_rate
306
+
307
+ def cut_segment(
308
+ self, waveform: torch.Tensor, start_sec: float, end_sec: float, sample_rate: int
309
+ ) -> torch.Tensor:
310
+ """
311
+ Extract a segment from waveform.
312
+
313
+ Args:
314
+ waveform: Input waveform [C, T]
315
+ start_sec: Start time in seconds
316
+ end_sec: End time in seconds
317
+ sample_rate: Sample rate of waveform
318
+
319
+ Returns:
320
+ Segment waveform [C, t]
321
+ """
322
+ start_sample = int(max(0, start_sec) * sample_rate)
323
+ end_sample = int(min(end_sec * sample_rate, waveform.shape[-1]))
324
+
325
+ return waveform[:, start_sample:end_sample]
326
+
327
+ def split_into_chunks(
328
+ self,
329
+ waveform: torch.Tensor,
330
+ chunk_duration: float,
331
+ overlap: float = 0.0,
332
+ sample_rate: Optional[int] = None,
333
+ ) -> List[Tuple[torch.Tensor, float, float]]:
334
+ """
335
+ Split waveform into overlapping chunks.
336
+
337
+ Args:
338
+ waveform: Input waveform
339
+ chunk_duration: Duration of each chunk in seconds
340
+ overlap: Overlap between chunks in seconds
341
+ sample_rate: Sample rate (uses config if None)
342
+
343
+ Returns:
344
+ List of (chunk_waveform, start_sec, end_sec)
345
+ """
346
+ sample_rate = sample_rate or self.config.sample_rate
347
+ total_duration = self.get_duration(waveform, sample_rate)
348
+
349
+ chunks = []
350
+ start = 0.0
351
+
352
+ while start < total_duration:
353
+ end = min(start + chunk_duration, total_duration)
354
+ chunk = self.cut_segment(waveform, start, end, sample_rate)
355
+ chunks.append((chunk, start, end))
356
+ start += chunk_duration - overlap
357
+
358
+ return chunks
359
+
360
+ def add_noise(
361
+ self, waveform: torch.Tensor, noise_level: float = 0.01, noise_type: str = "gaussian"
362
+ ) -> torch.Tensor:
363
+ """
364
+ Add noise to waveform (for data augmentation).
365
+
366
+ Args:
367
+ waveform: Input waveform
368
+ noise_level: Noise amplitude (0-1)
369
+ noise_type: Type of noise ("gaussian", "uniform")
370
+
371
+ Returns:
372
+ Waveform with added noise
373
+ """
374
+ if noise_type == "gaussian":
375
+ noise = torch.randn_like(waveform) * noise_level
376
+ elif noise_type == "uniform":
377
+ noise = (torch.rand_like(waveform) * 2 - 1) * noise_level
378
+ else:
379
+ raise ValueError(f"Unknown noise type: {noise_type}")
380
+
381
+ return waveform + noise
382
+
383
+ def save_audio(
384
+ self,
385
+ waveform: torch.Tensor,
386
+ output_path: Union[str, Path],
387
+ sample_rate: Optional[int] = None,
388
+ ):
389
+ """
390
+ Save waveform to audio file.
391
+
392
+ Args:
393
+ waveform: Waveform to save
394
+ output_path: Output file path
395
+ sample_rate: Sample rate (uses config if None)
396
+ """
397
+ sample_rate = sample_rate or self.config.sample_rate
398
+ torchaudio.save(str(output_path), waveform, sample_rate)
src/config.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration Module
3
+ ====================
4
+ Handles loading and managing configuration for the entire system.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ from dataclasses import dataclass, field
11
+ from typing import List, Optional
12
+
13
+ import yaml
14
+
15
+
16
+ @dataclass
17
+ class VADConfig:
18
+ """Voice Activity Detection configuration"""
19
+
20
+ threshold: float = 0.5
21
+ min_speech_duration: float = 0.3
22
+ min_silence_duration: float = 0.3
23
+ speech_pad_ms: int = 30
24
+
25
+
26
+ @dataclass
27
+ class SegmentationConfig:
28
+ """Segmentation configuration"""
29
+
30
+ window_duration: float = 1.5
31
+ window_hop: float = 0.75
32
+ min_segment_duration: float = 0.5
33
+
34
+
35
+ @dataclass
36
+ class EmbeddingConfig:
37
+ """Speaker embedding configuration"""
38
+
39
+ model_id: str = "speechbrain/spkrec-ecapa-voxceleb"
40
+ embedding_dim: int = 192
41
+
42
+
43
+ @dataclass
44
+ class ClusteringConfig:
45
+ """Clustering configuration"""
46
+
47
+ method: str = "agglomerative"
48
+ threshold: float = 0.7
49
+ min_cluster_size: int = 2
50
+ linkage: str = "average"
51
+
52
+
53
+ @dataclass
54
+ class AudioConfig:
55
+ """Audio processing configuration"""
56
+
57
+ sample_rate: int = 16000
58
+ mono: bool = True
59
+ normalize: bool = True
60
+ trim_silence: bool = False
61
+ max_duration_minutes: int = 60
62
+
63
+
64
+ @dataclass
65
+ class DiarizationConfig:
66
+ """Speaker diarization configuration"""
67
+
68
+ vad: VADConfig = field(default_factory=VADConfig)
69
+ segmentation: SegmentationConfig = field(default_factory=SegmentationConfig)
70
+ embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
71
+ clustering: ClusteringConfig = field(default_factory=ClusteringConfig)
72
+ merge_gap_threshold: float = 0.5
73
+ min_segment_duration: float = 0.3
74
+ smooth_segments: bool = True
75
+
76
+ # Embedding and collapse options
77
+ use_speechbrain: bool = True
78
+ allow_fallback: bool = False
79
+ collapse_threshold: float = 0.15
80
+ silhouette_collapse_threshold: float = 0.05
81
+
82
+
83
+ @dataclass
84
+ class ASRConfig:
85
+ """ASR configuration"""
86
+
87
+ model_id: str = "indonesian-nlp/wav2vec2-large-xlsr-indonesian"
88
+ chunk_length_s: float = 30.0
89
+ stride_length_s: float = 5.0
90
+ batch_size: int = 4
91
+ return_timestamps: Optional[str] = None
92
+ # Valid values: None (no timestamps), or 'char' / 'word' for CTC timestamp modes
93
+ capitalize_sentences: bool = True
94
+ normalize_whitespace: bool = True
95
+
96
+
97
+ @dataclass
98
+ class SummarizationConfig:
99
+ """Summarization configuration"""
100
+
101
+ model_id: str = "indobenchmark/indobert-base-p1"
102
+ sentence_model_id: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
103
+ num_sentences: int = 5
104
+ min_sentence_length: int = 10
105
+ max_sentence_length: int = 200
106
+ position_weight: float = 0.1
107
+ decision_keywords: List[str] = field(
108
+ default_factory=lambda: [
109
+ "diputuskan",
110
+ "disepakati",
111
+ "kesimpulan",
112
+ "keputusan",
113
+ "jadi",
114
+ "maka",
115
+ "sepakat",
116
+ "setuju",
117
+ "final",
118
+ ]
119
+ )
120
+ action_keywords: List[str] = field(
121
+ default_factory=lambda: [
122
+ "akan",
123
+ "harus",
124
+ "perlu",
125
+ "tolong",
126
+ "mohon",
127
+ "deadline",
128
+ "target",
129
+ "tugas",
130
+ "tanggung jawab",
131
+ "action item",
132
+ "follow up",
133
+ "tindak lanjut",
134
+ ]
135
+ )
136
+
137
+
138
+ @dataclass
139
+ class DocumentConfig:
140
+ """Document generation configuration"""
141
+
142
+ template: str = "default"
143
+ title_font_size: int = 18
144
+ heading_font_size: int = 14
145
+ body_font_size: int = 11
146
+ font_family: str = "Calibri"
147
+ include_timestamps: bool = True
148
+ include_speaker_colors: bool = True
149
+
150
+
151
+ @dataclass
152
+ class EvaluationConfig:
153
+ """Evaluation configuration"""
154
+
155
+ wer_lowercase: bool = True
156
+ wer_remove_punctuation: bool = True
157
+ der_collar: float = 0.25
158
+ der_skip_overlap: bool = False
159
+
160
+
161
+ @dataclass
162
+ class PathsConfig:
163
+ """Paths configuration"""
164
+
165
+ models_dir: str = "./models"
166
+ audio_dir: str = "./data/audio"
167
+ ground_truth_dir: str = "./data/ground_truth"
168
+ output_dir: str = "./data/output"
169
+ cache_dir: str = "./cache"
170
+ logs_dir: str = "./logs"
171
+
172
+
173
+ @dataclass
174
+ class Config:
175
+ """Main configuration class"""
176
+
177
+ audio: AudioConfig = field(default_factory=AudioConfig)
178
+ diarization: DiarizationConfig = field(default_factory=DiarizationConfig)
179
+ asr: ASRConfig = field(default_factory=ASRConfig)
180
+ summarization: SummarizationConfig = field(default_factory=SummarizationConfig)
181
+ document: DocumentConfig = field(default_factory=DocumentConfig)
182
+ evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
183
+ paths: PathsConfig = field(default_factory=PathsConfig)
184
+ device: str = "auto"
185
+ verbose: bool = True
186
+
187
+ def __post_init__(self):
188
+ """Create directories if they don't exist"""
189
+ for path_attr in [
190
+ "models_dir",
191
+ "audio_dir",
192
+ "ground_truth_dir",
193
+ "output_dir",
194
+ "cache_dir",
195
+ "logs_dir",
196
+ ]:
197
+ path = getattr(self.paths, path_attr)
198
+ os.makedirs(path, exist_ok=True)
199
+
200
+
201
+ def load_config(config_path: str = "config.yaml") -> Config:
202
+ """
203
+ Load configuration from YAML file.
204
+
205
+ Args:
206
+ config_path: Path to config.yaml file
207
+
208
+ Returns:
209
+ Config object with loaded settings
210
+ """
211
+ config = Config()
212
+
213
+ if os.path.exists(config_path):
214
+ with open(config_path, "r", encoding="utf-8") as f:
215
+ yaml_config = yaml.safe_load(f)
216
+
217
+ if yaml_config:
218
+ # Update audio config
219
+ if "audio" in yaml_config:
220
+ for key, value in yaml_config["audio"].items():
221
+ if hasattr(config.audio, key):
222
+ setattr(config.audio, key, value)
223
+
224
+ # Update ASR config
225
+ if "asr" in yaml_config:
226
+ for key, value in yaml_config["asr"].items():
227
+ if hasattr(config.asr, key):
228
+ setattr(config.asr, key, value)
229
+
230
+ # Update summarization config
231
+ if "summarization" in yaml_config:
232
+ for key, value in yaml_config["summarization"].items():
233
+ if hasattr(config.summarization, key):
234
+ setattr(config.summarization, key, value)
235
+
236
+ # Update paths config
237
+ if "paths" in yaml_config:
238
+ for key, value in yaml_config["paths"].items():
239
+ if hasattr(config.paths, key):
240
+ setattr(config.paths, key, value)
241
+
242
+ # Update device
243
+ if "hardware" in yaml_config and "device" in yaml_config["hardware"]:
244
+ config.device = yaml_config["hardware"]["device"]
245
+
246
+ return config
247
+
248
+
249
+ def save_config(config: Config, config_path: str = "config.yaml"):
250
+ """
251
+ Save configuration to YAML file.
252
+
253
+ Args:
254
+ config: Config object to save
255
+ config_path: Path to save config.yaml
256
+ """
257
+ # Convert dataclass to dict
258
+ config_dict = {
259
+ "audio": config.audio.__dict__,
260
+ "asr": config.asr.__dict__,
261
+ "summarization": {
262
+ k: v for k, v in config.summarization.__dict__.items() if not k.endswith("_keywords")
263
+ },
264
+ "document": config.document.__dict__,
265
+ "evaluation": config.evaluation.__dict__,
266
+ "paths": config.paths.__dict__,
267
+ "hardware": {"device": config.device},
268
+ }
269
+
270
+ with open(config_path, "w", encoding="utf-8") as f:
271
+ yaml.dump(config_dict, f, default_flow_style=False, allow_unicode=True)
src/diarization.py ADDED
@@ -0,0 +1,1504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speaker Diarization Module
3
+ ==========================
4
+ Implements VAD + Speaker Embedding + Clustering pipeline for speaker diarization.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from dataclasses import dataclass, field
11
+ from pathlib import Path
12
+ from typing import Any, Dict, List, Optional, Tuple
13
+
14
+ import numpy as np
15
+ import torch
16
+ from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering
17
+ from sklearn.metrics import silhouette_score
18
+ from sklearn.preprocessing import StandardScaler
19
+
20
+ from src.utils import setup_logger
21
+
22
+
23
+ @dataclass
24
+ class DiarizationConfig:
25
+ """Configuration for speaker diarization"""
26
+
27
+ # VAD settings
28
+ vad_threshold: float = 0.5
29
+ min_speech_duration: float = 0.3
30
+ min_silence_duration: float = 0.3
31
+
32
+ # Segmentation settings
33
+ segment_window: float = 1.5
34
+ segment_hop: float = 0.75
35
+
36
+ # Clustering settings
37
+ clustering_method: str = "agglomerative"
38
+ clustering_threshold: float = 0.7
39
+ min_cluster_size: int = 2
40
+ max_speakers: Optional[int] = None
41
+
42
+ # Post-processing
43
+ merge_gap_threshold: float = 0.5
44
+ min_segment_duration: float = 0.3
45
+
46
+ # Model settings
47
+ embedding_model_id: str = "speechbrain/spkrec-ecapa-voxceleb"
48
+ use_speechbrain: bool = True # prefer SpeechBrain embeddings
49
+ allow_fallback: bool = False # if False, raise an error when SpeechBrain cannot be loaded
50
+
51
+ # Collapse heuristics
52
+ collapse_threshold: float = 0.15
53
+ # When negative, do not automatically collapse clusters to a single speaker based on silhouette.
54
+ silhouette_collapse_threshold: float = -1.0
55
+
56
+ # Iterative merging (centroid-based)
57
+ iterative_merge_threshold: float = 0.15
58
+ iterative_merge_silhouette_threshold: float = 0.0
59
+ iterative_merge_max_iters: int = 10
60
+
61
+ # Performance tuning
62
+ embedding_batch_size: int = 32
63
+ embedding_cache: bool = True # write/load embedding arrays to cache_dir
64
+ use_fast_embedding: bool = False # use MFCC deterministic embeddings for speed
65
+
66
+ # Optional: target speaker count - if set, clusters will be greedily merged to meet target
67
+ target_num_speakers: Optional[int] = None
68
+ target_force_threshold: float = (
69
+ 1.0 # 1.0 => allow merges regardless of distance; lower = more conservative
70
+ )
71
+
72
+ # Device
73
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
74
+
75
+
76
+ @dataclass
77
+ class SpeakerSegment:
78
+ """Represents a speaker segment with timing and metadata"""
79
+
80
+ speaker_id: str
81
+ start: float
82
+ end: float
83
+ confidence: float = 1.0
84
+ is_overlap: bool = False
85
+ embedding: Optional[np.ndarray] = None
86
+ metadata: Dict[str, Any] = field(default_factory=dict)
87
+
88
+ @property
89
+ def duration(self) -> float:
90
+ """Get segment duration in seconds"""
91
+ return self.end - self.start
92
+
93
+ def to_dict(self) -> Dict[str, Any]:
94
+ """Convert to dictionary"""
95
+ return {
96
+ "speaker_id": self.speaker_id,
97
+ "start": self.start,
98
+ "end": self.end,
99
+ "confidence": self.confidence,
100
+ "is_overlap": self.is_overlap,
101
+ "duration": self.duration,
102
+ }
103
+
104
+
105
+ class SpeakerDiarizer:
106
+ """
107
+ Speaker Diarization using SpeechBrain ECAPA-TDNN embeddings.
108
+
109
+ Pipeline:
110
+ 1. Voice Activity Detection (VAD)
111
+ 2. Audio segmentation into windows
112
+ 3. Speaker embedding extraction (ECAPA-TDNN)
113
+ 4. Clustering to assign speaker labels
114
+ 5. Post-processing (merging, smoothing)
115
+
116
+ Attributes:
117
+ config: DiarizationConfig object
118
+
119
+ Example:
120
+ >>> diarizer = SpeakerDiarizer()
121
+ >>> segments = diarizer.process(waveform, sample_rate=16000, num_speakers=4)
122
+ >>> for seg in segments:
123
+ ... print(f"{seg.speaker_id}: {seg.start:.2f}s - {seg.end:.2f}s")
124
+ """
125
+
126
+ def __init__(self, config: Optional[DiarizationConfig] = None, models_dir: str = "./models"):
127
+ """
128
+ Initialize SpeakerDiarizer.
129
+
130
+ Args:
131
+ config: DiarizationConfig object
132
+ models_dir: Directory to cache downloaded models
133
+ """
134
+ self.config = config or DiarizationConfig()
135
+ self.models_dir = Path(models_dir)
136
+ self.models_dir.mkdir(parents=True, exist_ok=True)
137
+
138
+ self.device = self.config.device
139
+
140
+ # Setup logger
141
+ self.logger = setup_logger("SpeakerDiarizer")
142
+
143
+ # Model placeholders (lazy loading)
144
+ self._embedding_model = None
145
+ self._vad_model = None
146
+ self._embedding_model_is_speechbrain = False
147
+
148
+ def _load_embedding_model(self):
149
+ """Lazy load speaker embedding model
150
+
151
+ This function will attempt to patch missing torchaudio APIs (e.g., list_audio_backends)
152
+ so that SpeechBrain imports cleanly on environments with older torchaudio builds.
153
+ """
154
+ if self._embedding_model is None:
155
+ # Shim torchaudio compatibility if needed (some torchaudio versions lack list_audio_backends)
156
+ try:
157
+ import importlib
158
+
159
+ if importlib.util.find_spec("torchaudio"):
160
+ import torchaudio
161
+
162
+ if not hasattr(torchaudio, "list_audio_backends"):
163
+
164
+ def _list_audio_backends():
165
+ # best-effort guess of available backends; not exhaustive
166
+ backends = []
167
+ try:
168
+ # prefer sox_io and soundfile as common options
169
+ backends.append("sox_io")
170
+ except Exception:
171
+ pass
172
+ try:
173
+ backends.append("soundfile")
174
+ except Exception:
175
+ pass
176
+ if not backends:
177
+ backends = ["sox_io"]
178
+ return backends
179
+
180
+ torchaudio.list_audio_backends = _list_audio_backends
181
+
182
+ if not hasattr(torchaudio, "get_audio_backend"):
183
+ torchaudio.get_audio_backend = lambda: torchaudio.list_audio_backends()[0]
184
+ except Exception:
185
+ # best-effort only, don't prevent embedding loading attempt
186
+ pass
187
+
188
+ try:
189
+ from speechbrain.inference.speaker import EncoderClassifier
190
+
191
+ self.logger.info(f"Loading embedding model: {self.config.embedding_model_id}")
192
+
193
+ import os
194
+
195
+ # Prefer to disable HF symlinks up-front on Windows to prevent permission errors
196
+ os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1")
197
+
198
+ # Try a robust direct download into a local models directory to avoid symlinks entirely
199
+ dest_dir = str(self.models_dir / self.config.embedding_model_id.replace("/", "_"))
200
+ try:
201
+ from huggingface_hub import snapshot_download
202
+
203
+ self.logger.info(
204
+ f"Attempting to snapshot_download model to local dir {dest_dir} (no symlinks)"
205
+ )
206
+ os.makedirs(dest_dir, exist_ok=True)
207
+ snapshot_download(
208
+ repo_id=self.config.embedding_model_id,
209
+ local_dir=dest_dir,
210
+ local_dir_use_symlinks=False,
211
+ )
212
+ # Try to load from the locally downloaded snapshot
213
+ try:
214
+ self._embedding_model = EncoderClassifier.from_hparams(
215
+ source=dest_dir,
216
+ savedir=dest_dir,
217
+ run_opts={"device": self.device},
218
+ )
219
+ self.logger.info("Embedding model loaded successfully from local snapshot")
220
+ # mark that we used speechbrain
221
+ self._embedding_model_is_speechbrain = True
222
+ return
223
+ except Exception as e_local:
224
+ self.logger.warning(f"Local snapshot load failed: {e_local}")
225
+ except Exception:
226
+ # snapshot_download not available or failed; continue with other strategies
227
+ pass
228
+
229
+ try:
230
+ # First try: load directly from hf cache (no savedir) - this typically avoids writing symlinks
231
+ self._embedding_model = EncoderClassifier.from_hparams(
232
+ source=self.config.embedding_model_id,
233
+ run_opts={"device": self.device},
234
+ )
235
+ self.logger.info("Embedding model loaded successfully (from HF cache)")
236
+ self._embedding_model_is_speechbrain = True
237
+ return
238
+ except Exception as e:
239
+ err_msg = str(e)
240
+
241
+ # Detect Windows symlink permission error and retry with savedir + disabled symlink env
242
+ if (
243
+ ("A required privilege" in err_msg)
244
+ or ("symlink" in err_msg.lower())
245
+ or getattr(e, "winerror", None) == 1314
246
+ ):
247
+ try:
248
+ os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1"
249
+ self.logger.warning(
250
+ "Detected symlink/permission issue; retrying model load with HF_HUB_DISABLE_SYMLINKS=1 and specifying savedir"
251
+ )
252
+ self._embedding_model = EncoderClassifier.from_hparams(
253
+ source=self.config.embedding_model_id,
254
+ savedir=str(self.models_dir / "spkrec-ecapa"),
255
+ run_opts={"device": self.device},
256
+ )
257
+ self.logger.info(
258
+ "Embedding model loaded successfully (after disabling symlinks)"
259
+ )
260
+ self._embedding_model_is_speechbrain = True
261
+ return
262
+ except Exception:
263
+ # Try monkeypatching SB fetch to use COPY
264
+ try:
265
+ import speechbrain.utils.fetching as sbfetch
266
+
267
+ orig_fetch = sbfetch.fetch
268
+
269
+ def _fetch_copy(*args, **kwargs):
270
+ kwargs.setdefault("local_strategy", sbfetch.LocalStrategy.COPY)
271
+ return orig_fetch(*args, **kwargs)
272
+
273
+ sbfetch.fetch = _fetch_copy
274
+ self.logger.info(
275
+ "Retrying model load with SpeechBrain fetch set to COPY strategy"
276
+ )
277
+ self._embedding_model = EncoderClassifier.from_hparams(
278
+ source=self.config.embedding_model_id,
279
+ savedir=str(self.models_dir / "spkrec-ecapa"),
280
+ run_opts={"device": self.device},
281
+ )
282
+ self.logger.info(
283
+ "Embedding model loaded successfully (after switching fetch strategy)"
284
+ )
285
+ self._embedding_model_is_speechbrain = True
286
+ return
287
+ except Exception as e3:
288
+ err_msg = str(e3)
289
+ finally:
290
+ try:
291
+ sbfetch.fetch = orig_fetch
292
+ except Exception:
293
+ pass
294
+
295
+ self.logger.error(f"Failed to load SpeechBrain embedding model: {err_msg}")
296
+
297
+ # Try to salvage by copying an existing cached snapshot or downloading directly into dest_dir
298
+ try:
299
+ import re
300
+ import shutil
301
+
302
+ m = re.search(r"'([^']+)'\s*->\s*'([^']+)'", err_msg)
303
+ if m:
304
+ src_file = m.group(1)
305
+ src_dir = os.path.dirname(src_file)
306
+ self.logger.info(
307
+ f"Attempting to copy cached snapshot from {src_dir} to {dest_dir}"
308
+ )
309
+ shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
310
+
311
+ # Retry loading from the local copied directory
312
+ try:
313
+ self._embedding_model = EncoderClassifier.from_hparams(
314
+ source=dest_dir,
315
+ savedir=dest_dir,
316
+ run_opts={"device": self.device},
317
+ )
318
+ self.logger.info(
319
+ "Embedding model loaded successfully (after copying cached snapshot)"
320
+ )
321
+ self._embedding_model_is_speechbrain = True
322
+ return
323
+ except Exception as e4:
324
+ err_msg = str(e4)
325
+
326
+ # As a last resort, try to download model files directly into dest_dir using huggingface_hub APIs
327
+ from huggingface_hub import hf_hub_download, list_repo_files
328
+
329
+ self.logger.info(
330
+ f"Attempting direct HF download into {dest_dir} to avoid symlinks"
331
+ )
332
+ os.makedirs(dest_dir, exist_ok=True)
333
+ files = list_repo_files(self.config.embedding_model_id)
334
+ for fname in files:
335
+ if fname.endswith("/"):
336
+ continue
337
+ hf_hub_download(
338
+ repo_id=self.config.embedding_model_id,
339
+ filename=fname,
340
+ local_dir=dest_dir,
341
+ local_dir_use_symlinks=False,
342
+ )
343
+
344
+ # Retry loading now that files are locally present
345
+ self._embedding_model = EncoderClassifier.from_hparams(
346
+ source=dest_dir,
347
+ savedir=dest_dir,
348
+ run_opts={"device": self.device},
349
+ )
350
+ self.logger.info(
351
+ "Embedding model loaded successfully (after direct HF download)"
352
+ )
353
+ self._embedding_model_is_speechbrain = True
354
+ return
355
+ except Exception as e5:
356
+ err_msg = str(e5)
357
+
358
+ self.logger.warning(
359
+ "Common fixes: install a compatible torchaudio (matching your PyTorch), and install 'soundfile' or enable 'sox_io' backend."
360
+ )
361
+
362
+ # If user allows fallback, provide MFCC fallback; otherwise raise an error to enforce SpeechBrain usage
363
+ if getattr(self.config, "allow_fallback", False):
364
+ self.logger.warning(
365
+ "Falling back to MFCC-based deterministic embeddings (will be less accurate)."
366
+ )
367
+ self._embedding_model = "FALLBACK"
368
+ self._fallback_extractor = self._mfcc_embedding
369
+ return
370
+ else:
371
+ raise RuntimeError(
372
+ "Failed to load SpeechBrain embedding model and 'allow_fallback' is False. "
373
+ "Ensure torchaudio and speechbrain are installed, or set 'allow_fallback=True' in DiarizationConfig."
374
+ )
375
+ except Exception:
376
+ # Import of SpeechBrain failed entirely; honor allow_fallback setting
377
+ self.logger.warning(
378
+ "Could not import SpeechBrain; checking 'allow_fallback' setting"
379
+ )
380
+ if getattr(self.config, "allow_fallback", False):
381
+ self.logger.warning(
382
+ "Falling back to MFCC-based deterministic embeddings (allow_fallback=True)"
383
+ )
384
+ self._embedding_model = "FALLBACK"
385
+ self._fallback_extractor = self._mfcc_embedding
386
+ else:
387
+ raise RuntimeError(
388
+ "Failed to import or initialize SpeechBrain embedding model and 'allow_fallback' is False. "
389
+ "Install SpeechBrain or set 'allow_fallback=True' in DiarizationConfig to allow deterministic fallback."
390
+ )
391
+
392
+ def _mfcc_embedding(
393
+ self, segment_np: np.ndarray, sample_rate: int, target_dim: int = 192
394
+ ) -> np.ndarray:
395
+ """Compute a deterministic embedding from audio segment using MFCCs.
396
+
397
+ Falls back to simple waveform statistics if librosa is not available.
398
+ Returns a fixed-size vector of length `target_dim`.
399
+ """
400
+ try:
401
+ import librosa
402
+
403
+ mfcc = librosa.feature.mfcc(y=segment_np, sr=sample_rate, n_mfcc=40)
404
+ mfcc_mean = mfcc.mean(axis=1)
405
+ mfcc_std = mfcc.std(axis=1)
406
+ vec = np.concatenate([mfcc_mean, mfcc_std])
407
+ except Exception:
408
+ # Minimal deterministic fallback: use downsampled waveform statistics + spectral centroid approximation
409
+ vec = []
410
+ vec.append(np.mean(segment_np))
411
+ vec.append(np.std(segment_np))
412
+ # simple spectral centroid proxy
413
+ freqs = np.fft.rfftfreq(len(segment_np), d=1.0 / sample_rate)
414
+ spec = np.abs(np.fft.rfft(segment_np))
415
+ if spec.sum() > 0:
416
+ centroid = float((freqs * spec).sum() / spec.sum()) / (sample_rate / 2)
417
+ else:
418
+ centroid = 0.0
419
+ vec.append(centroid)
420
+ vec = np.array(vec, dtype=float)
421
+
422
+ # Pad or trim to target_dim
423
+ if len(vec) < target_dim:
424
+ padded = np.zeros(target_dim, dtype=float)
425
+ padded[: len(vec)] = vec
426
+ vec = padded
427
+ elif len(vec) > target_dim:
428
+ vec = vec[:target_dim]
429
+
430
+ # normalize
431
+ norm = np.linalg.norm(vec) + 1e-12
432
+ return (vec / norm).astype(np.float32)
433
+
434
+ def process(
435
+ self,
436
+ waveform: torch.Tensor,
437
+ sample_rate: int = 16000,
438
+ num_speakers: Optional[int] = None,
439
+ cache_dir: Optional[str] = None,
440
+ audio_id: Optional[str] = None,
441
+ fast_mode: bool = False,
442
+ ) -> List[SpeakerSegment]:
443
+ """
444
+ Main diarization pipeline.
445
+
446
+ Args:
447
+ waveform: Audio waveform [1, T]
448
+ sample_rate: Audio sample rate
449
+ num_speakers: Known number of speakers (auto-detect if None)
450
+
451
+ Returns:
452
+ List of SpeakerSegment with speaker assignments
453
+ """
454
+ self._load_embedding_model()
455
+
456
+ # Step 1: Voice Activity Detection
457
+ speech_regions = self._detect_speech(waveform, sample_rate)
458
+
459
+ if not speech_regions:
460
+ self.logger.warning("No speech detected in audio")
461
+ return []
462
+
463
+ self.logger.info(f"Detected {len(speech_regions)} speech regions")
464
+
465
+ # Step 2: Create analysis windows
466
+ windows = self._create_windows(speech_regions)
467
+
468
+ if not windows:
469
+ self.logger.warning("No valid windows created")
470
+ return []
471
+
472
+ self.logger.info(f"Created {len(windows)} analysis windows")
473
+
474
+ # Step 3: Extract speaker embeddings
475
+ embeddings = self._extract_embeddings(waveform, windows, sample_rate)
476
+
477
+ self.logger.info(f"Extracted embeddings with shape: {embeddings.shape}")
478
+
479
+ # Step 4: Cluster embeddings
480
+ labels = self._cluster_embeddings(
481
+ embeddings, num_speakers=num_speakers or self.config.max_speakers
482
+ )
483
+
484
+ num_speakers_found = len(set(labels))
485
+ self.logger.info(f"Found {num_speakers_found} speakers")
486
+
487
+ # Step 5: Create segments from windows and labels
488
+ raw_segments = self._create_segments(windows, labels, embeddings)
489
+
490
+ # Step 6: Post-processing
491
+ processed_segments = self._postprocess_segments(raw_segments)
492
+
493
+ # Step 7: Detect overlapping speech
494
+ processed_segments = self._detect_overlaps(processed_segments)
495
+
496
+ self.logger.info(f"Final: {len(processed_segments)} segments")
497
+
498
+ return processed_segments
499
+
500
+ def auto_tune(
501
+ self, waveform: torch.Tensor, sample_rate: int = 16000, num_speakers: Optional[int] = None
502
+ ) -> dict:
503
+ """Auto-tune clustering-related hyperparameters by searching simple parameter grid.
504
+
505
+ This method extracts embeddings and tries different clustering thresholds and
506
+ minimum cluster sizes, scoring candidates by silhouette score (and closeness
507
+ to `num_speakers` if provided). The best parameter set is applied to
508
+ `self.config` and returned for inspection.
509
+ """
510
+ # Quick extraction path
511
+ speech_regions = self._detect_speech(waveform, sample_rate)
512
+ if not speech_regions:
513
+ self.logger.warning("Auto-tune: no speech regions detected; aborting tuning")
514
+ return {}
515
+
516
+ windows = self._create_windows(speech_regions)
517
+ if not windows:
518
+ self.logger.warning("Auto-tune: no analysis windows created; aborting tuning")
519
+ return {}
520
+
521
+ embeddings = self._extract_embeddings(waveform, windows, sample_rate)
522
+ if embeddings is None or len(embeddings) < 4:
523
+ self.logger.warning("Auto-tune: insufficient embeddings for tuning; aborting tuning")
524
+ return {}
525
+
526
+ # Parameter grid (coarse)
527
+ clustering_thresholds = [0.95, 0.85, 0.7, 0.5, 0.3, 0.15]
528
+ min_cluster_sizes = [1, 2, 3, 4]
529
+
530
+ best_score = -1e9
531
+ best_params = {
532
+ "clustering_threshold": self.config.clustering_threshold,
533
+ "min_cluster_size": self.config.min_cluster_size,
534
+ "iterative_merge_threshold": self.config.iterative_merge_threshold,
535
+ }
536
+
537
+ # Save original values to restore if needed
538
+ orig_threshold = self.config.clustering_threshold
539
+ orig_min_size = self.config.min_cluster_size
540
+ orig_iter_thresh = self.config.iterative_merge_threshold
541
+
542
+ try:
543
+ for thr in clustering_thresholds:
544
+ for msize in min_cluster_sizes:
545
+ # Temporarily set
546
+ self.config.clustering_threshold = thr
547
+ self.config.min_cluster_size = msize
548
+
549
+ try:
550
+ labels = self._cluster_embeddings(embeddings, num_speakers=None)
551
+ k = len(np.unique(labels))
552
+ if k <= 1:
553
+ sil = 0.0
554
+ else:
555
+ try:
556
+ sil = silhouette_score(embeddings, labels, metric="cosine")
557
+ except Exception:
558
+ sil = 0.0
559
+
560
+ # Scoring: prefer higher silhouette and closeness to desired num_speakers
561
+ score = sil
562
+ if num_speakers is not None:
563
+ score -= 0.1 * abs(k - num_speakers)
564
+ # small penalty for many clusters
565
+ score -= 0.02 * k
566
+
567
+ self.logger.debug(
568
+ f"Auto-tune candidate: thr={thr}, min_size={msize} -> k={k}, sil={sil:.4f}, score={score:.4f}"
569
+ )
570
+
571
+ if score > best_score:
572
+ best_score = score
573
+ best_params = {
574
+ "clustering_threshold": thr,
575
+ "min_cluster_size": msize,
576
+ "achieved_k": k,
577
+ "silhouette": sil,
578
+ }
579
+ except Exception as e:
580
+ self.logger.debug(f"Auto-tune candidate failed: {e}")
581
+ continue
582
+
583
+ # Apply best params
584
+ self.config.clustering_threshold = float(
585
+ best_params.get("clustering_threshold", orig_threshold)
586
+ )
587
+ self.config.min_cluster_size = int(best_params.get("min_cluster_size", orig_min_size))
588
+ # If a desired num_speakers was provided, set target merge accordingly
589
+ if num_speakers is not None:
590
+ self.config.target_num_speakers = int(num_speakers)
591
+
592
+ self.logger.info(f"Auto-tune selected: {best_params}")
593
+ return best_params
594
+ finally:
595
+ # nothing to restore; we've intentionally applied best params
596
+ pass
597
+
598
+ def _detect_speech(self, waveform: torch.Tensor, sample_rate: int) -> List[Tuple[float, float]]:
599
+ """
600
+ Detect speech regions using energy-based VAD.
601
+
602
+ Args:
603
+ waveform: Audio waveform
604
+ sample_rate: Sample rate
605
+
606
+ Returns:
607
+ List of (start, end) tuples for speech regions
608
+ """
609
+ waveform_np = waveform.squeeze().cpu().numpy()
610
+
611
+ # Frame parameters
612
+ frame_length_ms = 25 # 25ms frames
613
+ hop_length_ms = 10 # 10ms hop
614
+
615
+ frame_length = int(frame_length_ms * sample_rate / 1000)
616
+ hop_length = int(hop_length_ms * sample_rate / 1000)
617
+
618
+ # Calculate energy per frame
619
+ num_frames = max(1, 1 + (len(waveform_np) - frame_length) // hop_length)
620
+ energies = np.zeros(num_frames)
621
+
622
+ for i in range(num_frames):
623
+ start_idx = i * hop_length
624
+ end_idx = min(start_idx + frame_length, len(waveform_np))
625
+ frame = waveform_np[start_idx:end_idx]
626
+
627
+ if len(frame) > 0:
628
+ energies[i] = np.sqrt(np.mean(frame**2) + 1e-10)
629
+
630
+ # Compute adaptive threshold
631
+ if len(energies) > 0:
632
+ energy_sorted = np.sort(energies)
633
+ # Use 30th percentile as noise floor estimate
634
+ noise_floor = energy_sorted[int(0.3 * len(energy_sorted))]
635
+ threshold = noise_floor + self.config.vad_threshold * np.std(energies)
636
+ else:
637
+ threshold = self.config.vad_threshold
638
+
639
+ # Find speech regions
640
+ is_speech = energies > threshold
641
+
642
+ # Apply morphological operations to smooth
643
+ # (simple dilation and erosion using convolution)
644
+ kernel_size = max(1, int(self.config.min_speech_duration * 1000 / hop_length_ms))
645
+
646
+ if kernel_size > 1 and len(is_speech) > kernel_size:
647
+ # Simple smoothing
648
+ kernel = np.ones(kernel_size) / kernel_size
649
+ smoothed = np.convolve(is_speech.astype(float), kernel, mode="same")
650
+ is_speech = smoothed > 0.5
651
+
652
+ # Convert to time regions
653
+ regions = []
654
+ in_speech = False
655
+ speech_start = 0.0
656
+
657
+ for i, speech in enumerate(is_speech):
658
+ time = i * hop_length / sample_rate
659
+
660
+ if speech and not in_speech:
661
+ speech_start = time
662
+ in_speech = True
663
+ elif not speech and in_speech:
664
+ duration = time - speech_start
665
+ if duration >= self.config.min_speech_duration:
666
+ regions.append((speech_start, time))
667
+ in_speech = False
668
+
669
+ # Handle last region
670
+ if in_speech:
671
+ end_time = len(waveform_np) / sample_rate
672
+ duration = end_time - speech_start
673
+ if duration >= self.config.min_speech_duration:
674
+ regions.append((speech_start, end_time))
675
+
676
+ # Merge nearby regions
677
+ regions = self._merge_nearby_regions(regions, self.config.min_silence_duration)
678
+
679
+ return regions
680
+
681
+ def _merge_nearby_regions(
682
+ self, regions: List[Tuple[float, float]], min_gap: float
683
+ ) -> List[Tuple[float, float]]:
684
+ """Merge regions that are close together"""
685
+ if not regions:
686
+ return []
687
+
688
+ merged = [regions[0]]
689
+
690
+ for start, end in regions[1:]:
691
+ last_start, last_end = merged[-1]
692
+
693
+ if start - last_end <= min_gap:
694
+ merged[-1] = (last_start, end)
695
+ else:
696
+ merged.append((start, end))
697
+
698
+ return merged
699
+
700
+ def _create_windows(
701
+ self, speech_regions: List[Tuple[float, float]]
702
+ ) -> List[Tuple[float, float]]:
703
+ """Create sliding windows over speech regions for embedding extraction"""
704
+ windows = []
705
+
706
+ for region_start, region_end in speech_regions:
707
+ t = region_start
708
+
709
+ while t < region_end:
710
+ window_end = min(t + self.config.segment_window, region_end)
711
+
712
+ # Only include windows with sufficient duration
713
+ if (window_end - t) >= self.config.min_segment_duration:
714
+ # Avoid creating too many tiny windows across short recordings
715
+ if (region_end - region_start) < (self.config.segment_window * 2):
716
+ # for short regions, use a single window covering the region
717
+ windows.append((region_start, region_end))
718
+ break
719
+ windows.append((t, window_end))
720
+
721
+ t += self.config.segment_hop
722
+
723
+ return windows
724
+
725
+ def _extract_embeddings(
726
+ self,
727
+ waveform: torch.Tensor,
728
+ windows: List[Tuple[float, float]],
729
+ sample_rate: int,
730
+ cache_dir: Optional[str] = None,
731
+ audio_id: Optional[str] = None,
732
+ fast_mode: bool = False,
733
+ ) -> np.ndarray:
734
+ """Extract speaker embeddings for each window.
735
+
736
+ Optimizations implemented:
737
+ - Disk cache (if enabled in config and cache_dir provided)
738
+ - Batch extraction using model's batch API when available
739
+ - Fast MFCC embedding path when `use_fast_embedding` is True
740
+ """
741
+ # Try disk cache first
742
+ if (
743
+ cache_dir
744
+ and audio_id
745
+ and self.config.embedding_cache
746
+ and getattr(self.config, "embedding_cache", True)
747
+ ):
748
+ try:
749
+ import os
750
+
751
+ cache_path = Path(cache_dir) / f"{audio_id}_embeddings.npy"
752
+ if cache_path.exists():
753
+ arr = np.load(str(cache_path))
754
+ if arr.shape[0] == len(windows):
755
+ self.logger.info(f"Loaded embeddings from cache: {cache_path}")
756
+ return arr
757
+ except Exception:
758
+ pass
759
+
760
+ n = len(windows)
761
+ embeddings = [None] * n
762
+
763
+ # If fallback or user requested fast embedding, compute MFCC-based embeddings vectorized
764
+ if (
765
+ (self._embedding_model == "FALLBACK" or self._embedding_model is None)
766
+ or getattr(self.config, "use_fast_embedding", False)
767
+ or fast_mode
768
+ ):
769
+ for i, (start, end) in enumerate(windows):
770
+ start_sample = int(start * sample_rate)
771
+ end_sample = int(end * sample_rate)
772
+ segment = waveform[:, start_sample:end_sample]
773
+ try:
774
+ seg_np = segment.squeeze().cpu().numpy()
775
+ emb = self._fallback_extractor(seg_np, sample_rate)
776
+ except Exception:
777
+ seg_np = segment.squeeze().cpu().numpy()
778
+ emb = self._mfcc_embedding(seg_np, sample_rate)
779
+ embeddings[i] = emb
780
+
781
+ embeddings = np.stack(embeddings, axis=0)
782
+
783
+ # Save to cache
784
+ try:
785
+ if cache_dir and audio_id and self.config.embedding_cache:
786
+ Path(cache_dir).mkdir(parents=True, exist_ok=True)
787
+ np.save(str(Path(cache_dir) / f"{audio_id}_embeddings.npy"), embeddings)
788
+ except Exception:
789
+ pass
790
+
791
+ return embeddings
792
+
793
+ # Otherwise use model batch encoding when available
794
+ batch_size = max(1, int(getattr(self.config, "embedding_batch_size", 32)))
795
+
796
+ # Prepare segment numpy arrays
797
+ segs = []
798
+ seg_indices = []
799
+ for i, (start, end) in enumerate(windows):
800
+ start_sample = int(start * sample_rate)
801
+ end_sample = int(end * sample_rate)
802
+ segment = waveform[:, start_sample:end_sample]
803
+ segs.append(segment)
804
+ seg_indices.append(i)
805
+
806
+ # Try batch processing
807
+ try:
808
+ # If model supports encode_batch on a list or stacked tensor, process in chunks
809
+ for i in range(0, len(segs), batch_size):
810
+ batch = segs[i : i + batch_size]
811
+ # Stack into a tensor batch
812
+ try:
813
+ batch_tensor = torch.stack(
814
+ [b.squeeze(0) if b.dim() == 2 else b for b in batch], dim=0
815
+ )
816
+ except Exception:
817
+ # Some models expect list of tensors; keep as list
818
+ batch_tensor = batch
819
+
820
+ with torch.no_grad():
821
+ try:
822
+ # Move to model device if available
823
+ if hasattr(self._embedding_model, "device") and isinstance(
824
+ batch_tensor, torch.Tensor
825
+ ):
826
+ batch_tensor = batch_tensor.to(self._embedding_model.device)
827
+
828
+ out = None
829
+ # Try the most common batch API names
830
+ if hasattr(self._embedding_model, "encode_batch"):
831
+ out = self._embedding_model.encode_batch(batch_tensor)
832
+ elif hasattr(self._embedding_model, "encode"):
833
+ out = self._embedding_model.encode(batch_tensor)
834
+ else:
835
+ # fallback: try to call on each separately
836
+ out = [self._embedding_model.encode_batch(x) for x in batch]
837
+
838
+ # Normalize outputs into numpy array
839
+ if isinstance(out, torch.Tensor):
840
+ out_np = out.cpu().numpy()
841
+ elif isinstance(out, list):
842
+ out_np = np.stack(
843
+ [
844
+ (
845
+ o.squeeze().cpu().numpy()
846
+ if isinstance(o, torch.Tensor)
847
+ else np.array(o)
848
+ )
849
+ for o in out
850
+ ],
851
+ axis=0,
852
+ )
853
+ else:
854
+ out_np = np.array(out)
855
+
856
+ # assign back to embeddings
857
+ for j, idx in enumerate(range(i, i + out_np.shape[0])):
858
+ embeddings[idx] = out_np[j]
859
+
860
+ except Exception as e:
861
+ # fallback to per-segment extraction for this batch
862
+ self.logger.debug(f"Batch embedding failed, falling back per-segment: {e}")
863
+ for bb_idx, seg in enumerate(batch):
864
+ try:
865
+ with torch.no_grad():
866
+ if hasattr(self._embedding_model, "device") and isinstance(
867
+ seg, torch.Tensor
868
+ ):
869
+ seg = seg.to(self._embedding_model.device)
870
+ emb = self._embedding_model.encode_batch(seg)
871
+ emb = emb.squeeze().cpu().numpy()
872
+ except Exception:
873
+ emb = np.random.randn(192).astype(np.float32)
874
+ embeddings[i + bb_idx] = emb
875
+
876
+ embeddings = np.stack(embeddings, axis=0)
877
+
878
+ # Save to cache
879
+ try:
880
+ if cache_dir and audio_id and self.config.embedding_cache:
881
+ Path(cache_dir).mkdir(parents=True, exist_ok=True)
882
+ np.save(str(Path(cache_dir) / f"{audio_id}_embeddings.npy"), embeddings)
883
+ except Exception:
884
+ pass
885
+
886
+ return embeddings
887
+
888
+ except Exception as e:
889
+ self.logger.warning(f"Batch embedding extraction failed: {e}")
890
+ # final fallback: single extraction loop
891
+ embeddings = []
892
+ for start, end in windows:
893
+ start_sample = int(start * sample_rate)
894
+ end_sample = int(end * sample_rate)
895
+ segment = waveform[:, start_sample:end_sample]
896
+ try:
897
+ with torch.no_grad():
898
+ if hasattr(self._embedding_model, "device"):
899
+ segment = segment.to(self._embedding_model.device)
900
+ emb = self._embedding_model.encode_batch(segment)
901
+ emb = emb.squeeze().cpu().numpy()
902
+ except Exception:
903
+ emb = np.random.randn(192).astype(np.float32)
904
+ embeddings.append(emb)
905
+
906
+ embeddings = np.stack(embeddings, axis=0)
907
+ return embeddings
908
+
909
+ def _cluster_embeddings(
910
+ self, embeddings: np.ndarray, num_speakers: Optional[int] = None, method_override: Optional[str] = None
911
+ ) -> np.ndarray:
912
+ """Cluster embeddings to assign speaker labels, with small-cluster merging.
913
+
914
+ Args:
915
+ embeddings: (N, D) array of embeddings
916
+ num_speakers: Optional target number of speakers
917
+ method_override: If set, use this clustering method ('agglomerative','spectral','kmeans')
918
+ """
919
+ if len(embeddings) < 2:
920
+ return np.zeros(len(embeddings), dtype=int)
921
+
922
+ # Normalize embeddings
923
+ scaler = StandardScaler()
924
+ embeddings_norm = scaler.fit_transform(embeddings)
925
+
926
+ # Support both nested (Config.diarization.clustering) and flat config shapes
927
+ if method_override is not None:
928
+ method = method_override
929
+ # default thresholds - allow config overrides below
930
+ threshold = getattr(self.config, "clustering_threshold", 0.7)
931
+ linkage = getattr(self.config, "clustering_linkage", "average")
932
+ min_size_cfg = getattr(self.config, "min_cluster_size", 2)
933
+ max_speakers_cfg = getattr(self.config, "max_speakers", None)
934
+ elif hasattr(self.config, "clustering"):
935
+ method = self.config.clustering.method
936
+ threshold = self.config.clustering.threshold
937
+ linkage = self.config.clustering.linkage
938
+ min_size_cfg = getattr(
939
+ self.config.clustering,
940
+ "min_cluster_size",
941
+ getattr(self.config, "min_cluster_size", 2),
942
+ )
943
+ max_speakers_cfg = getattr(self.config, "max_speakers", None)
944
+ else:
945
+ method = getattr(self.config, "clustering_method", "spectral")
946
+ threshold = getattr(self.config, "clustering_threshold", 0.7)
947
+ linkage = getattr(self.config, "clustering_linkage", "average")
948
+ min_size_cfg = getattr(self.config, "min_cluster_size", 2)
949
+ max_speakers_cfg = getattr(self.config, "max_speakers", None)
950
+
951
+ if method == "agglomerative":
952
+ if num_speakers is not None:
953
+ clustering = AgglomerativeClustering(
954
+ n_clusters=num_speakers, metric="cosine", linkage=linkage
955
+ )
956
+ else:
957
+ # If no target provided, estimate number of speakers via silhouette search
958
+ est_max = min(8, max(2, len(embeddings) // 2))
959
+ est_min = 2
960
+ best_k = None
961
+ best_score = -1.0
962
+ # Only try silhouette search on reasonably-sized inputs
963
+ if len(embeddings) >= 8:
964
+ for k in range(est_min, est_max + 1):
965
+ try:
966
+ tmp = AgglomerativeClustering(n_clusters=k, metric="cosine", linkage=linkage)
967
+ labels_tmp = tmp.fit_predict(embeddings_norm)
968
+ # silhouette requires at least 2 clusters and < n_samples clusters
969
+ if len(np.unique(labels_tmp)) > 1 and len(np.unique(labels_tmp)) < len(embeddings):
970
+ score = silhouette_score(embeddings_norm, labels_tmp, metric="cosine")
971
+ else:
972
+ score = -1.0
973
+ except Exception:
974
+ score = -1.0
975
+ if score > best_score:
976
+ best_score = score
977
+ best_k = k
978
+ # If silhouette search found a sensible k use it; else fallback to threshold style
979
+ if best_k is not None and best_score > 0.01:
980
+ clustering = AgglomerativeClustering(n_clusters=best_k, metric="cosine", linkage=linkage)
981
+ self.logger.info(f"Agglomerative autodetected k={best_k} (silhouette={best_score:.3f})")
982
+ else:
983
+ clustering = AgglomerativeClustering(
984
+ n_clusters=None,
985
+ distance_threshold=threshold,
986
+ metric="cosine",
987
+ linkage=linkage,
988
+ )
989
+
990
+ elif method == "spectral":
991
+ n_clusters = num_speakers or min(8, len(embeddings) // 2)
992
+ clustering = SpectralClustering(
993
+ n_clusters=n_clusters,
994
+ affinity="nearest_neighbors",
995
+ n_neighbors=min(10, len(embeddings) - 1),
996
+ )
997
+
998
+ elif method == "kmeans":
999
+ n_clusters = num_speakers or min(8, len(embeddings) // 2)
1000
+ clustering = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
1001
+
1002
+ else:
1003
+ raise ValueError(f"Unknown clustering method: {method}")
1004
+
1005
+ try:
1006
+ labels = clustering.fit_predict(embeddings_norm)
1007
+ except Exception as e:
1008
+ self.logger.error(f"Clustering failed: {e}")
1009
+ labels = np.array([i % 2 for i in range(len(embeddings))])
1010
+
1011
+ # Debug: cluster sizes
1012
+ unique, counts = np.unique(labels, return_counts=True)
1013
+ sizes = dict(zip(unique.tolist(), counts.tolist()))
1014
+ self.logger.debug(f"Initial clusters: {len(unique)}, sizes: {sizes}")
1015
+
1016
+ # Global check: if all embeddings are very similar, collapse directly to 1 speaker
1017
+ try:
1018
+ # First, perform a row-normalized (per-embedding) cosine check on raw embeddings
1019
+ row_norm = embeddings / (np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12)
1020
+ n_sample = min(200, len(row_norm))
1021
+ idx = np.linspace(0, len(row_norm) - 1, n_sample).astype(int)
1022
+ sub = row_norm[idx]
1023
+ sims = np.dot(sub, sub.T)
1024
+ sims = np.clip(sims, -1.0, 1.0)
1025
+ dists = 1.0 - sims
1026
+ mean_row_dist = (
1027
+ float(np.mean(dists[np.triu_indices_from(dists, k=1)])) if n_sample > 1 else 1.0
1028
+ )
1029
+ global_row_threshold = getattr(self.config, "global_collapse_threshold", 0.03)
1030
+ # Be more permissive for short recordings (few windows)
1031
+ if len(embeddings) < 40:
1032
+ global_row_threshold = max(global_row_threshold, 0.08)
1033
+ if mean_row_dist < global_row_threshold:
1034
+ self.logger.info(
1035
+ f"Row-normalized embeddings too similar (mean dist={mean_row_dist:.6f}), collapsing to 1 speaker"
1036
+ )
1037
+ return np.zeros(len(embeddings), dtype=int)
1038
+
1039
+ # Next, check on scaled embeddings (existing logic)
1040
+ n_sample = min(200, len(embeddings_norm))
1041
+ idx = np.linspace(0, len(embeddings_norm) - 1, n_sample).astype(int)
1042
+ sub = embeddings_norm[idx]
1043
+ sims = np.dot(sub, sub.T)
1044
+ sims = np.clip(sims, -1.0, 1.0)
1045
+ dists = 1.0 - sims
1046
+ mean_global_dist = (
1047
+ float(np.mean(dists[np.triu_indices_from(dists, k=1)])) if n_sample > 1 else 1.0
1048
+ )
1049
+ global_collapse_threshold = getattr(self.config, "global_collapse_threshold", 0.03)
1050
+ if mean_global_dist < global_collapse_threshold:
1051
+ self.logger.info(
1052
+ f"Global embeddings too similar (mean dist={mean_global_dist:.4f}), collapsing to 1 speaker"
1053
+ )
1054
+ return np.zeros(len(embeddings), dtype=int)
1055
+
1056
+ # Additional small-variance heuristic: if feature-wise std is tiny, collapse as well
1057
+ mean_std = float(np.mean(np.std(embeddings_norm, axis=0)))
1058
+ std_threshold = getattr(self.config, "global_std_threshold", 1e-2)
1059
+ if mean_std < std_threshold:
1060
+ self.logger.info(
1061
+ f"Embeddings have tiny variance (mean std={mean_std:.6f}), collapsing to 1 speaker"
1062
+ )
1063
+ return np.zeros(len(embeddings), dtype=int)
1064
+ except Exception:
1065
+ pass
1066
+
1067
+ # If centroids are very close to each other, this is likely a single-speaker recording.
1068
+ # Compute mean pairwise centroid cosine distance; if below a threshold, collapse to 1 cluster.
1069
+ try:
1070
+ labels_unique = np.unique(labels)
1071
+ centroids = [embeddings_norm[labels == l].mean(axis=0) for l in labels_unique]
1072
+ if len(centroids) > 1:
1073
+ pair_dists = []
1074
+ for i in range(len(centroids)):
1075
+ for j in range(i + 1, len(centroids)):
1076
+ a = centroids[i] / (np.linalg.norm(centroids[i]) + 1e-12)
1077
+ b = centroids[j] / (np.linalg.norm(centroids[j]) + 1e-12)
1078
+ pair_dists.append(1.0 - float(np.dot(a, b)))
1079
+ mean_pair_dist = float(np.mean(pair_dists)) if pair_dists else 1.0
1080
+ else:
1081
+ mean_pair_dist = 1.0
1082
+
1083
+ collapse_threshold = getattr(self.config, "collapse_threshold", 0.15)
1084
+ if mean_pair_dist < collapse_threshold:
1085
+ self.logger.info(
1086
+ f"Centroids too similar (mean dist={mean_pair_dist:.3f}), collapsing to 1 speaker"
1087
+ )
1088
+ labels = np.zeros_like(labels)
1089
+
1090
+ # If SpeechBrain embeddings are used and clusters have a very low silhouette score,
1091
+ # it's likely that the recording is single-speaker and clustering is over-fragmenting.
1092
+ try:
1093
+ if getattr(self.config, "use_speechbrain", True) and getattr(
1094
+ self, "_embedding_model_is_speechbrain", False
1095
+ ):
1096
+ unique_labels = np.unique(labels)
1097
+ if len(unique_labels) > 1:
1098
+ try:
1099
+ score = silhouette_score(embeddings_norm, labels, metric="cosine")
1100
+ if score < getattr(self.config, "silhouette_collapse_threshold", 0.05):
1101
+ self.logger.info(
1102
+ f"Low silhouette score ({score:.4f}) detected with SpeechBrain embeddings; collapsing to 1 speaker"
1103
+ )
1104
+ return np.zeros(len(embeddings), dtype=int)
1105
+ except Exception:
1106
+ pass
1107
+ except Exception:
1108
+ pass
1109
+ except Exception:
1110
+ pass
1111
+
1112
+ # Merge clusters smaller than min_cluster_size
1113
+ min_size = min_size_cfg
1114
+ if min_size and min_size > 1:
1115
+ changed = True
1116
+ while changed:
1117
+ changed = False
1118
+ labels_unique, label_counts = np.unique(labels, return_counts=True)
1119
+ small_labels = [l for l, c in zip(labels_unique, label_counts) if c < min_size]
1120
+ if not small_labels:
1121
+ break
1122
+
1123
+ # compute centroids for existing labels
1124
+ centroids = {l: embeddings_norm[labels == l].mean(axis=0) for l in labels_unique}
1125
+
1126
+ for sl in small_labels:
1127
+ candidates = [l for l in labels_unique if l != sl]
1128
+ if not candidates:
1129
+ continue
1130
+
1131
+ # find nearest centroid (cosine distance)
1132
+ def cosine_dist(a, b):
1133
+ a_norm = a / (np.linalg.norm(a) + 1e-12)
1134
+ b_norm = b / (np.linalg.norm(b) + 1e-12)
1135
+ return 1.0 - float(np.dot(a_norm, b_norm))
1136
+
1137
+ distances = [(c, cosine_dist(centroids[sl], centroids[c])) for c in candidates]
1138
+ nearest = min(distances, key=lambda x: x[1])[0]
1139
+
1140
+ # reassign labels
1141
+ labels[labels == sl] = nearest
1142
+ changed = True
1143
+
1144
+ # Final cluster sizes
1145
+ unique2, counts2 = np.unique(labels, return_counts=True)
1146
+ sizes2 = dict(zip(unique2.tolist(), counts2.tolist()))
1147
+ self.logger.debug(f"Clusters after merge: {len(unique2)}, sizes: {sizes2}")
1148
+
1149
+ # Additional centroid-based merging: merge clusters whose centroids are very close
1150
+ try:
1151
+ labels_unique = np.unique(labels)
1152
+ centroids = {l: embeddings_norm[labels == l].mean(axis=0) for l in labels_unique}
1153
+ # compute pairwise centroid distances
1154
+ pairs = []
1155
+ for i, a in enumerate(labels_unique):
1156
+ for j, b in enumerate(labels_unique):
1157
+ if j <= i:
1158
+ continue
1159
+ dist = 1.0 - float(
1160
+ np.dot(
1161
+ centroids[a] / (np.linalg.norm(centroids[a]) + 1e-12),
1162
+ centroids[b] / (np.linalg.norm(centroids[b]) + 1e-12),
1163
+ )
1164
+ )
1165
+ pairs.append((dist, a, b))
1166
+
1167
+ # merge pairs with distance < threshold
1168
+ pairs.sort()
1169
+ merged = False
1170
+ for dist, a, b in pairs:
1171
+ if dist < threshold:
1172
+ # merge b into a
1173
+ labels[labels == b] = a
1174
+ merged = True
1175
+
1176
+ if merged:
1177
+ labels_unique2, counts2 = np.unique(labels, return_counts=True)
1178
+ sizes2 = dict(zip(labels_unique2.tolist(), counts2.tolist()))
1179
+ self.logger.debug(
1180
+ f"Clusters after centroid-merge: {len(labels_unique2)}, sizes: {sizes2}"
1181
+ )
1182
+
1183
+ # Iterative silhouette-guided merging: try merging closest centroid pairs while it improves or meets configured criteria
1184
+ try:
1185
+ iterative_thresh = getattr(self.config, "iterative_merge_threshold", threshold)
1186
+ silhouette_min = getattr(self.config, "iterative_merge_silhouette_threshold", 0.0)
1187
+ max_merge_iters = getattr(self.config, "iterative_merge_max_iters", 10)
1188
+
1189
+ def compute_centroids(curr_labels):
1190
+ uniq = np.unique(curr_labels)
1191
+ return {l: embeddings_norm[curr_labels == l].mean(axis=0) for l in uniq}
1192
+
1193
+ def pairwise_min_pair(centroids_dict):
1194
+ uniq = list(centroids_dict.keys())
1195
+ best = (1.0, None, None)
1196
+ for i, a in enumerate(uniq):
1197
+ for j in range(i + 1, len(uniq)):
1198
+ b = uniq[j]
1199
+ a_c = centroids_dict[a] / (np.linalg.norm(centroids_dict[a]) + 1e-12)
1200
+ b_c = centroids_dict[b] / (np.linalg.norm(centroids_dict[b]) + 1e-12)
1201
+ dist = 1.0 - float(np.dot(a_c, b_c))
1202
+ if dist < best[0]:
1203
+ best = (dist, a, b)
1204
+ return best
1205
+
1206
+ curr_labels = labels.copy()
1207
+ prev_score = None
1208
+ try:
1209
+ if len(np.unique(curr_labels)) > 1:
1210
+ prev_score = silhouette_score(embeddings_norm, curr_labels, metric="cosine")
1211
+ except Exception:
1212
+ prev_score = None
1213
+
1214
+ iters = 0
1215
+ while iters < max_merge_iters:
1216
+ iters += 1
1217
+ cent = compute_centroids(curr_labels)
1218
+ if len(cent) <= 1:
1219
+ break
1220
+ min_dist, a, b = pairwise_min_pair(cent)
1221
+ if min_dist >= iterative_thresh:
1222
+ break
1223
+
1224
+ # simulate merge and evaluate silhouette
1225
+ next_labels = curr_labels.copy()
1226
+ next_labels[next_labels == b] = a
1227
+
1228
+ try:
1229
+ if len(np.unique(next_labels)) > 1:
1230
+ next_score = silhouette_score(
1231
+ embeddings_norm, next_labels, metric="cosine"
1232
+ )
1233
+ else:
1234
+ next_score = 1.0
1235
+ except Exception:
1236
+ next_score = None
1237
+
1238
+ accept = False
1239
+ if next_score is not None:
1240
+ if prev_score is None:
1241
+ # accept merges that meet a minimum silhouette threshold
1242
+ if next_score >= silhouette_min:
1243
+ accept = True
1244
+ else:
1245
+ # accept if silhouette improves by a small margin or stays acceptable
1246
+ if next_score >= prev_score or next_score >= silhouette_min:
1247
+ accept = True
1248
+
1249
+ if accept:
1250
+ curr_labels = next_labels
1251
+ prev_score = next_score
1252
+ labels = curr_labels.copy()
1253
+ # continue iterating
1254
+ else:
1255
+ break
1256
+
1257
+ if iters > 1:
1258
+ labels_unique2, counts2 = np.unique(labels, return_counts=True)
1259
+ sizes2 = dict(zip(labels_unique2.tolist(), counts2.tolist()))
1260
+ self.logger.debug(
1261
+ f"Clusters after iterative-merge (iters={iters}): {len(labels_unique2)}, sizes: {sizes2}"
1262
+ )
1263
+
1264
+ # If user requested a target speaker count, greedily merge closest centroid pairs until we meet it
1265
+ try:
1266
+ target_k = getattr(self.config, "target_num_speakers", None)
1267
+ force_thresh = float(getattr(self.config, "target_force_threshold", 1.0))
1268
+ if target_k is not None:
1269
+ curr_labels = labels.copy()
1270
+
1271
+ def compute_centroids(curr):
1272
+ uniq = np.unique(curr)
1273
+ return {l: embeddings_norm[curr == l].mean(axis=0) for l in uniq}
1274
+
1275
+ merged_iters = 0
1276
+ while len(np.unique(curr_labels)) > target_k:
1277
+ cent = compute_centroids(curr_labels)
1278
+ if len(cent) <= 1:
1279
+ break
1280
+ # find closest pair
1281
+ uniq = list(cent.keys())
1282
+ best = (1.0, None, None)
1283
+ for i, a in enumerate(uniq):
1284
+ for j in range(i + 1, len(uniq)):
1285
+ b = uniq[j]
1286
+ a_c = cent[a] / (np.linalg.norm(cent[a]) + 1e-12)
1287
+ b_c = cent[b] / (np.linalg.norm(cent[b]) + 1e-12)
1288
+ dist = 1.0 - float(np.dot(a_c, b_c))
1289
+ if dist < best[0]:
1290
+ best = (dist, a, b)
1291
+
1292
+ min_dist, a, b = best
1293
+ # if min_dist is too large and force_thresh < 1.0, break
1294
+ if min_dist > force_thresh and force_thresh < 1.0:
1295
+ self.logger.warning(
1296
+ f"Stopping target-merge early: nearest cluster dist {min_dist:.3f} > force_thresh {force_thresh}"
1297
+ )
1298
+ break
1299
+
1300
+ # merge b into a
1301
+ curr_labels[curr_labels == b] = a
1302
+ merged_iters += 1
1303
+ # safety to avoid infinite loops
1304
+ if merged_iters > 1000:
1305
+ break
1306
+
1307
+ if merged_iters:
1308
+ labels = curr_labels.copy()
1309
+ labels_unique2, counts2 = np.unique(labels, return_counts=True)
1310
+ sizes2 = dict(zip(labels_unique2.tolist(), counts2.tolist()))
1311
+ self.logger.info(
1312
+ f"Clusters after target-merge (target={target_k}, iters={merged_iters}): {len(labels_unique2)}, sizes: {sizes2}"
1313
+ )
1314
+ except Exception:
1315
+ pass
1316
+
1317
+ except Exception:
1318
+ # don't let merging errors break the pipeline
1319
+ pass
1320
+
1321
+ # Heuristic fallback: if still too fragmented, run KMeans with estimated speaker count
1322
+ n_clusters_found = len(np.unique(labels))
1323
+ max_allowed = 20
1324
+ if n_clusters_found > max_allowed:
1325
+ est_k = min(12, max(2, int(len(embeddings) / 80)))
1326
+ self.logger.warning(
1327
+ f"Too many clusters ({n_clusters_found}), falling back to KMeans with k={est_k}"
1328
+ )
1329
+ try:
1330
+ km = KMeans(n_clusters=est_k, random_state=42, n_init=10)
1331
+ labels = km.fit_predict(embeddings_norm)
1332
+ # Re-merge small clusters after KMeans
1333
+ labels_unique2, counts2 = np.unique(labels, return_counts=True)
1334
+ sizes2 = dict(zip(labels_unique2.tolist(), counts2.tolist()))
1335
+ self.logger.info(
1336
+ f"Clusters after KMeans fallback: {len(labels_unique2)}, sizes: {sizes2}"
1337
+ )
1338
+ except Exception as e:
1339
+ self.logger.error(f"KMeans fallback failed: {e}")
1340
+ except Exception:
1341
+ pass
1342
+
1343
+ return labels
1344
+
1345
+ def _create_segments(
1346
+ self, windows: List[Tuple[float, float]], labels: np.ndarray, embeddings: np.ndarray
1347
+ ) -> List[SpeakerSegment]:
1348
+ """Create SpeakerSegment objects from windows and labels"""
1349
+ segments = []
1350
+
1351
+ for (start, end), label, emb in zip(windows, labels, embeddings):
1352
+ segments.append(
1353
+ SpeakerSegment(
1354
+ speaker_id=f"SPEAKER_{label:02d}",
1355
+ start=start,
1356
+ end=end,
1357
+ confidence=1.0,
1358
+ embedding=emb,
1359
+ )
1360
+ )
1361
+
1362
+ # If we used the fallback extractor, update segment embeddings to the deterministic MFCC embeddings
1363
+ if getattr(self, "_fallback_extractor", None) is not None:
1364
+ try:
1365
+ for i, seg in enumerate(segments):
1366
+ # reuse windows to create a deterministic embedding
1367
+ s, e = windows[i]
1368
+ # external code expects embeddings array, but ensure segment.embedding is deterministic
1369
+ if (
1370
+ segments[i].embedding is None
1371
+ or isinstance(self._embedding_model, str)
1372
+ and self._embedding_model == "FALLBACK"
1373
+ ):
1374
+ # compute on-demand using fallback extractor
1375
+ seg_np = self._extract_waveform_segment(windows[i])
1376
+ segments[i].embedding = self._fallback_extractor(seg_np, sample_rate)
1377
+ except Exception:
1378
+ pass
1379
+
1380
+ return segments
1381
+
1382
+ def _postprocess_segments(self, segments: List[SpeakerSegment]) -> List[SpeakerSegment]:
1383
+ """Post-process segments: merge adjacent, filter short"""
1384
+ if not segments:
1385
+ return []
1386
+
1387
+ # Sort by start time
1388
+ segments = sorted(segments, key=lambda x: x.start)
1389
+
1390
+ # Merge adjacent segments from same speaker
1391
+ merged = [segments[0]]
1392
+
1393
+ for seg in segments[1:]:
1394
+ last = merged[-1]
1395
+ gap = seg.start - last.end
1396
+
1397
+ if seg.speaker_id == last.speaker_id and gap <= self.config.merge_gap_threshold:
1398
+ # Merge: extend last segment
1399
+ last.end = max(last.end, seg.end)
1400
+ last.confidence = (last.confidence + seg.confidence) / 2
1401
+ else:
1402
+ merged.append(seg)
1403
+
1404
+ # Smoothing: fix short isolated segments between identical speakers
1405
+ smoothed = merged
1406
+ if len(smoothed) >= 3:
1407
+ changed = False
1408
+ for i in range(1, len(smoothed) - 1):
1409
+ seg = smoothed[i]
1410
+ prev = smoothed[i - 1]
1411
+ nxt = smoothed[i + 1]
1412
+ threshold = max(1.0, self.config.min_segment_duration)
1413
+ if seg.duration < threshold and prev.speaker_id == nxt.speaker_id:
1414
+ seg.speaker_id = prev.speaker_id
1415
+ changed = True
1416
+
1417
+ if changed:
1418
+ # merge again after smoothing
1419
+ merged2 = [smoothed[0]]
1420
+ for seg in smoothed[1:]:
1421
+ last = merged2[-1]
1422
+ gap = seg.start - last.end
1423
+ if seg.speaker_id == last.speaker_id and gap <= self.config.merge_gap_threshold:
1424
+ last.end = max(last.end, seg.end)
1425
+ last.confidence = (last.confidence + seg.confidence) / 2
1426
+ else:
1427
+ merged2.append(seg)
1428
+ merged = merged2
1429
+
1430
+ # Filter short segments
1431
+ filtered = [seg for seg in merged if seg.duration >= self.config.min_segment_duration]
1432
+
1433
+ return filtered
1434
+
1435
+ def _merge_segments(
1436
+ self, segments: List[SpeakerSegment], max_gap: float = 0.5
1437
+ ) -> List[SpeakerSegment]:
1438
+ """Compatibility helper: merge adjacent segments from same speaker within max_gap"""
1439
+ if not segments:
1440
+ return []
1441
+
1442
+ segments = sorted(segments, key=lambda x: x.start)
1443
+ merged_list = [segments[0]]
1444
+
1445
+ for seg in segments[1:]:
1446
+ last = merged_list[-1]
1447
+ gap = seg.start - last.end
1448
+ if seg.speaker_id == last.speaker_id and gap <= max_gap:
1449
+ # Merge: extend last segment
1450
+ last.end = max(last.end, seg.end)
1451
+ last.confidence = (last.confidence + seg.confidence) / 2
1452
+ else:
1453
+ merged_list.append(seg)
1454
+
1455
+ return merged_list
1456
+
1457
+ def _detect_overlaps(self, segments: List[SpeakerSegment]) -> List[SpeakerSegment]:
1458
+ """Mark segments that overlap with other speakers"""
1459
+ for i, seg1 in enumerate(segments):
1460
+ for j, seg2 in enumerate(segments):
1461
+ if i != j and seg1.speaker_id != seg2.speaker_id:
1462
+ # Check for time overlap
1463
+ overlap_start = max(seg1.start, seg2.start)
1464
+ overlap_end = min(seg1.end, seg2.end)
1465
+
1466
+ if overlap_start < overlap_end:
1467
+ seg1.is_overlap = True
1468
+ seg2.is_overlap = True
1469
+
1470
+ return segments
1471
+
1472
+ def get_speaker_stats(self, segments: List[SpeakerSegment]) -> Dict[str, Dict[str, float]]:
1473
+ """
1474
+ Get statistics for each speaker.
1475
+
1476
+ Returns:
1477
+ Dict mapping speaker_id to stats (total_duration, num_segments, etc.)
1478
+ """
1479
+ stats = {}
1480
+
1481
+ for seg in segments:
1482
+ if seg.speaker_id not in stats:
1483
+ stats[seg.speaker_id] = {
1484
+ "total_duration": 0.0,
1485
+ "num_segments": 0,
1486
+ "avg_segment_duration": 0.0,
1487
+ "overlap_duration": 0.0,
1488
+ }
1489
+
1490
+ stats[seg.speaker_id]["total_duration"] += seg.duration
1491
+ stats[seg.speaker_id]["num_segments"] += 1
1492
+
1493
+ if seg.is_overlap:
1494
+ stats[seg.speaker_id]["overlap_duration"] += seg.duration
1495
+
1496
+ # Calculate averages
1497
+ for speaker_id in stats:
1498
+ num_segs = stats[speaker_id]["num_segments"]
1499
+ if num_segs > 0:
1500
+ stats[speaker_id]["avg_segment_duration"] = (
1501
+ stats[speaker_id]["total_duration"] / num_segs
1502
+ )
1503
+
1504
+ return stats
src/document_generator.py ADDED
@@ -0,0 +1,852 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Document Generator Module
3
+ =========================
4
+ Exports meeting minutes to formatted .docx using python-docx.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import re
10
+ import warnings
11
+ from dataclasses import dataclass, field
12
+ from datetime import datetime
13
+ from pathlib import Path
14
+ from typing import Dict, List, Optional
15
+
16
+ try:
17
+ from docx import Document
18
+ from docx.enum.table import WD_TABLE_ALIGNMENT
19
+ from docx.enum.text import WD_ALIGN_PARAGRAPH
20
+ from docx.oxml import OxmlElement
21
+ from docx.oxml.ns import qn
22
+ from docx.shared import Cm, Pt, RGBColor
23
+
24
+ DOCX_AVAILABLE = True
25
+ except Exception:
26
+ # Minimal fallback implementations for environments without python-docx (used in tests)
27
+ DOCX_AVAILABLE = False
28
+
29
+ class Document:
30
+ def __init__(self):
31
+ self._paragraphs = []
32
+ self.sections = []
33
+
34
+ # Minimal styles container to mimic python-docx for tests
35
+ class DummyStyle:
36
+ def __init__(self):
37
+ self.font = type("F", (), {"name": None, "size": None})
38
+
39
+ class RFonts:
40
+ def set(self, *args, **kwargs):
41
+ pass
42
+
43
+ class RPr:
44
+ def __init__(self):
45
+ self.rFonts = RFonts()
46
+
47
+ class Element:
48
+ def __init__(self):
49
+ self.rPr = RPr()
50
+
51
+ self._element = Element()
52
+
53
+ class Styles:
54
+ def __init__(self):
55
+ self._styles = {"Normal": DummyStyle()}
56
+
57
+ def __getitem__(self, key):
58
+ return self._styles.setdefault(key, DummyStyle())
59
+
60
+ self.styles = Styles()
61
+
62
+ class Run:
63
+ def __init__(self, text=""):
64
+ self.text = str(text)
65
+ self.bold = False
66
+ self.italic = False
67
+ self.font = type("F", (), {"size": None, "color": type("C", (), {"rgb": None})()})
68
+
69
+ class Paragraph:
70
+ def __init__(self, text=""):
71
+ self.runs = []
72
+ self.paragraph_format = type("PF", (), {"space_after": None})
73
+ self.alignment = None
74
+ if text:
75
+ self.add_run(text)
76
+
77
+ def add_run(self, text=""):
78
+ # Create a lightweight run-like object for fallback
79
+ run = type(
80
+ "Run",
81
+ (),
82
+ {
83
+ "text": str(text),
84
+ "bold": False,
85
+ "italic": False,
86
+ "font": type(
87
+ "F", (), {"size": None, "color": type("C", (), {"rgb": None})()}
88
+ )(),
89
+ },
90
+ )()
91
+ self.runs.append(run)
92
+ return run
93
+
94
+ def add_paragraph(self, text="", **kwargs):
95
+ # Accept style and other kwargs for compatibility
96
+ para = self.Paragraph(text)
97
+ self._paragraphs.append(para)
98
+ return para
99
+
100
+ def add_heading(self, text, level=None, **kwargs):
101
+ para = self.Paragraph(text)
102
+ self._paragraphs.append(para)
103
+ return para
104
+
105
+ def add_table(self, rows, cols):
106
+ outer = self
107
+
108
+ class Cell:
109
+ def __init__(self):
110
+ self.paragraphs = [outer.Paragraph()]
111
+
112
+ # Minimal _tc structure to support shading and other docx operations in fallback
113
+ class TCPr:
114
+ def append(self, *args, **kwargs):
115
+ pass
116
+
117
+ class TC:
118
+ def get_or_add_tcPr(self):
119
+ return TCPr()
120
+
121
+ self._tc = TC()
122
+
123
+ @property
124
+ def text(self):
125
+ if self.paragraphs and self.paragraphs[0].runs:
126
+ return " ".join(run.text for run in self.paragraphs[0].runs)
127
+ return ""
128
+
129
+ @text.setter
130
+ def text(self, value):
131
+ # Create lightweight run-like object
132
+ self.paragraphs[0].runs = [
133
+ type(
134
+ "Run",
135
+ (),
136
+ {
137
+ "text": str(value),
138
+ "bold": False,
139
+ "italic": False,
140
+ "font": type(
141
+ "F", (), {"size": None, "color": type("C", (), {"rgb": None})()}
142
+ )(),
143
+ },
144
+ )()
145
+ ]
146
+
147
+ class Row:
148
+ def __init__(self, cols):
149
+ self.cells = [Cell() for _ in range(cols)]
150
+
151
+ table = type(
152
+ "Table",
153
+ (),
154
+ {"rows": [Row(cols) for _ in range(rows)], "style": None, "alignment": None},
155
+ )
156
+ return table
157
+
158
+ def save(self, path):
159
+ # Save a plain text fallback document so tests can verify file exists
160
+ lines = []
161
+ for p in self._paragraphs:
162
+ if hasattr(p, "runs"):
163
+ lines.append(" ".join(getattr(r, "text", "") for r in p.runs))
164
+ else:
165
+ lines.append(str(p))
166
+ with open(path, "w", encoding="utf-8") as f:
167
+ f.write("\n".join(lines))
168
+
169
+ class Pt:
170
+ def __init__(self, value):
171
+ self.value = value
172
+
173
+ class Cm:
174
+ def __init__(self, value):
175
+ self.value = value
176
+
177
+ class RGBColor:
178
+ def __init__(self, r, g, b):
179
+ pass
180
+
181
+ class WD_ALIGN_PARAGRAPH:
182
+ CENTER = 1
183
+
184
+ class WD_TABLE_ALIGNMENT:
185
+ LEFT = 1
186
+
187
+ class OxmlElement:
188
+ def __init__(self, *args, **kwargs):
189
+ pass
190
+
191
+ def set(self, *args, **kwargs):
192
+ pass
193
+
194
+ def qn(x):
195
+ return x
196
+
197
+
198
+ from src.summarizer import MeetingSummary
199
+ from src.transcriber import TranscriptSegment
200
+
201
+
202
+ @dataclass
203
+ class MeetingMetadata:
204
+ """Meeting information for document header"""
205
+
206
+ title: str
207
+ date: str
208
+ time: str = ""
209
+ location: str = ""
210
+ duration: str = ""
211
+ participants: Optional[List[str]] = None
212
+ organizer: str = ""
213
+ agenda: str = ""
214
+
215
+ @classmethod
216
+ def create_default(cls, audio_duration_sec: float = 0) -> "MeetingMetadata":
217
+ """Create default metadata"""
218
+ duration_str = ""
219
+ if audio_duration_sec > 0:
220
+ hours = int(audio_duration_sec // 3600)
221
+ minutes = int((audio_duration_sec % 3600) // 60)
222
+ seconds = int(audio_duration_sec % 60)
223
+
224
+ if hours > 0:
225
+ duration_str = f"{hours} jam {minutes} menit {seconds} detik"
226
+ else:
227
+ duration_str = f"{minutes} menit {seconds} detik"
228
+
229
+ return cls(
230
+ title="Notulensi Rapat",
231
+ date=datetime.now().strftime("%d %B %Y"),
232
+ time=datetime.now().strftime("%H:%M"),
233
+ duration=duration_str,
234
+ )
235
+
236
+
237
+ @dataclass
238
+ class DocumentConfig:
239
+ """Configuration for document generation"""
240
+
241
+ # Font settings
242
+ title_font_size: int = 18
243
+ heading1_font_size: int = 14
244
+ heading2_font_size: int = 12
245
+ body_font_size: int = 11
246
+ font_family: str = "Calibri"
247
+
248
+ # Layout
249
+ page_width: float = 21.0 # cm (A4)
250
+ page_height: float = 29.7 # cm (A4)
251
+ margin_top: float = 2.5
252
+ margin_bottom: float = 2.5
253
+ margin_left: float = 3.0
254
+ margin_right: float = 2.5
255
+
256
+ # Content options
257
+ include_timestamps: bool = True
258
+ include_speaker_colors: bool = True
259
+ include_table_of_contents: bool = False
260
+ include_page_numbers: bool = True
261
+
262
+ # Sections to include
263
+ sections: Dict[str, bool] = field(
264
+ default_factory=lambda: {
265
+ "header": True,
266
+ "meeting_info": True,
267
+ "summary": True,
268
+ "decisions": True,
269
+ "action_items": True,
270
+ "transcript": True,
271
+ "footer": True,
272
+ }
273
+ )
274
+
275
+
276
+ class DocumentGenerator:
277
+ """
278
+ Generates formatted .docx meeting minutes.
279
+
280
+ Structure:
281
+ - Title
282
+ - Meeting Information
283
+ - Executive Summary
284
+ - Key Points
285
+ - Decisions
286
+ - Action Items
287
+ - Full Transcript
288
+ - Footer
289
+
290
+ Attributes:
291
+ config: DocumentConfig object
292
+ output_dir: Output directory path
293
+
294
+ Example:
295
+ >>> generator = DocumentGenerator()
296
+ >>> doc_path = generator.generate(metadata, summary, transcript)
297
+ >>> print(f"Document saved: {doc_path}")
298
+ """
299
+
300
+ # Speaker colors for visual distinction
301
+ SPEAKER_COLORS = [
302
+ RGBColor(0, 102, 204), # Blue
303
+ RGBColor(204, 51, 0), # Red
304
+ RGBColor(0, 153, 51), # Green
305
+ RGBColor(153, 51, 153), # Purple
306
+ RGBColor(204, 102, 0), # Orange
307
+ RGBColor(0, 153, 153), # Teal
308
+ RGBColor(102, 102, 0), # Olive
309
+ RGBColor(153, 0, 76), # Maroon
310
+ ]
311
+
312
+ def __init__(self, config: Optional[DocumentConfig] = None, output_dir: str = "./data/output"):
313
+ """
314
+ Initialize DocumentGenerator.
315
+
316
+ Args:
317
+ config: DocumentConfig object
318
+ output_dir: Directory for output files
319
+ """
320
+ self.config = config or DocumentConfig()
321
+ self.output_dir = Path(output_dir)
322
+ self.output_dir.mkdir(parents=True, exist_ok=True)
323
+
324
+ self._speaker_color_map: Dict[str, RGBColor] = {}
325
+
326
+ def generate(
327
+ self,
328
+ metadata: MeetingMetadata,
329
+ summary: MeetingSummary,
330
+ transcript: List[TranscriptSegment],
331
+ output_filename: Optional[str] = None,
332
+ ) -> str:
333
+ """
334
+ Generate complete meeting minutes document.
335
+
336
+ Args:
337
+ metadata: Meeting information
338
+ summary: Generated summary
339
+ transcript: Transcribed segments with speakers
340
+ output_filename: Output file name (auto-generated if None)
341
+
342
+ Returns:
343
+ Path to generated document
344
+ """
345
+ # Create document
346
+ doc = Document()
347
+
348
+ # Setup document
349
+ self._setup_document(doc)
350
+ self._setup_styles(doc)
351
+
352
+ # Build speaker color map
353
+ self._build_speaker_color_map(transcript)
354
+
355
+ # Add sections
356
+ if self.config.sections.get("header", True):
357
+ self._add_title(doc, metadata)
358
+
359
+ if self.config.sections.get("meeting_info", True):
360
+ self._add_meeting_info(doc, metadata)
361
+
362
+ if self.config.sections.get("summary", True):
363
+ self._add_summary_section(doc, summary)
364
+
365
+ if self.config.sections.get("decisions", True):
366
+ self._add_decisions_section(doc, summary.decisions)
367
+
368
+ if self.config.sections.get("action_items", True):
369
+ self._add_action_items_section(doc, summary.action_items)
370
+
371
+ if self.config.sections.get("transcript", True):
372
+ self._add_transcript_section(doc, transcript)
373
+
374
+ if self.config.sections.get("footer", True):
375
+ self._add_footer(doc)
376
+
377
+ # Generate filename if not provided
378
+ if output_filename is None:
379
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
380
+ safe_title = self._sanitize_filename(metadata.title)[:30]
381
+ ext = ".docx" if DOCX_AVAILABLE else ".txt"
382
+ output_filename = f"notulensi_{safe_title}_{timestamp}{ext}"
383
+
384
+ # Ensure .docx extension
385
+ if not output_filename.endswith(".docx"):
386
+ output_filename = Path(output_filename).with_suffix(".docx").name
387
+
388
+ output_path = self.output_dir / output_filename
389
+
390
+ # Save document
391
+ if DOCX_AVAILABLE:
392
+ doc.save(str(output_path))
393
+ else:
394
+ # If python-docx is not available, build a minimal valid .docx package so Word can open it.
395
+ warnings.warn(
396
+ "python-docx is not available in the current environment; generating a minimal .docx package instead."
397
+ )
398
+ paragraphs = self._extract_paragraph_texts(doc)
399
+ self._save_minimal_docx(str(output_path), paragraphs)
400
+
401
+ return str(output_path)
402
+
403
+ def _setup_document(self, doc: Document):
404
+ """Configure document settings"""
405
+ # Set page margins
406
+ sections = doc.sections
407
+ for section in sections:
408
+ section.top_margin = Cm(self.config.margin_top)
409
+ section.bottom_margin = Cm(self.config.margin_bottom)
410
+ section.left_margin = Cm(self.config.margin_left)
411
+ section.right_margin = Cm(self.config.margin_right)
412
+
413
+ def _setup_styles(self, doc: Document):
414
+ """Configure document styles"""
415
+ # Normal style
416
+ style = doc.styles["Normal"]
417
+ style.font.name = self.config.font_family
418
+ style.font.size = Pt(self.config.body_font_size)
419
+
420
+ # Set font for East Asian text
421
+ style._element.rPr.rFonts.set(qn("w:eastAsia"), self.config.font_family)
422
+
423
+ def _build_speaker_color_map(self, transcript: List[TranscriptSegment]):
424
+ """Build consistent color mapping for speakers"""
425
+ speakers = sorted(set(seg.speaker_id for seg in transcript))
426
+
427
+ for i, speaker in enumerate(speakers):
428
+ self._speaker_color_map[speaker] = self.SPEAKER_COLORS[i % len(self.SPEAKER_COLORS)]
429
+
430
+ def _add_title(self, doc: Document, metadata: MeetingMetadata):
431
+ """Add document title"""
432
+ # Main title
433
+ title_para = doc.add_paragraph()
434
+ title_run = title_para.add_run("NOTULENSI RAPAT")
435
+ title_run.bold = True
436
+ title_run.font.size = Pt(self.config.title_font_size)
437
+ title_run.font.color.rgb = RGBColor(0, 51, 102)
438
+ title_para.alignment = WD_ALIGN_PARAGRAPH.CENTER
439
+
440
+ # Subtitle with meeting title
441
+ if metadata.title and metadata.title != "Notulensi Rapat":
442
+ subtitle_para = doc.add_paragraph()
443
+ subtitle_run = subtitle_para.add_run(metadata.title)
444
+ subtitle_run.bold = True
445
+ subtitle_run.font.size = Pt(self.config.heading1_font_size)
446
+ subtitle_para.alignment = WD_ALIGN_PARAGRAPH.CENTER
447
+
448
+ # Generated by note
449
+ note_para = doc.add_paragraph()
450
+ note_run = note_para.add_run("Generated by AI Meeting Transcriber (SpeechBrain + BERT)")
451
+ note_run.italic = True
452
+ note_run.font.size = Pt(9)
453
+ note_run.font.color.rgb = RGBColor(128, 128, 128)
454
+ note_para.alignment = WD_ALIGN_PARAGRAPH.CENTER
455
+
456
+ # Spacer
457
+ doc.add_paragraph()
458
+
459
+ def _add_meeting_info(self, doc: Document, metadata: MeetingMetadata):
460
+ """Add meeting information section"""
461
+ # Section heading
462
+ heading = doc.add_heading("Informasi Rapat", level=1)
463
+ heading.runs[0].font.size = Pt(self.config.heading1_font_size)
464
+
465
+ # Create info table
466
+ info_items = [
467
+ ("Tanggal", metadata.date),
468
+ ("Waktu", metadata.time or "-"),
469
+ ("Lokasi/Platform", metadata.location or "-"),
470
+ ("Durasi", metadata.duration or "-"),
471
+ ("Penyelenggara", metadata.organizer or "-"),
472
+ ]
473
+
474
+ # Filter out empty items
475
+ info_items = [(label, value) for label, value in info_items if value and value != "-"]
476
+
477
+ if info_items:
478
+ table = doc.add_table(rows=len(info_items), cols=2)
479
+ table.style = "Table Grid"
480
+ table.alignment = WD_TABLE_ALIGNMENT.LEFT
481
+
482
+ for i, (label, value) in enumerate(info_items):
483
+ row = table.rows[i]
484
+
485
+ # Label cell
486
+ cell_label = row.cells[0]
487
+ cell_label.text = label
488
+ cell_label.paragraphs[0].runs[0].bold = True
489
+ cell_label.width = Cm(4)
490
+
491
+ # Value cell
492
+ cell_value = row.cells[1]
493
+ cell_value.text = value
494
+
495
+ # Add participants if available
496
+ if metadata.participants:
497
+ doc.add_paragraph()
498
+ para = doc.add_paragraph()
499
+ para.add_run("Peserta Rapat: ").bold = True
500
+ para.add_run(", ".join(metadata.participants))
501
+
502
+ # Add agenda if available
503
+ if metadata.agenda:
504
+ doc.add_paragraph()
505
+ para = doc.add_paragraph()
506
+ para.add_run("Agenda: ").bold = True
507
+ para.add_run(metadata.agenda)
508
+
509
+ # Spacer
510
+ doc.add_paragraph()
511
+
512
+ def _add_summary_section(self, doc: Document, summary: MeetingSummary):
513
+ """Add executive summary section"""
514
+ # Section heading
515
+ heading = doc.add_heading("Ringkasan Eksekutif", level=1)
516
+ heading.runs[0].font.size = Pt(self.config.heading1_font_size)
517
+
518
+ # Overview
519
+ if summary.overview and not self._is_placeholder_text(summary.overview):
520
+ overview_para = doc.add_paragraph()
521
+ overview_para.add_run(summary.overview)
522
+ overview_para.paragraph_format.space_after = Pt(12)
523
+ else:
524
+ overview_para = doc.add_paragraph()
525
+ overview_para.add_run(
526
+ "Ringkasan tidak tersedia. (Model ringkasan tidak dimuat atau data tidak mencukupi.)"
527
+ )
528
+ overview_para.runs[0].italic = True
529
+ overview_para.runs[0].font.color.rgb = RGBColor(128, 128, 128)
530
+
531
+ # Key points (filter placeholders)
532
+ filtered_points = [
533
+ p for p in (summary.key_points or []) if not self._is_placeholder_text(p)
534
+ ]
535
+ if filtered_points:
536
+ subheading = doc.add_heading("Poin-Poin Penting", level=2)
537
+ subheading.runs[0].font.size = Pt(self.config.heading2_font_size)
538
+
539
+ for point in filtered_points:
540
+ para = doc.add_paragraph(point, style="List Bullet")
541
+ else:
542
+ para = doc.add_paragraph()
543
+ para.add_run("Tidak ada poin penting yang dihasilkan secara otomatis.")
544
+ para.runs[0].italic = True
545
+ para.runs[0].font.color.rgb = RGBColor(128, 128, 128)
546
+
547
+ # Topics discussed (filter placeholders)
548
+ topics_filtered = [t for t in (summary.topics or []) if not self._is_placeholder_text(t)]
549
+ if topics_filtered:
550
+ doc.add_paragraph()
551
+ para = doc.add_paragraph()
552
+ para.add_run("Topik yang dibahas: ").bold = True
553
+ para.add_run(", ".join(topics_filtered))
554
+
555
+ # Spacer
556
+ doc.add_paragraph()
557
+
558
+ def _add_decisions_section(self, doc: Document, decisions: List[str]):
559
+ """Add decisions section"""
560
+ # Section heading
561
+ heading = doc.add_heading("Keputusan Rapat", level=1)
562
+ heading.runs[0].font.size = Pt(self.config.heading1_font_size)
563
+
564
+ if decisions:
565
+ for i, decision in enumerate(decisions, 1):
566
+ para = doc.add_paragraph()
567
+ para.add_run(f"{i}. ").bold = True
568
+ para.add_run(decision)
569
+ else:
570
+ para = doc.add_paragraph()
571
+ para.add_run("Tidak ada keputusan yang teridentifikasi secara otomatis.")
572
+ para.runs[0].italic = True
573
+ para.runs[0].font.color.rgb = RGBColor(128, 128, 128)
574
+
575
+ # Spacer
576
+ doc.add_paragraph()
577
+
578
+ def _add_action_items_section(self, doc: Document, action_items: List[Dict[str, str]]):
579
+ """Add action items section"""
580
+ # Section heading
581
+ heading = doc.add_heading("Action Items / Tindak Lanjut", level=1)
582
+ heading.runs[0].font.size = Pt(self.config.heading1_font_size)
583
+
584
+ if action_items:
585
+ # Create table
586
+ table = doc.add_table(rows=len(action_items) + 1, cols=4)
587
+ table.style = "Table Grid"
588
+ table.alignment = WD_TABLE_ALIGNMENT.LEFT
589
+
590
+ # Header row
591
+ headers = ["No.", "Penanggung Jawab", "Tugas", "Deadline"]
592
+ header_row = table.rows[0]
593
+
594
+ for i, header_text in enumerate(headers):
595
+ cell = header_row.cells[i]
596
+ cell.text = header_text
597
+
598
+ # Style header
599
+ for paragraph in cell.paragraphs:
600
+ for run in paragraph.runs:
601
+ run.bold = True
602
+
603
+ # Set header background color
604
+ shading = OxmlElement("w:shd")
605
+ shading.set(qn("w:fill"), "D9E2F3")
606
+ cell._tc.get_or_add_tcPr().append(shading)
607
+
608
+ # Data rows
609
+ for i, item in enumerate(action_items, 1):
610
+ row = table.rows[i]
611
+
612
+ row.cells[0].text = str(i)
613
+ row.cells[1].text = item.get("owner", "-")
614
+ row.cells[2].text = item.get("task", "-")
615
+ row.cells[3].text = item.get("due", "-")
616
+
617
+ # Set column widths
618
+ for row in table.rows:
619
+ row.cells[0].width = Cm(1.0)
620
+ row.cells[1].width = Cm(3.5)
621
+ row.cells[2].width = Cm(9.0)
622
+ row.cells[3].width = Cm(2.5)
623
+ else:
624
+ para = doc.add_paragraph()
625
+ para.add_run("Tidak ada action item yang teridentifikasi secara otomatis.")
626
+ para.runs[0].italic = True
627
+ para.runs[0].font.color.rgb = RGBColor(128, 128, 128)
628
+
629
+ # Spacer
630
+ doc.add_paragraph()
631
+
632
+ def _add_transcript_section(self, doc: Document, transcript: List[TranscriptSegment]):
633
+ """Add full transcript section"""
634
+ # Section heading
635
+ heading = doc.add_heading("Transkrip Percakapan", level=1)
636
+ heading.runs[0].font.size = Pt(self.config.heading1_font_size)
637
+
638
+ if not transcript:
639
+ para = doc.add_paragraph()
640
+ para.add_run("Tidak ada transkrip yang tersedia.")
641
+ para.runs[0].italic = True
642
+ return
643
+
644
+ # Add each segment
645
+ for seg in transcript:
646
+ para = doc.add_paragraph()
647
+
648
+ # Timestamp
649
+ if self.config.include_timestamps:
650
+ timestamp = self._format_timestamp(seg.start, seg.end)
651
+
652
+ # Speaker label with color
653
+ speaker_run = para.add_run(f"{seg.speaker_id} [{timestamp}]: ")
654
+ speaker_run.bold = True
655
+
656
+ if self.config.include_speaker_colors:
657
+ color = self._speaker_color_map.get(seg.speaker_id, RGBColor(0, 0, 0))
658
+ speaker_run.font.color.rgb = color
659
+ else:
660
+ speaker_run = para.add_run(f"{seg.speaker_id}: ")
661
+ speaker_run.bold = True
662
+
663
+ # Transcript text (sanitize placeholder/fallback strings)
664
+ text = seg.text or ""
665
+ cleaned = self._clean_text_for_doc(text)
666
+ para.add_run(cleaned)
667
+
668
+ # Mark overlapping speech
669
+ if seg.is_overlap:
670
+ overlap_run = para.add_run(" [OVERLAP]")
671
+ overlap_run.italic = True
672
+ overlap_run.font.color.rgb = RGBColor(255, 102, 0)
673
+ overlap_run.font.size = Pt(9)
674
+
675
+ def _add_footer(self, doc: Document):
676
+ """Add document footer"""
677
+ # Separator line
678
+ doc.add_paragraph()
679
+ separator = doc.add_paragraph("─" * 70)
680
+ separator.alignment = WD_ALIGN_PARAGRAPH.CENTER
681
+
682
+ # Footer text
683
+ footer_para = doc.add_paragraph()
684
+
685
+ timestamp = datetime.now().strftime("%d %B %Y, %H:%M:%S")
686
+ footer_text = f"Dokumen ini dihasilkan secara otomatis pada {timestamp}"
687
+
688
+ footer_run = footer_para.add_run(footer_text)
689
+ footer_run.italic = True
690
+ footer_run.font.size = Pt(9)
691
+ footer_run.font.color.rgb = RGBColor(128, 128, 128)
692
+ footer_para.alignment = WD_ALIGN_PARAGRAPH.CENTER
693
+
694
+ # Disclaimer
695
+ disclaimer_para = doc.add_paragraph()
696
+ disclaimer_text = (
697
+ "Hasil transkripsi dan ringkasan mungkin mengandung ketidakakuratan. "
698
+ "Harap verifikasi informasi penting."
699
+ )
700
+
701
+ disclaimer_run = disclaimer_para.add_run(disclaimer_text)
702
+ disclaimer_run.italic = True
703
+ disclaimer_run.font.size = Pt(8)
704
+ disclaimer_run.font.color.rgb = RGBColor(150, 150, 150)
705
+ disclaimer_para.alignment = WD_ALIGN_PARAGRAPH.CENTER
706
+
707
+ def _is_placeholder_text(self, text: Optional[str]) -> bool:
708
+ """Detect summarizer/ASR fallback placeholder text."""
709
+ if not text:
710
+ return True
711
+ t = str(text).strip()
712
+ # common placeholder patterns from summarizer / transcriber fallbacks
713
+ if re.search(r"\[\s*Transkripsi placeholder", t, re.I):
714
+ return True
715
+ if re.search(r"placeholder", t, re.I) and len(t) < 120:
716
+ return True
717
+ return False
718
+
719
+ def _clean_text_for_doc(self, text: Optional[str]) -> str:
720
+ """Clean text for document: replace raw placeholders with user-friendly notices."""
721
+ if not text or self._is_placeholder_text(text):
722
+ return "[transkripsi tidak tersedia]"
723
+ # Remove any bracketed placeholder fragments embedded in text
724
+ cleaned = re.sub(r"\[\s*Transkripsi placeholder[^\]]*\]", "", str(text), flags=re.I).strip()
725
+ return cleaned or "[transkripsi tidak tersedia]"
726
+
727
+ @staticmethod
728
+ def _format_timestamp(start: float, end: float) -> str:
729
+ """Format time range as HH:MM:SS"""
730
+
731
+ def sec_to_str(sec: float) -> str:
732
+ sec = max(0.0, float(sec))
733
+ h = int(sec // 3600)
734
+ m = int((sec % 3600) // 60)
735
+ s = int(sec % 60)
736
+
737
+ if h > 0:
738
+ return f"{h:02d}:{m:02d}:{s:02d}"
739
+ return f"{m:02d}:{s:02d}"
740
+
741
+ return f"{sec_to_str(start)}–{sec_to_str(end)}"
742
+
743
+ def _save_minimal_docx(self, path: str, paragraphs: List[str]):
744
+ """Create a minimal valid .docx (zip package) containing plain paragraphs.
745
+ This is a lightweight fallback when python-docx is not installed, to ensure
746
+ the generated file can be opened in Word.
747
+ """
748
+ import zipfile
749
+
750
+ def _escape_xml(s: str) -> str:
751
+ return (
752
+ s.replace("&", "&amp;")
753
+ .replace("<", "&lt;")
754
+ .replace(">", "&gt;")
755
+ .replace('"', "&quot;")
756
+ .replace("'", "&apos;")
757
+ )
758
+
759
+ content_types = (
760
+ '<?xml version="1.0" encoding="UTF-8"?>\n'
761
+ '<Types xmlns="http://schemas.openxmlformats.org/package/2006/content-types">\n'
762
+ ' <Default Extension="rels" ContentType="application/vnd.openxmlformats-package.relationships+xml"/>\n'
763
+ ' <Default Extension="xml" ContentType="application/xml"/>\n'
764
+ ' <Override PartName="/word/document.xml" ContentType="application/vnd.openxmlformats-officedocument.wordprocessingml.document.main+xml"/>\n'
765
+ "</Types>"
766
+ )
767
+
768
+ rels = (
769
+ '<?xml version="1.0" encoding="UTF-8"?>\n'
770
+ '<Relationships xmlns="http://schemas.openxmlformats.org/package/2006/relationships">\n'
771
+ ' <Relationship Id="rId1" Type="http://schemas.openxmlformats.org/officeDocument/2006/relationships/officeDocument" Target="word/document.xml"/>\n'
772
+ "</Relationships>"
773
+ )
774
+
775
+ doc_xml_header = (
776
+ '<?xml version="1.0" encoding="UTF-8" standalone="yes"?>\n'
777
+ '<w:document xmlns:wpc="http://schemas.microsoft.com/office/word/2010/wordprocessingCanvas" '
778
+ 'xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006" '
779
+ 'xmlns:o="urn:schemas-microsoft-com:office:office" '
780
+ 'xmlns:r="http://schemas.openxmlformats.org/officeDocument/2006/relationships" '
781
+ 'xmlns:m="http://schemas.openxmlformats.org/officeDocument/2006/math" '
782
+ 'xmlns:v="urn:schemas-microsoft-com:vml" '
783
+ 'xmlns:wp14="http://schemas.microsoft.com/office/word/2010/wordprocessingDrawing" '
784
+ 'xmlns:wp="http://schemas.openxmlformats.org/drawingml/2006/wordprocessingDrawing" '
785
+ 'xmlns:w10="urn:schemas-microsoft-com:office:word" '
786
+ 'xmlns:w="http://schemas.openxmlformats.org/wordprocessingml/2006/main" '
787
+ 'xmlns:w14="http://schemas.microsoft.com/office/word/2010/wordml" '
788
+ 'xmlns:wpg="http://schemas.microsoft.com/office/word/2010/wordprocessingGroup" '
789
+ 'xmlns:wpi="http://schemas.microsoft.com/office/word/2010/wordprocessingInk" '
790
+ 'xmlns:wne="http://schemas.microsoft.com/office/word/2006/wordml" '
791
+ 'xmlns:wps="http://schemas.microsoft.com/office/word/2010/wordprocessingShape">\n'
792
+ " <w:body>\n"
793
+ )
794
+
795
+ doc_xml_footer = (
796
+ " <w:sectPr>\n"
797
+ ' <w:pgSz w:w="11900" w:h="16840"/>\n'
798
+ ' <w:pgMar w:top="1440" w:right="1440" w:bottom="1440" w:left="1440" w:header="720" w:footer="720" w:gutter="0"/>\n'
799
+ " </w:sectPr>\n"
800
+ " </w:body>\n"
801
+ "</w:document>"
802
+ )
803
+
804
+ # Build paragraphs as simple <w:p><w:r><w:t>text</w:t></w:r></w:p>
805
+ paras_xml = []
806
+ for p in paragraphs:
807
+ t = _escape_xml(p.strip())
808
+ if not t:
809
+ # preserve blank line
810
+ paras_xml.append(" <w:p/>\n")
811
+ else:
812
+ paras_xml.append(f' <w:p><w:r><w:t xml:space="preserve">{t}</w:t></w:r></w:p>\n')
813
+
814
+ doc_xml = doc_xml_header + "".join(paras_xml) + doc_xml_footer
815
+
816
+ with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as z:
817
+ z.writestr("[Content_Types].xml", content_types)
818
+ z.writestr("_rels/.rels", rels)
819
+ z.writestr("word/document.xml", doc_xml)
820
+
821
+ def _extract_paragraph_texts(self, doc: Document) -> List[str]:
822
+ """Get paragraphs text for python-docx Document or fallback Document"""
823
+ paras: List[str] = []
824
+ # python-docx Document
825
+ try:
826
+ # using attribute if present
827
+ if hasattr(doc, "paragraphs"):
828
+ for p in doc.paragraphs:
829
+ paras.append(p.text)
830
+ return paras
831
+ except Exception:
832
+ pass
833
+
834
+ # fallback minimal Document implementation
835
+ if hasattr(doc, "_paragraphs"):
836
+ for p in doc._paragraphs:
837
+ if hasattr(p, "runs"):
838
+ paras.append(" ".join(getattr(r, "text", "") for r in p.runs))
839
+ else:
840
+ paras.append(str(p))
841
+ return paras
842
+
843
+ @staticmethod
844
+ def _sanitize_filename(filename: str) -> str:
845
+ """Remove invalid characters from filename"""
846
+ import re
847
+
848
+ # Remove invalid characters
849
+ sanitized = re.sub(r'[<>:"/\\|?*]', "", filename)
850
+ # Replace spaces with underscores
851
+ sanitized = sanitized.replace(" ", "_")
852
+ return sanitized
src/evaluator.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation Module
3
+ =================
4
+ Implements WER, DER, and other metrics for thesis validation.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import csv
10
+ import re
11
+ from dataclasses import dataclass, field
12
+ from datetime import datetime
13
+ from pathlib import Path
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import numpy as np
17
+
18
+ try:
19
+ from jiwer import cer, mer, process_words, wer, wil
20
+
21
+ JIWER_AVAILABLE = True
22
+ except ImportError:
23
+ JIWER_AVAILABLE = False
24
+ print("[Evaluator] Warning: jiwer not installed. WER calculation will use fallback.")
25
+
26
+
27
+ @dataclass
28
+ class WERResult:
29
+ """Word Error Rate evaluation result"""
30
+
31
+ wer: float
32
+ mer: float = 0.0 # Match Error Rate
33
+ wil: float = 0.0 # Word Information Lost
34
+ cer: float = 0.0 # Character Error Rate
35
+ substitutions: int = 0
36
+ deletions: int = 0
37
+ insertions: int = 0
38
+ hits: int = 0
39
+ reference_length: int = 0
40
+ hypothesis_length: int = 0
41
+
42
+ def to_dict(self) -> Dict[str, Any]:
43
+ """Convert to dictionary"""
44
+ return {
45
+ "wer": self.wer,
46
+ "mer": self.mer,
47
+ "wil": self.wil,
48
+ "cer": self.cer,
49
+ "substitutions": self.substitutions,
50
+ "deletions": self.deletions,
51
+ "insertions": self.insertions,
52
+ "hits": self.hits,
53
+ "reference_length": self.reference_length,
54
+ "hypothesis_length": self.hypothesis_length,
55
+ }
56
+
57
+
58
+ @dataclass
59
+ class DERResult:
60
+ """Diarization Error Rate evaluation result"""
61
+
62
+ der: float
63
+ missed_speech: float = 0.0
64
+ false_alarm: float = 0.0
65
+ speaker_confusion: float = 0.0
66
+ total_duration: float = 0.0
67
+ num_speakers_ref: int = 0
68
+ num_speakers_hyp: int = 0
69
+
70
+ def to_dict(self) -> Dict[str, Any]:
71
+ """Convert to dictionary"""
72
+ return {
73
+ "der": self.der,
74
+ "missed_speech": self.missed_speech,
75
+ "false_alarm": self.false_alarm,
76
+ "speaker_confusion": self.speaker_confusion,
77
+ "total_duration": self.total_duration,
78
+ "num_speakers_ref": self.num_speakers_ref,
79
+ "num_speakers_hyp": self.num_speakers_hyp,
80
+ }
81
+
82
+
83
+ @dataclass
84
+ class SummaryResult:
85
+ """Summary evaluation result (ROUGE/BERTScore)"""
86
+
87
+ rouge: Dict[str, float]
88
+ bertscore: Dict[str, float]
89
+
90
+
91
+ @dataclass
92
+ class EvaluationResult:
93
+ """Combined evaluation result"""
94
+
95
+ sample_name: str
96
+ condition: str
97
+ wer_result: Optional[WERResult] = None
98
+ der_result: Optional[DERResult] = None
99
+ summary_result: Optional[SummaryResult] = None
100
+ metadata: Dict[str, Any] = field(default_factory=dict)
101
+
102
+
103
+ class Evaluator:
104
+ """
105
+ Evaluation metrics calculator for ASR and Diarization.
106
+
107
+ Provides:
108
+ - WER (Word Error Rate) for ASR evaluation
109
+ - DER (Diarization Error Rate) for speaker diarization evaluation
110
+ - Report generation for thesis documentation
111
+
112
+ Example:
113
+ >>> evaluator = Evaluator()
114
+ >>> wer_result = evaluator.calculate_wer(reference, hypothesis)
115
+ >>> print(f"WER: {wer_result.wer:.2%}")
116
+ """
117
+
118
+ def __init__(self, output_dir: str = "./data/output"):
119
+ """
120
+ Initialize Evaluator.
121
+
122
+ Args:
123
+ output_dir: Directory for evaluation outputs
124
+ """
125
+ self.output_dir = Path(output_dir)
126
+ self.output_dir.mkdir(parents=True, exist_ok=True)
127
+
128
+ # =========================================================================
129
+ # Text Preprocessing
130
+ # =========================================================================
131
+
132
+ @staticmethod
133
+ def preprocess_text(
134
+ text: str,
135
+ lowercase: bool = True,
136
+ remove_punctuation: bool = True,
137
+ normalize_whitespace: bool = True,
138
+ remove_filler_words: bool = False,
139
+ ) -> str:
140
+ """
141
+ Preprocess text for fair WER comparison.
142
+
143
+ Args:
144
+ text: Input text
145
+ lowercase: Convert to lowercase
146
+ remove_punctuation: Remove punctuation marks
147
+ normalize_whitespace: Normalize whitespace
148
+ remove_filler_words: Remove filler words (eh, um, etc.)
149
+
150
+ Returns:
151
+ Preprocessed text
152
+ """
153
+ if not text:
154
+ return ""
155
+
156
+ # Lowercase
157
+ if lowercase:
158
+ text = text.lower()
159
+
160
+ # Remove punctuation
161
+ if remove_punctuation:
162
+ text = re.sub(r"[^\w\s]", " ", text)
163
+
164
+ # Remove filler words (common in Indonesian)
165
+ if remove_filler_words:
166
+ filler_words = ["eh", "em", "um", "uh", "ah", "hmm", "eee", "anu"]
167
+ pattern = r"\b(" + "|".join(filler_words) + r")\b"
168
+ text = re.sub(pattern, "", text, flags=re.IGNORECASE)
169
+
170
+ # Normalize whitespace
171
+ if normalize_whitespace:
172
+ text = " ".join(text.split())
173
+
174
+ return text.strip()
175
+
176
+ # =========================================================================
177
+ # WER Calculation
178
+ # =========================================================================
179
+
180
+ def calculate_wer(self, reference: str, hypothesis: str, preprocess: bool = True) -> WERResult:
181
+ """
182
+ Calculate Word Error Rate and related metrics.
183
+
184
+ WER = (S + D + I) / N
185
+ where:
186
+ S = Substitutions
187
+ D = Deletions
188
+ I = Insertions
189
+ N = Total words in reference
190
+
191
+ Args:
192
+ reference: Ground truth text
193
+ hypothesis: ASR output text
194
+ preprocess: Apply text preprocessing
195
+
196
+ Returns:
197
+ WERResult with detailed metrics
198
+ """
199
+ # Preprocess
200
+ if preprocess:
201
+ reference = self.preprocess_text(reference)
202
+ hypothesis = self.preprocess_text(hypothesis)
203
+
204
+ # Handle empty cases
205
+ if not reference:
206
+ return WERResult(
207
+ wer=1.0 if hypothesis else 0.0,
208
+ reference_length=0,
209
+ hypothesis_length=len(hypothesis.split()) if hypothesis else 0,
210
+ )
211
+
212
+ if not hypothesis:
213
+ return WERResult(
214
+ wer=1.0,
215
+ deletions=len(reference.split()),
216
+ reference_length=len(reference.split()),
217
+ hypothesis_length=0,
218
+ )
219
+
220
+ # Use jiwer if available
221
+ if JIWER_AVAILABLE:
222
+ try:
223
+ wer_score = wer(reference, hypothesis)
224
+ mer_score = mer(reference, hypothesis)
225
+ wil_score = wil(reference, hypothesis)
226
+ cer_score = cer(reference, hypothesis)
227
+
228
+ # Get detailed breakdown
229
+ output = process_words(reference, hypothesis)
230
+
231
+ return WERResult(
232
+ wer=wer_score,
233
+ mer=mer_score,
234
+ wil=wil_score,
235
+ cer=cer_score,
236
+ substitutions=output.substitutions,
237
+ deletions=output.deletions,
238
+ insertions=output.insertions,
239
+ hits=output.hits,
240
+ reference_length=len(reference.split()),
241
+ hypothesis_length=len(hypothesis.split()),
242
+ )
243
+ except Exception as e:
244
+ print(f"[Evaluator] jiwer calculation failed: {e}")
245
+
246
+ # Fallback: manual calculation using edit distance
247
+ return self._calculate_wer_manual(reference, hypothesis)
248
+
249
+ def _calculate_wer_manual(self, reference: str, hypothesis: str) -> WERResult:
250
+ """Calculate WER using manual edit distance (fallback)"""
251
+ ref_words = reference.split()
252
+ hyp_words = hypothesis.split()
253
+
254
+ # Dynamic programming for edit distance
255
+ m, n = len(ref_words), len(hyp_words)
256
+ dp = [[0] * (n + 1) for _ in range(m + 1)]
257
+
258
+ # Initialize
259
+ for i in range(m + 1):
260
+ dp[i][0] = i
261
+ for j in range(n + 1):
262
+ dp[0][j] = j
263
+
264
+ # Fill DP table
265
+ for i in range(1, m + 1):
266
+ for j in range(1, n + 1):
267
+ if ref_words[i - 1] == hyp_words[j - 1]:
268
+ dp[i][j] = dp[i - 1][j - 1]
269
+ else:
270
+ dp[i][j] = min(
271
+ dp[i - 1][j] + 1, # Deletion
272
+ dp[i][j - 1] + 1, # Insertion
273
+ dp[i - 1][j - 1] + 1, # Substitution
274
+ )
275
+
276
+ # Backtrack to count operations
277
+ i, j = m, n
278
+ substitutions = deletions = insertions = hits = 0
279
+
280
+ while i > 0 or j > 0:
281
+ if i > 0 and j > 0 and ref_words[i - 1] == hyp_words[j - 1]:
282
+ hits += 1
283
+ i -= 1
284
+ j -= 1
285
+ elif i > 0 and j > 0 and dp[i][j] == dp[i - 1][j - 1] + 1:
286
+ substitutions += 1
287
+ i -= 1
288
+ j -= 1
289
+ elif i > 0 and dp[i][j] == dp[i - 1][j] + 1:
290
+ deletions += 1
291
+ i -= 1
292
+ else:
293
+ insertions += 1
294
+ j -= 1
295
+
296
+ total_errors = substitutions + deletions + insertions
297
+ wer_score = total_errors / len(ref_words) if ref_words else 0.0
298
+
299
+ return WERResult(
300
+ wer=wer_score,
301
+ substitutions=substitutions,
302
+ deletions=deletions,
303
+ insertions=insertions,
304
+ hits=hits,
305
+ reference_length=len(ref_words),
306
+ hypothesis_length=len(hyp_words),
307
+ )
308
+
309
+ def calculate_wer_batch(
310
+ self, references: List[str], hypotheses: List[str], preprocess: bool = True
311
+ ) -> Tuple[float, List[WERResult]]:
312
+ """
313
+ Calculate WER for multiple pairs and return aggregate.
314
+
315
+ Args:
316
+ references: List of reference texts
317
+ hypotheses: List of hypothesis texts
318
+ preprocess: Apply preprocessing
319
+
320
+ Returns:
321
+ Tuple of (weighted average WER, list of individual results)
322
+ """
323
+ if len(references) != len(hypotheses):
324
+ raise ValueError("Reference and hypothesis lists must have same length")
325
+
326
+ results = []
327
+ for ref, hyp in zip(references, hypotheses):
328
+ result = self.calculate_wer(ref, hyp, preprocess)
329
+ results.append(result)
330
+
331
+ # Calculate weighted average WER
332
+ total_ref_words = sum(r.reference_length for r in results)
333
+ total_errors = sum(r.substitutions + r.deletions + r.insertions for r in results)
334
+
335
+ avg_wer = total_errors / total_ref_words if total_ref_words > 0 else 0.0
336
+
337
+ return avg_wer, results
338
+
339
+ # =========================================================================
340
+ # DER Calculation
341
+ # =========================================================================
342
+
343
+ def calculate_der(
344
+ self,
345
+ reference_segments: List[Tuple[str, float, float]],
346
+ hypothesis_segments: List[Tuple[str, float, float]],
347
+ collar: float = 0.25,
348
+ ) -> DERResult:
349
+ """
350
+ Calculate Diarization Error Rate.
351
+
352
+ DER = (Missed Speech + False Alarm + Speaker Confusion) / Total Reference Duration
353
+
354
+ Args:
355
+ reference_segments: Ground truth [(speaker_id, start, end), ...]
356
+ hypothesis_segments: System output [(speaker_id, start, end), ...]
357
+ collar: Forgiveness collar in seconds (standard: 0.25s)
358
+
359
+ Returns:
360
+ DERResult with detailed breakdown
361
+ """
362
+ if not reference_segments:
363
+ return DERResult(
364
+ der=0.0,
365
+ total_duration=0.0,
366
+ num_speakers_ref=0,
367
+ num_speakers_hyp=(
368
+ len(set(s[0] for s in hypothesis_segments)) if hypothesis_segments else 0
369
+ ),
370
+ )
371
+
372
+ # Get unique speakers
373
+ ref_speakers = set(s[0] for s in reference_segments)
374
+ hyp_speakers = set(s[0] for s in hypothesis_segments) if hypothesis_segments else set()
375
+
376
+ # Calculate total reference duration
377
+ total_ref_duration = sum(end - start for _, start, end in reference_segments)
378
+
379
+ if total_ref_duration == 0:
380
+ return DERResult(
381
+ der=0.0,
382
+ total_duration=0.0,
383
+ num_speakers_ref=len(ref_speakers),
384
+ num_speakers_hyp=len(hyp_speakers),
385
+ )
386
+
387
+ # Frame-based evaluation
388
+ resolution = 0.01 # 10ms resolution
389
+
390
+ # Get time range
391
+ all_starts = [s[1] for s in reference_segments + (hypothesis_segments or [])]
392
+ all_ends = [s[2] for s in reference_segments + (hypothesis_segments or [])]
393
+
394
+ min_time = min(all_starts) if all_starts else 0
395
+ max_time = max(all_ends) if all_ends else 0
396
+
397
+ # Initialize counters
398
+ missed_speech = 0.0
399
+ false_alarm = 0.0
400
+ speaker_confusion = 0.0
401
+
402
+ # Frame-by-frame evaluation
403
+ t = min_time
404
+ while t < max_time:
405
+ t_mid = t + resolution / 2
406
+
407
+ # Get reference speakers at time t
408
+ ref_spk_at_t = set()
409
+ for speaker, start, end in reference_segments:
410
+ # Apply collar
411
+ if (start + collar) <= t_mid < (end - collar):
412
+ ref_spk_at_t.add(speaker)
413
+
414
+ # Get hypothesis speakers at time t
415
+ hyp_spk_at_t = set()
416
+ if hypothesis_segments:
417
+ for speaker, start, end in hypothesis_segments:
418
+ if start <= t_mid < end:
419
+ hyp_spk_at_t.add(speaker)
420
+
421
+ # Count errors
422
+ if ref_spk_at_t and not hyp_spk_at_t:
423
+ # Missed speech: reference has speech, hypothesis doesn't
424
+ missed_speech += resolution
425
+ elif hyp_spk_at_t and not ref_spk_at_t:
426
+ # False alarm: hypothesis has speech, reference doesn't
427
+ false_alarm += resolution
428
+ elif ref_spk_at_t and hyp_spk_at_t:
429
+ # Both have speech - check for speaker confusion
430
+ # Simplified: if number of speakers differs, count as confusion
431
+ ref_count = len(ref_spk_at_t)
432
+ hyp_count = len(hyp_spk_at_t)
433
+
434
+ if ref_count != hyp_count:
435
+ # Partial confusion
436
+ confusion_ratio = abs(ref_count - hyp_count) / max(ref_count, hyp_count)
437
+ speaker_confusion += resolution * confusion_ratio
438
+
439
+ t += resolution
440
+
441
+ # Calculate DER
442
+ total_error = missed_speech + false_alarm + speaker_confusion
443
+ der = total_error / total_ref_duration
444
+
445
+ return DERResult(
446
+ der=min(der, 1.0), # Cap at 100%
447
+ missed_speech=missed_speech / total_ref_duration,
448
+ false_alarm=false_alarm / total_ref_duration,
449
+ speaker_confusion=speaker_confusion / total_ref_duration,
450
+ total_duration=total_ref_duration,
451
+ num_speakers_ref=len(ref_speakers),
452
+ num_speakers_hyp=len(hyp_speakers),
453
+ )
454
+
455
+ # =========================================================================
456
+ # Summary evaluation (ROUGE, BERTScore)
457
+ # =========================================================================
458
+
459
+ def calculate_summary_metrics(self, reference: str, hypothesis: str) -> SummaryResult:
460
+ """Calculate ROUGE and BERTScore for summaries.
461
+
462
+ Returns a SummaryResult with compact numeric metrics (rouge1/2/l F1 and bertscore P/R/F1 average).
463
+ """
464
+ try:
465
+ import evaluate
466
+
467
+ rouge = evaluate.load("rouge")
468
+ bert = evaluate.load("bertscore")
469
+
470
+ # ROUGE expects lists
471
+ rouge_res = rouge.compute(predictions=[hypothesis], references=[reference])
472
+ # bertscore returns lists of precision/recall/f1
473
+ bert_res = bert.compute(predictions=[hypothesis], references=[reference], lang="id")
474
+
475
+ # pick common metrics
476
+ rouge_out = {
477
+ "rouge1_f": float(rouge_res.get("rouge1_f", 0.0)),
478
+ "rouge2_f": float(rouge_res.get("rouge2_f", 0.0)),
479
+ "rougel_f": float(rouge_res.get("rougeL_f", 0.0)),
480
+ }
481
+
482
+ bert_out = {
483
+ "bertscore_precision": float(bert_res.get("precision", [0.0])[0]),
484
+ "bertscore_recall": float(bert_res.get("recall", [0.0])[0]),
485
+ "bertscore_f1": float(bert_res.get("f1", [0.0])[0]),
486
+ }
487
+
488
+ return SummaryResult(rouge=rouge_out, bertscore=bert_out)
489
+ except Exception as e:
490
+ print(f"[Evaluator] Summary metric computation failed: {e}")
491
+ # fallback: empty metrics
492
+ return SummaryResult(rouge={}, bertscore={})
493
+
494
+ # =========================================================================
495
+ # Report Generation
496
+ # =========================================================================
497
+
498
+ def generate_evaluation_report(
499
+ self,
500
+ wer_results: List[WERResult],
501
+ der_results: Optional[List[DERResult]] = None,
502
+ summary_results: Optional[List[SummaryResult]] = None,
503
+ sample_names: Optional[List[str]] = None,
504
+ condition_name: str = "Unknown",
505
+ metadata: Optional[Dict[str, Any]] = None,
506
+ ) -> str:
507
+ """
508
+ Generate formatted evaluation report for thesis.
509
+
510
+ Args:
511
+ wer_results: List of WER results
512
+ der_results: List of DER results (optional)
513
+ sample_names: Names for each sample
514
+ condition_name: Name of test condition
515
+ metadata: Optional dictionary of hyperparameters / tuning info used during the run
516
+
517
+ Returns:
518
+ Formatted report string
519
+ """
520
+ lines = []
521
+ lines.append("=" * 70)
522
+ lines.append("LAPORAN EVALUASI SISTEM NOTULENSI RAPAT OTOMATIS")
523
+ lines.append(f"Kondisi: {condition_name}")
524
+ lines.append(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
525
+ lines.append("=" * 70)
526
+ lines.append("")
527
+
528
+ # WER Summary
529
+ lines.append("1. EVALUASI ASR (Word Error Rate)")
530
+ lines.append("-" * 50)
531
+
532
+ if wer_results:
533
+ wer_values = [r.wer for r in wer_results]
534
+ avg_wer = np.mean(wer_values)
535
+ std_wer = np.std(wer_values)
536
+ min_wer = np.min(wer_values)
537
+ max_wer = np.max(wer_values)
538
+
539
+ total_subs = sum(r.substitutions for r in wer_results)
540
+ total_dels = sum(r.deletions for r in wer_results)
541
+ total_ins = sum(r.insertions for r in wer_results)
542
+ total_hits = sum(r.hits for r in wer_results)
543
+
544
+ lines.append(f" Jumlah sampel : {len(wer_results)}")
545
+ lines.append(f" WER rata-rata : {avg_wer:.4f} ({avg_wer*100:.2f}%)")
546
+ lines.append(f" Standar deviasi : {std_wer:.4f}")
547
+ lines.append(f" WER minimum : {min_wer:.4f} ({min_wer*100:.2f}%)")
548
+ lines.append(f" WER maksimum : {max_wer:.4f} ({max_wer*100:.2f}%)")
549
+ lines.append("")
550
+ lines.append(" Detail Error Total:")
551
+ lines.append(f" - Substitutions : {total_subs}")
552
+ lines.append(f" - Deletions : {total_dels}")
553
+ lines.append(f" - Insertions : {total_ins}")
554
+ lines.append(f" - Correct (Hits) : {total_hits}")
555
+
556
+ # Per-sample details
557
+ if sample_names and len(sample_names) == len(wer_results):
558
+ lines.append("")
559
+ lines.append(" Detail per Sampel:")
560
+ for name, result in zip(sample_names, wer_results):
561
+ lines.append(f" - {name}: WER = {result.wer:.4f} ({result.wer*100:.2f}%)")
562
+ else:
563
+ lines.append(" Tidak ada data WER untuk dievaluasi.")
564
+
565
+ lines.append("")
566
+
567
+ # DER Summary
568
+ lines.append("2. EVALUASI DIARIZATION (Diarization Error Rate)")
569
+ lines.append("-" * 50)
570
+
571
+ if der_results:
572
+ der_values = [r.der for r in der_results]
573
+ avg_der = np.mean(der_values)
574
+ std_der = np.std(der_values)
575
+
576
+ avg_missed = np.mean([r.missed_speech for r in der_results])
577
+ avg_fa = np.mean([r.false_alarm for r in der_results])
578
+ avg_conf = np.mean([r.speaker_confusion for r in der_results])
579
+
580
+ lines.append(f" Jumlah sampel : {len(der_results)}")
581
+ lines.append(f" DER rata-rata : {avg_der:.4f} ({avg_der*100:.2f}%)")
582
+ lines.append(f" Standar deviasi : {std_der:.4f}")
583
+ lines.append("")
584
+ lines.append(" Komponen Error (rata-rata):")
585
+ lines.append(f" - Missed Speech : {avg_missed:.4f} ({avg_missed*100:.2f}%)")
586
+ lines.append(f" - False Alarm : {avg_fa:.4f} ({avg_fa*100:.2f}%)")
587
+ lines.append(f" - Speaker Confusion: {avg_conf:.4f} ({avg_conf*100:.2f}%)")
588
+
589
+ # Per-sample details
590
+ if sample_names and len(sample_names) == len(der_results):
591
+ lines.append("")
592
+ lines.append(" Detail per Sampel:")
593
+ for name, result in zip(sample_names, der_results):
594
+ lines.append(f" - {name}: DER = {result.der:.4f} ({result.der*100:.2f}%)")
595
+ else:
596
+ lines.append(" Tidak ada data DER untuk dievaluasi.")
597
+
598
+ lines.append("")
599
+ # Summary evaluation (ROUGE, BERTScore)
600
+ lines.append("3. EVALUASI RINGKASAN (Ringkasan/Abstraksi)")
601
+ lines.append("-" * 50)
602
+ if summary_results:
603
+ try:
604
+ avg_rouge1 = np.mean([r.rouge.get("rouge1_f", 0.0) for r in summary_results])
605
+ avg_rouge2 = np.mean([r.rouge.get("rouge2_f", 0.0) for r in summary_results])
606
+ avg_rougel = np.mean([r.rouge.get("rougel_f", 0.0) for r in summary_results])
607
+ avg_bertscore = np.mean([r.bertscore.get("bertscore_f1", 0.0) for r in summary_results])
608
+ lines.append(f" Jumlah sampel : {len(summary_results)}")
609
+ lines.append(f" ROUGE-1 F1 (avg) : {avg_rouge1:.4f}")
610
+ lines.append(f" ROUGE-2 F1 (avg) : {avg_rouge2:.4f}")
611
+ lines.append(f" ROUGE-L F1 (avg) : {avg_rougel:.4f}")
612
+ lines.append(f" BERTScore F1 (avg) : {avg_bertscore:.4f}")
613
+ except Exception as e:
614
+ lines.append(f" (summary metric aggregation failed: {e})")
615
+ else:
616
+ lines.append(" Tidak ada data ringkasan untuk dievaluasi.")
617
+
618
+ lines.append("")
619
+
620
+ # Include metadata/hyperparameters if provided
621
+ if metadata:
622
+ lines.append("4. CONFIGURATION & HYPERPARAMETERS")
623
+ lines.append("-" * 50)
624
+ try:
625
+ # Print metadata items in sorted order for consistency
626
+ for k in sorted(metadata.keys()):
627
+ v = metadata[k]
628
+ # For nested dicts, pretty-print a compact representation
629
+ if isinstance(v, dict):
630
+ if not v:
631
+ lines.append(f" - {k}: {{}}")
632
+ else:
633
+ lines.append(f" - {k}:")
634
+ for kk, vv in v.items():
635
+ lines.append(f" - {kk}: {vv}")
636
+ else:
637
+ lines.append(f" - {k}: {v}")
638
+ except Exception as e:
639
+ lines.append(f" - (metadata formatting failed: {e})")
640
+
641
+ lines.append("")
642
+
643
+ lines.append("=" * 70)
644
+ lines.append("Catatan:")
645
+ lines.append(
646
+ "- Evaluasi WER menggunakan preprocessing standar (lowercase, hapus tanda baca)"
647
+ )
648
+ lines.append("- Evaluasi DER menggunakan collar forgiveness 0.25 detik")
649
+ lines.append("=" * 70)
650
+
651
+ return "\n".join(lines)
652
+
653
+ def export_results_to_csv(
654
+ self, results: List[EvaluationResult], output_filename: str = "evaluation_results.csv"
655
+ ) -> str:
656
+ """
657
+ Export evaluation results to CSV for thesis appendix.
658
+
659
+ Args:
660
+ results: List of EvaluationResult objects
661
+ output_filename: Output CSV filename
662
+
663
+ Returns:
664
+ Path to saved CSV file
665
+ """
666
+ output_path = self.output_dir / output_filename
667
+
668
+ with open(output_path, "w", newline="", encoding="utf-8") as f:
669
+ writer = csv.writer(f)
670
+
671
+ # Header
672
+ writer.writerow(
673
+ [
674
+ "Sample",
675
+ "Condition",
676
+ "WER",
677
+ "MER",
678
+ "WIL",
679
+ "CER",
680
+ "Substitutions",
681
+ "Deletions",
682
+ "Insertions",
683
+ "Hits",
684
+ "Ref_Words",
685
+ "Hyp_Words",
686
+ "DER",
687
+ "Missed_Speech",
688
+ "False_Alarm",
689
+ "Speaker_Confusion",
690
+ # Summary metrics
691
+ "ROUGE1_F",
692
+ "ROUGE2_F",
693
+ "ROUGEL_F",
694
+ "BERTScore_F1",
695
+ "Duration_Sec",
696
+ "Num_Speakers_Ref",
697
+ "Num_Speakers_Hyp",
698
+ ]
699
+ )
700
+
701
+ # Data rows
702
+ for result in results:
703
+ wer = result.wer_result
704
+ der = result.der_result
705
+
706
+ row = [
707
+ result.sample_name,
708
+ result.condition,
709
+ # WER metrics
710
+ f"{wer.wer:.4f}" if wer else "",
711
+ f"{wer.mer:.4f}" if wer else "",
712
+ f"{wer.wil:.4f}" if wer else "",
713
+ f"{wer.cer:.4f}" if wer else "",
714
+ wer.substitutions if wer else "",
715
+ wer.deletions if wer else "",
716
+ wer.insertions if wer else "",
717
+ wer.hits if wer else "",
718
+ wer.reference_length if wer else "",
719
+ wer.hypothesis_length if wer else "",
720
+ # DER metrics
721
+ f"{der.der:.4f}" if der else "",
722
+ f"{der.missed_speech:.4f}" if der else "",
723
+ f"{der.false_alarm:.4f}" if der else "",
724
+ f"{der.speaker_confusion:.4f}" if der else "",
725
+ # Summary metrics
726
+ f"{result.summary_result.rouge.get('rouge1_f', ''):.4f}" if result.summary_result and result.summary_result.rouge else "",
727
+ f"{result.summary_result.rouge.get('rouge2_f', ''):.4f}" if result.summary_result and result.summary_result.rouge else "",
728
+ f"{result.summary_result.rouge.get('rougel_f', ''):.4f}" if result.summary_result and result.summary_result.rouge else "",
729
+ f"{result.summary_result.bertscore.get('bertscore_f1', ''):.4f}" if result.summary_result and result.summary_result.bertscore else "",
730
+ f"{der.total_duration:.2f}" if der else "",
731
+ der.num_speakers_ref if der else "",
732
+ der.num_speakers_hyp if der else "",
733
+ ]
734
+
735
+ writer.writerow(row)
736
+
737
+ return str(output_path)
738
+
739
+ def generate_summary_table(
740
+ self, results_by_condition: Dict[str, List[EvaluationResult]]
741
+ ) -> str:
742
+ """
743
+ Generate summary table comparing results across conditions.
744
+
745
+ Args:
746
+ results_by_condition: Dict mapping condition name to list of results
747
+
748
+ Returns:
749
+ Formatted table string
750
+ """
751
+ lines = []
752
+ lines.append("")
753
+ lines.append("TABEL RINGKASAN EVALUASI PER KONDISI")
754
+ lines.append("=" * 80)
755
+ lines.append("")
756
+
757
+ # Header
758
+ header = (
759
+ f"{'Kondisi':<20} {'N':>5} {'WER Mean':>10} {'WER Std':>10} "
760
+ f"{'DER Mean':>10} {'DER Std':>10}"
761
+ )
762
+ lines.append(header)
763
+ lines.append("-" * 80)
764
+
765
+ # Data rows
766
+ for condition, results in results_by_condition.items():
767
+ n = len(results)
768
+
769
+ # WER stats
770
+ wer_values = [r.wer_result.wer for r in results if r.wer_result]
771
+ wer_mean = np.mean(wer_values) if wer_values else 0
772
+ wer_std = np.std(wer_values) if wer_values else 0
773
+
774
+ # DER stats
775
+ der_values = [r.der_result.der for r in results if r.der_result]
776
+ der_mean = np.mean(der_values) if der_values else 0
777
+ der_std = np.std(der_values) if der_values else 0
778
+
779
+ row = (
780
+ f"{condition:<20} {n:>5} {wer_mean:>10.4f} {wer_std:>10.4f} "
781
+ f"{der_mean:>10.4f} {der_std:>10.4f}"
782
+ )
783
+ lines.append(row)
784
+
785
+ lines.append("-" * 80)
786
+ lines.append("")
787
+
788
+ return "\n".join(lines)
789
+
790
+ def save_report(self, report: str, filename: str = "evaluation_report.txt") -> str:
791
+ """Save evaluation report to file"""
792
+ output_path = self.output_dir / filename
793
+
794
+ with open(output_path, "w", encoding="utf-8") as f:
795
+ f.write(report)
796
+
797
+ return str(output_path)
src/nlp_utils.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Advanced NLP utilities: NER + dependency parsing wrapper with graceful fallbacks.
3
+
4
+ Provides a small abstraction `AdvancedNLPExtractor` that will use spaCy if available
5
+ (or fallback regex/heuristic extractors) to extract structured action items and
6
+ decisions from sentence-level metadata.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ import re
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ try:
16
+ import spacy
17
+ from spacy.language import Language
18
+
19
+ _HAS_SPACY = True
20
+ except Exception:
21
+ _HAS_SPACY = False
22
+
23
+ try:
24
+ from langdetect import detect as _detect_lang
25
+
26
+ _HAS_LANGDETECT = True
27
+ except Exception:
28
+ _HAS_LANGDETECT = False
29
+
30
+ logger = logging.getLogger("AdvancedNLP")
31
+
32
+
33
+ class AdvancedNLPExtractor:
34
+ """Wrapper providing NER and dependency-based extraction.
35
+
36
+ Usage:
37
+ extractor = AdvancedNLPExtractor()
38
+ items = extractor.extract_actions_from_sentences(sent_meta)
39
+
40
+ `sent_meta` is a list of dicts produced by `BERTSummarizer._get_sentences_with_meta`
41
+ where each dict contains at least `text`, `speaker_id`, `start`, `end`.
42
+ """
43
+
44
+ def __init__(self, lang: Optional[str] = None):
45
+ self.lang = lang
46
+ self._nlp: Optional[Language] = None
47
+ if _HAS_SPACY:
48
+ try:
49
+ model = self._choose_model(lang)
50
+ if model is not None:
51
+ self._nlp = spacy.load(model)
52
+ logger.info(f"Loaded spaCy model: {model}")
53
+ except Exception as e:
54
+ logger.warning(f"spaCy model load failed: {e}")
55
+ self._nlp = None
56
+ else:
57
+ logger.debug("spaCy not available; using heuristic fallbacks")
58
+
59
+ def _choose_model(self, lang: Optional[str]) -> Optional[str]:
60
+ # Prefer language-specific small models if available
61
+ if lang is None and _HAS_LANGDETECT:
62
+ return None # leave None to let caller decide based on text
63
+ if lang == "id":
64
+ return "id_core_news_sm"
65
+ if lang == "en":
66
+ return "en_core_web_sm"
67
+ # Fall back to cross-lingual entity model if present
68
+ return "xx_ent_wiki_sm"
69
+
70
+ def _detect_lang(self, text: str) -> Optional[str]:
71
+ if not _HAS_LANGDETECT:
72
+ return None
73
+ try:
74
+ return _detect_lang(text)
75
+ except Exception:
76
+ return None
77
+
78
+ def _get_doc(self, text: str):
79
+ # If spaCy is loaded, use it. Otherwise return None.
80
+ if self._nlp is None:
81
+ # try to lazily pick a model based on language
82
+ if _HAS_SPACY:
83
+ lang = self._detect_lang(text)
84
+ model = self._choose_model(lang)
85
+ if model:
86
+ try:
87
+ self._nlp = spacy.load(model)
88
+ logger.info(f"Lazy-loaded spaCy model: {model}")
89
+ except Exception:
90
+ self._nlp = None
91
+ return None
92
+ try:
93
+ return self._nlp(text)
94
+ except Exception:
95
+ return None
96
+
97
+ def extract_persons(self, text: str) -> List[str]:
98
+ doc = self._get_doc(text)
99
+ if doc is None:
100
+ # simple regex: capitalized words sequences
101
+ names = re.findall(r"\b([A-Z][a-z]{1,20}(?:\s+[A-Z][a-z]{1,20})*)\b", text)
102
+ return list(dict.fromkeys(names))
103
+ persons = [ent.text for ent in doc.ents if ent.label_ in ("PERSON", "PER")]
104
+ # preserve order, unique
105
+ return list(dict.fromkeys(persons))
106
+
107
+ def extract_actions_from_sentences(
108
+ self, sent_meta: List[Dict[str, Any]]
109
+ ) -> List[Dict[str, Any]]:
110
+ """Return candidate action items extracted from sentence metadata.
111
+
112
+ Each returned dict contains: {owner, task, sentence_idx, confidence}
113
+ """
114
+ results: List[Dict[str, Any]] = []
115
+
116
+ texts = [s["text"] for s in sent_meta]
117
+ full = " ".join(texts[: max(1, min(10, len(texts)))])
118
+ lang = self._detect_lang(full) if _HAS_LANGDETECT else None
119
+
120
+ for i, s in enumerate(sent_meta):
121
+ text = s.get("text", "").strip()
122
+ if not text:
123
+ continue
124
+
125
+ # Quick keyword filter (language-agnostic): if no action words, skip
126
+ if not re.search(
127
+ r"\b(akan|harus|perlu|tolong|mohon|harap|deadline|target|tugas|follow up|tindak lanjut|siapkan|buat|bikin|saya|aku|kami|kita)\b",
128
+ text,
129
+ flags=re.IGNORECASE,
130
+ ):
131
+ # also check for English keywords
132
+ if not re.search(
133
+ r"\b(will|shall|must|please|assign|task|deadline|action item|follow up|todo)\b",
134
+ text,
135
+ flags=re.IGNORECASE,
136
+ ):
137
+ continue
138
+
139
+ doc = self._get_doc(text)
140
+ owner: Optional[str] = None
141
+ task: Optional[str] = None
142
+ confidence = 0.5
143
+
144
+ # First, try to find PERSON entities in the sentence
145
+ if doc is not None:
146
+ persons = [ent.text for ent in doc.ents if ent.label_ in ("PERSON", "PER")]
147
+ if persons:
148
+ owner = persons[0]
149
+ confidence = 0.8
150
+
151
+ # dependency parse-based task extraction
152
+ try:
153
+ # find ROOT verb
154
+ root = None
155
+ for token in doc:
156
+ if token.dep_ == "ROOT" and token.pos_ in ("VERB", "AUX"):
157
+ root = token
158
+ break
159
+
160
+ if root is not None:
161
+ # look for direct objects / xcomp / ccomp
162
+ objs = [t for t in doc if t.dep_ in ("dobj", "obj", "xcomp", "ccomp")]
163
+ if objs:
164
+ task = " ".join([tok.text for tok in objs[0].subtree])
165
+ confidence = max(confidence, 0.7)
166
+ else:
167
+ # fallback: use root subtree as task
168
+ task = " ".join([tok.text for tok in root.subtree])
169
+ confidence = max(confidence, 0.6)
170
+
171
+ # If no owner found, search preceding tokens for personal pronouns
172
+ if owner is None:
173
+ pron = [t for t in doc if t.pos_ == "PRON"]
174
+ if pron:
175
+ owner = pron[0].text
176
+ confidence = 0.6
177
+ except Exception:
178
+ pass
179
+
180
+ # Regex fallback to capture "Name akan <action>" in many languages
181
+ if owner is None:
182
+ m = re.search(
183
+ r"\b([A-Z][a-z]{1,20})\b\s+(akan|will|harus|must|to)\s+(?P<task>.+)",
184
+ text,
185
+ flags=re.IGNORECASE,
186
+ )
187
+ if m:
188
+ owner = m.group(1)
189
+ task = m.group("task").strip(" .,:;-")
190
+ confidence = 0.7
191
+
192
+ # Otherwise, check for "Saya akan"/"Aku akan" and attribute to speaker
193
+ if owner is None and re.search(r"\b(saya|aku|kami|kita)\b", text, flags=re.IGNORECASE):
194
+ owner = s.get("speaker_id")
195
+ # try extract phrase after 'akan' or commit verb
196
+ m2 = re.search(
197
+ r"\b(?:akan|saya akan|aku akan|saya akan membuat|aku akan membuat|tolong|siapkan|buat|bikin)\b\s*(?P<task>.+)$",
198
+ text,
199
+ flags=re.IGNORECASE,
200
+ )
201
+ if m2:
202
+ task = m2.group("task").strip(" .,:;-")
203
+ confidence = 0.7
204
+
205
+ # final fallback: if sentence contains action keywords, use whole sentence
206
+ if task is None:
207
+ # trim connectors and filler
208
+ t = re.sub(r"^(oke|ya|nah|baik)\b[:,-]*", "", text, flags=re.IGNORECASE).strip()
209
+ task = t[:300]
210
+
211
+ # Basic length filter
212
+ if task and len(task.split()) < 3:
213
+ continue
214
+
215
+ results.append(
216
+ {
217
+ "owner": owner or s.get("speaker_id"),
218
+ "task": task,
219
+ "sentence_idx": i,
220
+ "confidence": confidence,
221
+ }
222
+ )
223
+
224
+ return results
225
+
226
+
227
+ def extract_decisions_from_sentences(sent_meta: List[Dict[str, Any]]) -> List[str]:
228
+ """Simple decision extraction: look for decision keywords and return cleaned contexts."""
229
+ results: List[str] = []
230
+ decision_kw = re.compile(
231
+ r"\b(diputuskan|disepakati|kesimpulan|keputusan|sepakat|setuju|disetujui|putus|decided|decision)\b",
232
+ flags=re.IGNORECASE,
233
+ )
234
+
235
+ for i, s in enumerate(sent_meta):
236
+ text = s.get("text", "").strip()
237
+ if not text:
238
+ continue
239
+ if decision_kw.search(text):
240
+ cleaned = re.sub(r"\[.*?\]", "", text)
241
+ results.append(cleaned.strip())
242
+
243
+ return results
src/pipeline.py ADDED
@@ -0,0 +1,1121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main Pipeline Module
3
+ ====================
4
+ Orchestrates all components for end-to-end meeting transcription.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ import time
11
+ from dataclasses import dataclass, field
12
+ from datetime import datetime
13
+ from pathlib import Path
14
+ from typing import Any, Callable, Dict, List, Optional, Tuple
15
+
16
+ import torch
17
+
18
+ from src.audio_processor import AudioConfig, AudioProcessor
19
+ from src.diarization import DiarizationConfig, SpeakerDiarizer, SpeakerSegment
20
+ from src.document_generator import DocumentGenerator, MeetingMetadata
21
+ from src.evaluator import EvaluationResult, Evaluator
22
+ from src.summarizer import BERTSummarizer, MeetingSummary, SummarizationConfig
23
+ from src.transcriber import ASRConfig, ASRTranscriber, TranscriptSegment
24
+
25
+ # Optional speechbrain adapter
26
+ try:
27
+ from src.transcriber_speechbrain import ( # type: ignore
28
+ SpeechBrainASRConfig,
29
+ SpeechBrainTranscriber,
30
+ )
31
+ except Exception:
32
+ SpeechBrainTranscriber = None
33
+ SpeechBrainASRConfig = None
34
+ from src.utils import (
35
+ Timer,
36
+ ensure_dir,
37
+ format_duration,
38
+ sanitize_filename,
39
+ save_json,
40
+ setup_logger,
41
+ )
42
+
43
+
44
+ @dataclass
45
+ class PipelineConfig:
46
+ """Configuration for the complete pipeline"""
47
+
48
+ # Paths
49
+ models_dir: str = "./models"
50
+ output_dir: str = "./data/output"
51
+ cache_dir: str = "./cache"
52
+
53
+ # Audio settings
54
+ sample_rate: int = 16000
55
+
56
+ # Diarization settings
57
+ num_speakers: Optional[int] = None
58
+ min_speech_duration: float = 0.3
59
+ # Target speaker enforcement (convenience wrapper for DiarizationConfig.target_num_speakers)
60
+ target_speakers: Optional[int] = None
61
+
62
+ # ASR settings
63
+ # Default to Whisper Large v3 Turbo for better accuracy (may be slower)
64
+ asr_model_id: str = "large-v3-turbo"
65
+ asr_backend: str = "whisperx" # whisperx preferred for Large models
66
+ asr_language: str = "id"
67
+ whisperx_compute_type: str = "auto"
68
+ whisperx_vad_filter: bool = True
69
+
70
+ # Summarization settings
71
+ num_summary_sentences: int = 5
72
+
73
+ # Device
74
+ device: str = "auto"
75
+
76
+ # Flags
77
+ save_intermediate: bool = True
78
+ verbose: bool = True
79
+
80
+ # Performance options
81
+ fast_mode: bool = False # reduce accuracy for speed
82
+ quick_asr: bool = False # use lightweight ASR where possible
83
+ embedding_cache: bool = True # cache diarization embeddings to disk
84
+
85
+ # Preset mode (deployment = recommended default for production: WhisperX large-v3-turbo int8)
86
+ # Set default to 'fast' to prefer lightweight models (whisper-small) and avoid heavy WhisperX defaults
87
+ preset: str = "fast" # choices: deployment|balanced|fast|accurate
88
+
89
+ # Quick ASR options
90
+ prefer_whisper_small: bool = True
91
+ # Approximate Continuous Speech Tokenizer token rate in Hz (e.g., 7.5). When set,
92
+ # ASR will apply a lossy preprocessor to compress audio for speed. Use with care.
93
+ cst_hz: Optional[float] = 7.5
94
+
95
+ # Compare diarization methods during evaluation
96
+ diarization_compare: bool = False
97
+
98
+ # Allow explicit override for ASR parallel workers (None = auto)
99
+ asr_parallel_workers: Optional[int] = None # override for per-segment ASR parallelism
100
+
101
+ # Optional speaker mapping & diarization tuning
102
+ speaker_map_path: Optional[str] = None
103
+ tune_diarization: bool = False
104
+
105
+ # Target speaker convenience
106
+ target_speakers: Optional[int] = None
107
+
108
+ def __post_init__(self):
109
+ # Auto-detect device
110
+ if self.device == "auto":
111
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
112
+
113
+ # Create directories
114
+ ensure_dir(self.models_dir)
115
+ ensure_dir(self.output_dir)
116
+ ensure_dir(self.cache_dir)
117
+
118
+
119
+ @dataclass
120
+ class PipelineResult:
121
+ """Complete result from pipeline processing"""
122
+
123
+ # Input info
124
+ audio_path: str
125
+ audio_duration: float
126
+
127
+ # Processing info
128
+ num_speakers: int
129
+ num_segments: int
130
+ total_words: int
131
+ processing_time: float
132
+
133
+ # Outputs
134
+ segments: List[Dict[str, Any]]
135
+ transcript_text: str
136
+ summary: Dict[str, Any]
137
+ document_path: str
138
+
139
+ # Metadata
140
+ metadata: Dict[str, Any] = field(default_factory=dict)
141
+
142
+ def to_dict(self) -> Dict[str, Any]:
143
+ """Convert to dictionary"""
144
+ return {
145
+ "audio_path": self.audio_path,
146
+ "audio_duration": self.audio_duration,
147
+ "num_speakers": self.num_speakers,
148
+ "num_segments": self.num_segments,
149
+ "total_words": self.total_words,
150
+ "processing_time": self.processing_time,
151
+ "transcript_text": self.transcript_text,
152
+ "summary": self.summary,
153
+ "document_path": self.document_path,
154
+ "metadata": self.metadata,
155
+ }
156
+
157
+ def save(self, filepath: str):
158
+ """Save result to JSON file"""
159
+ save_json(self.to_dict(), filepath)
160
+
161
+
162
+ class MeetingTranscriberPipeline:
163
+ """
164
+ End-to-end pipeline for automatic meeting transcription.
165
+
166
+ Pipeline Flow:
167
+ 1. Audio Loading & Preprocessing
168
+ 2. Speaker Diarization (VAD + Embedding + Clustering)
169
+ 3. ASR Transcription (per speaker segment)
170
+ 4. BERT Summarization (extractive)
171
+ 5. Document Generation (.docx)
172
+
173
+ Attributes:
174
+ config: PipelineConfig object
175
+
176
+ Example:
177
+ >>> pipeline = MeetingTranscriberPipeline()
178
+ >>> result = pipeline.process("meeting.wav", title="Team Meeting")
179
+ >>> print(f"Document saved: {result.document_path}")
180
+ """
181
+
182
+ def __init__(self, config: Optional[PipelineConfig] = None):
183
+ """
184
+ Initialize pipeline.
185
+
186
+ Args:
187
+ config: PipelineConfig object (uses defaults if None)
188
+ """
189
+ self.config = config or PipelineConfig()
190
+
191
+ # Setup logger
192
+ self.logger = setup_logger(
193
+ "MeetingTranscriber",
194
+ log_file=(
195
+ os.path.join(self.config.cache_dir, "pipeline.log")
196
+ if self.config.save_intermediate
197
+ else None
198
+ ),
199
+ )
200
+
201
+ # Component placeholders (lazy loading)
202
+ self._audio_processor = None
203
+ self._diarizer = None
204
+ self._transcriber = None
205
+ self._summarizer = None
206
+ self._doc_generator = None
207
+ self._evaluator = None
208
+
209
+ # Processing state
210
+ self._waveform = None
211
+ self._sample_rate = None
212
+ self._diarization_segments = None
213
+ self._transcript_segments = None
214
+ self._summary = None
215
+ # Diarization tuning result (if autotune was run)
216
+ self._diarization_tune_result = None
217
+
218
+ if self.config.verbose:
219
+ self._log(f"Pipeline initialized with device: {self.config.device}")
220
+ # Log effective CST value for diagnostics
221
+ self._log(f"Pipeline effective cst_hz: {getattr(self.config, 'cst_hz', None)} Hz")
222
+
223
+ # =========================================================================
224
+ # Properties (Lazy Loading)
225
+ # =========================================================================
226
+
227
+ @property
228
+ def audio_processor(self) -> AudioProcessor:
229
+ """Get audio processor (lazy loaded)"""
230
+ if self._audio_processor is None:
231
+ self._audio_processor = AudioProcessor(
232
+ AudioConfig(sample_rate=self.config.sample_rate, mono=True, normalize=True)
233
+ )
234
+ return self._audio_processor
235
+
236
+ @property
237
+ def diarizer(self) -> SpeakerDiarizer:
238
+ """Get diarizer (lazy loaded)"""
239
+ if self._diarizer is None:
240
+ dz_cfg = DiarizationConfig(
241
+ min_speech_duration=self.config.min_speech_duration,
242
+ device=self.config.device,
243
+ )
244
+ # If pipeline has target_speakers configured, propagate to diarizer config
245
+ if getattr(self.config, "target_speakers", None) is not None:
246
+ dz_cfg.target_num_speakers = int(self.config.target_speakers)
247
+ self._diarizer = SpeakerDiarizer(config=dz_cfg, models_dir=self.config.models_dir)
248
+ return self._diarizer
249
+
250
+ @property
251
+ def transcriber(self) -> ASRTranscriber:
252
+ """Get transcriber (lazy loaded)"""
253
+ if self._transcriber is None:
254
+ # Instantiate ASR transcriber; if configured to use SpeechBrain backend prefer adapter
255
+ asr_cfg = ASRConfig(
256
+ model_id=self.config.asr_model_id,
257
+ device=self.config.device,
258
+ backend=getattr(self.config, "asr_backend", "whisper"),
259
+ language=getattr(self.config, "asr_language", "id"),
260
+ whisperx_compute_type=getattr(self.config, "whisperx_compute_type", "auto"),
261
+ whisperx_vad_filter=bool(getattr(self.config, "whisperx_vad_filter", True)),
262
+ )
263
+
264
+ # Apply preset defaults (deployment/balanced/fast/accurate)
265
+ preset = getattr(self.config, "preset", None)
266
+ if preset == "deployment":
267
+ # Deployment preset: prefer WhisperX large-v3-turbo (int8 on CPU), full-audio mapping, tuned parallelism
268
+ asr_cfg.backend = "whisperx"
269
+
270
+ # If user did not explicitly provide a WhisperX-compatible model (e.g. the
271
+ # configured model contains 'wav2vec' or is an existing TF checkpoint),
272
+ # override to a known WhisperX-compatible model id. This avoids trying to
273
+ # load a Transformers checkpoint with WhisperX which expects CTranslate2 format
274
+ # (contains 'model.bin').
275
+ user_model = getattr(self.config, "asr_model_id", "") or ""
276
+ user_model_l = user_model.lower()
277
+ if (
278
+ (not user_model_l)
279
+ or ("wav2vec" in user_model_l)
280
+ or user_model_l.startswith("models/")
281
+ ):
282
+ asr_cfg.model_id = "large-v3-turbo"
283
+ self._log(
284
+ "Preset 'deployment' selected: overriding ASR model to 'large-v3-turbo' for WhisperX compatibility."
285
+ )
286
+ else:
287
+ asr_cfg.model_id = user_model
288
+
289
+ asr_cfg.use_full_audio_for_segments = True
290
+ asr_cfg.whisperx_compute_type = (
291
+ getattr(self.config, "whisperx_compute_type", "int8") or "int8"
292
+ )
293
+ try:
294
+ import os
295
+
296
+ asr_cfg.parallel_workers = min(8, max(1, (os.cpu_count() or 4) - 1))
297
+ except Exception:
298
+ pass
299
+ elif getattr(self.config, "quick_asr", False) or getattr(self.config, "prefer_whisper_small", False):
300
+ # Quick/Lightweight ASR: prefer Whisper small for speed and low memory
301
+ try:
302
+ asr_cfg.model_id = "openai/whisper-small"
303
+ asr_cfg.backend = "whisper"
304
+ # For speed, avoid the costly full-audio alignment step
305
+ asr_cfg.use_full_audio_for_segments = False
306
+ # Increase parallel workers conservatively for per-segment transcription
307
+ import os
308
+
309
+ asr_cfg.parallel_workers = min(8, max(1, (os.cpu_count() or 4) - 1))
310
+ # Larger chunk lengths reduce per-chunk overhead (helps CPU-bound runs)
311
+ asr_cfg.chunk_length_s = max(asr_cfg.chunk_length_s, 60.0)
312
+ # If Pipeline requested CST approximation, propagate to ASR config
313
+ if getattr(self.config, "cst_hz", None) is not None:
314
+ asr_cfg.cst_hz = float(self.config.cst_hz)
315
+ except Exception:
316
+ pass
317
+
318
+ # Allow explicit override from pipeline config
319
+ if getattr(self.config, "asr_parallel_workers", None) is not None:
320
+ try:
321
+ asr_cfg.parallel_workers = int(self.config.asr_parallel_workers)
322
+ except Exception:
323
+ pass
324
+
325
+ # Allow explicit override from pipeline config
326
+ if getattr(self.config, "asr_parallel_workers", None) is not None:
327
+ try:
328
+ asr_cfg.parallel_workers = int(self.config.asr_parallel_workers)
329
+ except Exception:
330
+ pass
331
+ if (
332
+ getattr(self.config, "asr_backend", None) == "speechbrain"
333
+ and SpeechBrainTranscriber is not None
334
+ ):
335
+ # Create SpeechBrain adapter and wrap it with existing ASRTranscriber interface by setting backend
336
+ self._transcriber = ASRTranscriber(
337
+ config=asr_cfg, models_dir=self.config.models_dir
338
+ )
339
+ self._transcriber.config.backend = "speechbrain"
340
+ else:
341
+ self._transcriber = ASRTranscriber(
342
+ config=asr_cfg,
343
+ models_dir=self.config.models_dir,
344
+ )
345
+ return self._transcriber
346
+
347
+ @property
348
+ def summarizer(self) -> BERTSummarizer:
349
+ """Get summarizer (lazy loaded)"""
350
+ if self._summarizer is None:
351
+ self._summarizer = BERTSummarizer(
352
+ config=SummarizationConfig(num_sentences=self.config.num_summary_sentences)
353
+ )
354
+ return self._summarizer
355
+
356
+ @property
357
+ def doc_generator(self) -> DocumentGenerator:
358
+ """Get document generator (lazy loaded)"""
359
+ if self._doc_generator is None:
360
+ self._doc_generator = DocumentGenerator(output_dir=self.config.output_dir)
361
+ return self._doc_generator
362
+
363
+ @property
364
+ def evaluator(self) -> Evaluator:
365
+ """Get evaluator (lazy loaded)"""
366
+ if self._evaluator is None:
367
+ self._evaluator = Evaluator(output_dir=self.config.output_dir)
368
+ return self._evaluator
369
+
370
+ # =========================================================================
371
+ # Main Processing Methods
372
+ # =========================================================================
373
+
374
+ def process(
375
+ self,
376
+ audio_path: str,
377
+ title: str = "Notulensi Rapat",
378
+ date: Optional[str] = None,
379
+ location: str = "",
380
+ num_speakers: Optional[int] = None,
381
+ output_filename: Optional[str] = None,
382
+ progress_callback: Optional[Callable[[str, int, int], None]] = None,
383
+ ) -> PipelineResult:
384
+ """
385
+ Process audio file through complete pipeline.
386
+
387
+ Args:
388
+ audio_path: Path to audio file
389
+ title: Meeting title for document
390
+ date: Meeting date (default: today)
391
+ location: Meeting location/platform
392
+ num_speakers: Known number of speakers (auto-detect if None)
393
+ output_filename: Output .docx filename (auto-generated if None)
394
+ progress_callback: Callback function(step_name, current, total)
395
+
396
+ Returns:
397
+ PipelineResult with all outputs and metadata
398
+ """
399
+ start_time = time.time()
400
+
401
+ def update_progress(step: str, current: int, total: int):
402
+ if progress_callback:
403
+ progress_callback(step, current, total)
404
+ if self.config.verbose:
405
+ self._log(f"Step {current}/{total}: {step}")
406
+
407
+ self._log("=" * 60)
408
+ self._log(f"Processing: {audio_path}")
409
+ self._log("=" * 60)
410
+
411
+ # =====================================================================
412
+ # Step 1: Load and preprocess audio
413
+ # =====================================================================
414
+ update_progress("Loading audio", 1, 5)
415
+
416
+ with Timer("Audio loading"):
417
+ self._waveform, self._sample_rate = self.audio_processor.load_audio(audio_path)
418
+
419
+ duration = self.audio_processor.get_duration(self._waveform, self._sample_rate)
420
+ self._log(f"Audio loaded: {format_duration(duration)} ({duration:.2f}s)")
421
+
422
+ # Validate audio duration
423
+ max_duration_minutes = getattr(self.config, "max_duration_minutes", 60)
424
+ max_duration_seconds = max_duration_minutes * 60
425
+ if duration > max_duration_seconds:
426
+ error_msg = (
427
+ f"Audio duration ({duration:.1f}s) exceeds maximum allowed duration "
428
+ f"({max_duration_seconds}s / {max_duration_minutes} minutes). "
429
+ "Please split the audio or increase max_duration_minutes in config."
430
+ )
431
+ self.logger.error(error_msg)
432
+ raise ValueError(error_msg)
433
+
434
+ # =====================================================================
435
+ # Step 2: Speaker diarization (optionally tune hyperparameters first)
436
+ # =====================================================================
437
+ update_progress("Speaker diarization", 2, 5)
438
+
439
+ # Optional automatic tuning step
440
+ if getattr(self.config, "tune_diarization", False):
441
+ self._log("Tuning diarization hyperparameters...")
442
+ try:
443
+ tune_res = self.diarizer.auto_tune(
444
+ self._waveform, self._sample_rate, num_speakers=num_speakers
445
+ )
446
+ # store tuning result for later reporting
447
+ self._diarization_tune_result = tune_res or {}
448
+ except Exception as e:
449
+ self._diarization_tune_result = {}
450
+ self._log(f"Diarization tuning failed (continuing with defaults): {e}")
451
+
452
+ with Timer("Diarization"):
453
+ # Pass cache directory and audio id so diarizer can cache embeddings
454
+ self._diarization_segments = self.diarizer.process(
455
+ self._waveform,
456
+ self._sample_rate,
457
+ num_speakers=num_speakers or self.config.num_speakers,
458
+ cache_dir=self.config.cache_dir,
459
+ audio_id=Path(audio_path).stem,
460
+ fast_mode=self.config.fast_mode,
461
+ )
462
+
463
+ unique_speakers = set(seg.speaker_id for seg in self._diarization_segments)
464
+ self._log(
465
+ f"Found {len(unique_speakers)} speakers, {len(self._diarization_segments)} segments"
466
+ )
467
+
468
+ # =====================================================================
469
+ # Step 3: ASR transcription
470
+ # =====================================================================
471
+ update_progress("Transcribing speech", 3, 5)
472
+
473
+ with Timer("Transcription"):
474
+ self._transcript_segments = self.transcriber.transcribe_segments(
475
+ self._waveform, self._diarization_segments, self._sample_rate
476
+ )
477
+
478
+ total_words = sum(seg.word_count for seg in self._transcript_segments)
479
+ self._log(f"Transcribed {len(self._transcript_segments)} segments, ~{total_words} words")
480
+
481
+ # =====================================================================
482
+ # Step 4: BERT summarization
483
+ # =====================================================================
484
+ update_progress("Generating summary", 4, 5)
485
+
486
+ # If a manual speaker map was provided via config, apply it so summarizer sees mapped names
487
+ if getattr(self.config, "speaker_map_path", None):
488
+ try:
489
+ speaker_map = self._load_speaker_map(self.config.speaker_map_path)
490
+ self._apply_speaker_map(speaker_map)
491
+ except Exception as e:
492
+ self._log(f"Failed to load/apply speaker map: {e}")
493
+
494
+ with Timer("Summarization"):
495
+ self._summary = self.summarizer.summarize(self._transcript_segments)
496
+
497
+ self._log(f"Generated summary with {len(self._summary.key_points)} key points")
498
+
499
+ # =====================================================================
500
+ # Step 5: Generate document
501
+ # =====================================================================
502
+ update_progress("Generating document", 5, 5)
503
+
504
+ # Prepare metadata
505
+ participants = list(unique_speakers)
506
+ # If speaker map provided, map participants accordingly
507
+ if getattr(self.config, "speaker_map_path", None):
508
+ try:
509
+ speaker_map = self._load_speaker_map(self.config.speaker_map_path)
510
+ participants = [speaker_map.get(p, p) for p in participants]
511
+ except Exception:
512
+ pass
513
+
514
+ metadata = MeetingMetadata(
515
+ title=title,
516
+ date=date or datetime.now().strftime("%d %B %Y"),
517
+ time=datetime.now().strftime("%H:%M"),
518
+ location=location,
519
+ duration=format_duration(duration),
520
+ participants=participants,
521
+ )
522
+
523
+ # Generate filename if not provided
524
+ if output_filename is None:
525
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
526
+ safe_title = sanitize_filename(title)[:30]
527
+ output_filename = f"notulensi_{safe_title}_{timestamp}.docx"
528
+
529
+ with Timer("Document generation"):
530
+ doc_path = self.doc_generator.generate(
531
+ metadata=metadata,
532
+ summary=self._summary,
533
+ transcript=self._transcript_segments,
534
+ output_filename=output_filename,
535
+ )
536
+
537
+ self._log(f"Document saved: {doc_path}")
538
+
539
+ # =====================================================================
540
+ # Save intermediate results
541
+ # =====================================================================
542
+ if self.config.save_intermediate:
543
+ self._save_intermediate_results(audio_path, metadata)
544
+
545
+ # Save speaker map alongside results if provided
546
+ if getattr(self.config, "speaker_map_path", None):
547
+ try:
548
+ speaker_map = self._load_speaker_map(self.config.speaker_map_path)
549
+ save_json(
550
+ speaker_map,
551
+ Path(self.config.cache_dir) / f"{Path(audio_path).stem}_speaker_map.json",
552
+ )
553
+ except Exception:
554
+ pass
555
+
556
+ # =====================================================================
557
+ # Build result
558
+ # =====================================================================
559
+ processing_time = time.time() - start_time
560
+
561
+ result = PipelineResult(
562
+ audio_path=audio_path,
563
+ audio_duration=duration,
564
+ num_speakers=len(unique_speakers),
565
+ num_segments=len(self._transcript_segments),
566
+ total_words=total_words,
567
+ processing_time=processing_time,
568
+ segments=[seg.to_dict() for seg in self._transcript_segments],
569
+ transcript_text=self.get_transcript_text(),
570
+ summary=self._summary.to_dict(),
571
+ document_path=doc_path,
572
+ metadata={
573
+ "title": title,
574
+ "date": date or datetime.now().strftime("%Y-%m-%d"),
575
+ "location": location,
576
+ "device": self.config.device,
577
+ "asr_model": self.config.asr_model_id,
578
+ },
579
+ )
580
+
581
+ self._log("=" * 60)
582
+ self._log(f"Processing complete! Total time: {format_duration(processing_time)}")
583
+ self._log(f"Output: {doc_path}")
584
+ self._log("=" * 60)
585
+
586
+ return result
587
+
588
+ # =========================================================================
589
+ # Individual Step Methods
590
+ # =========================================================================
591
+
592
+ def load_audio(self, audio_path: str) -> Tuple[torch.Tensor, int]:
593
+ """Load and preprocess audio file"""
594
+ self._waveform, self._sample_rate = self.audio_processor.load_audio(audio_path)
595
+ return self._waveform, self._sample_rate
596
+
597
+ def run_diarization(self, num_speakers: Optional[int] = None) -> List[SpeakerSegment]:
598
+ """Run diarization on loaded audio"""
599
+ if self._waveform is None:
600
+ raise ValueError("Audio not loaded. Call load_audio() first.")
601
+
602
+ self._diarization_segments = self.diarizer.process(
603
+ self._waveform, self._sample_rate, num_speakers=num_speakers
604
+ )
605
+ return self._diarization_segments
606
+
607
+ def run_transcription(self) -> List[TranscriptSegment]:
608
+ """Run ASR on diarized segments"""
609
+ if self._diarization_segments is None:
610
+ raise ValueError("Diarization not done. Call run_diarization() first.")
611
+
612
+ self._transcript_segments = self.transcriber.transcribe_segments(
613
+ self._waveform, self._diarization_segments, self._sample_rate
614
+ )
615
+ return self._transcript_segments
616
+
617
+ def run_summarization(self) -> MeetingSummary:
618
+ """Generate summary from transcript"""
619
+ if self._transcript_segments is None:
620
+ raise ValueError("Transcription not done. Call run_transcription() first.")
621
+
622
+ self._summary = self.summarizer.summarize(self._transcript_segments)
623
+ return self._summary
624
+
625
+ def generate_document(
626
+ self, metadata: MeetingMetadata, output_filename: str = "notulensi.docx"
627
+ ) -> str:
628
+ """Generate .docx document"""
629
+ if self._transcript_segments is None or self._summary is None:
630
+ raise ValueError("Transcript and summary required.")
631
+
632
+ return self.doc_generator.generate(
633
+ metadata=metadata,
634
+ summary=self._summary,
635
+ transcript=self._transcript_segments,
636
+ output_filename=output_filename,
637
+ )
638
+
639
+ # =========================================================================
640
+ # Evaluation Methods
641
+ # =========================================================================
642
+
643
+ def evaluate(
644
+ self,
645
+ reference_transcript: Optional[str] = None,
646
+ reference_diarization: Optional[List[Tuple[str, float, float]]] = None,
647
+ reference_summary: Optional[str] = None,
648
+ sample_name: str = "sample",
649
+ condition: str = "unknown",
650
+ ) -> EvaluationResult:
651
+ """
652
+ Evaluate pipeline output against ground truth.
653
+
654
+ Args:
655
+ reference_transcript: Ground truth transcript text
656
+ reference_diarization: Ground truth diarization [(speaker, start, end), ...]
657
+ reference_summary: Ground truth summary text (for summary evaluation)
658
+ sample_name: Name for this sample
659
+ condition: Test condition name
660
+
661
+ Returns:
662
+ EvaluationResult with WER, DER, and optional summary metrics
663
+ """
664
+ wer_result = None
665
+ der_result = None
666
+
667
+ # Calculate WER if reference transcript provided
668
+ if reference_transcript and self._transcript_segments:
669
+ hypothesis = self.get_transcript_text()
670
+ wer_result = self.evaluator.calculate_wer(reference_transcript, hypothesis)
671
+ self._log(f"WER: {wer_result.wer:.4f} ({wer_result.wer*100:.2f}%)")
672
+
673
+ # Calculate DER if reference diarization provided
674
+ if reference_diarization and self._diarization_segments:
675
+ hypothesis_diarization = [
676
+ (seg.speaker_id, seg.start, seg.end) for seg in self._diarization_segments
677
+ ]
678
+ der_result = self.evaluator.calculate_der(reference_diarization, hypothesis_diarization)
679
+ self._log(f"DER: {der_result.der:.4f} ({der_result.der*100:.2f}%)")
680
+
681
+ # If reference diarization not provided but reference transcript contains speaker labels,
682
+ # attempt to build a reference diarization by aligning the labeled transcript to the
683
+ # pipeline's transcript segments. This often improves DER accuracy when GT RTTM is missing.
684
+ if not reference_diarization and reference_transcript and self._diarization_segments:
685
+ # Heuristic detection: presence of 'Name:' lines
686
+ if ":" in reference_transcript and any(
687
+ line.strip().endswith(":") or ":" in line
688
+ for line in reference_transcript.splitlines()[:20]
689
+ ):
690
+ try:
691
+ from src.utils import (
692
+ align_reference_to_segments,
693
+ parse_speaker_labeled_text,
694
+ )
695
+
696
+ utterances = parse_speaker_labeled_text(reference_transcript)
697
+ if utterances:
698
+ hyp_segs = self._transcript_segments or []
699
+ # Build reference diarization from alignment
700
+ derived_ref = align_reference_to_segments(utterances, hyp_segs)
701
+ if derived_ref:
702
+ hypothesis_diarization = [
703
+ (seg.speaker_id, seg.start, seg.end)
704
+ for seg in self._diarization_segments
705
+ ]
706
+ der_result = self.evaluator.calculate_der(
707
+ derived_ref, hypothesis_diarization
708
+ )
709
+ self._log(
710
+ f"Derived RTTM used for DER (from speaker-labeled transcript). DER: {der_result.der:.4f} ({der_result.der*100:.2f}%)"
711
+ )
712
+ except Exception as e:
713
+ self._log(f"Auto-alignment for RTTM failed: {e}")
714
+ pass
715
+
716
+ # Summary evaluation (if reference_summary provided)
717
+ summary_result = None
718
+ if reference_summary and self._summary:
719
+ try:
720
+ # Prefer overview text if available, otherwise join key points
721
+ hyp_summary = getattr(self._summary, "overview", "") or " ".join(getattr(self._summary, "key_points", []))
722
+ summary_result = self.evaluator.calculate_summary_metrics(reference_summary, hyp_summary)
723
+ self._log(
724
+ f"Summary metrics - ROUGE1_F: {summary_result.rouge.get('rouge1_f', 0.0):.4f}, BERTScore_F1: {summary_result.bertscore.get('bertscore_f1', 0.0):.4f}"
725
+ )
726
+ except Exception as e:
727
+ self._log(f"Summary evaluation failed: {e}")
728
+
729
+ # Build evaluation metadata: include relevant hyperparameters and tuning info
730
+ metadata: Dict[str, Any] = {}
731
+
732
+ try:
733
+ # ASR config
734
+ asr_cfg = getattr(self.transcriber, "config", None)
735
+ if asr_cfg is not None:
736
+ metadata["asr_backend"] = getattr(asr_cfg, "backend", None)
737
+ metadata["asr_model_id"] = getattr(asr_cfg, "model_id", None)
738
+ metadata["asr_language"] = getattr(asr_cfg, "language", None)
739
+ metadata["asr_use_full_audio_for_segments"] = getattr(
740
+ asr_cfg, "use_full_audio_for_segments", None
741
+ )
742
+ metadata["asr_whisperx_compute_type"] = getattr(asr_cfg, "whisperx_compute_type", None)
743
+ metadata["asr_whisperx_vad_filter"] = getattr(asr_cfg, "whisperx_vad_filter", None)
744
+ metadata["asr_parallel_workers"] = getattr(asr_cfg, "parallel_workers", None)
745
+ except Exception:
746
+ pass
747
+
748
+ try:
749
+ dz_cfg = getattr(self.diarizer, "config", None)
750
+ if dz_cfg is not None:
751
+ # pick a sensible subset of diarizer params
752
+ metadata["diarizer_vad_threshold"] = getattr(dz_cfg, "vad_threshold", None)
753
+ metadata["diarizer_min_speech_duration"] = getattr(dz_cfg, "min_speech_duration", None)
754
+ metadata["diarizer_segment_window"] = getattr(dz_cfg, "segment_window", None)
755
+ metadata["diarizer_segment_hop"] = getattr(dz_cfg, "segment_hop", None)
756
+ metadata["diarizer_clustering_method"] = getattr(dz_cfg, "clustering_method", None)
757
+ metadata["diarizer_clustering_threshold"] = getattr(dz_cfg, "clustering_threshold", None)
758
+ metadata["diarizer_min_cluster_size"] = getattr(dz_cfg, "min_cluster_size", None)
759
+ metadata["diarizer_iterative_merge_threshold"] = getattr(
760
+ dz_cfg, "iterative_merge_threshold", None
761
+ )
762
+ metadata["diarizer_target_num_speakers"] = getattr(dz_cfg, "target_num_speakers", None)
763
+ metadata["diarizer_target_force_threshold"] = getattr(dz_cfg, "target_force_threshold", None)
764
+ metadata["diarizer_merge_gap_threshold"] = getattr(dz_cfg, "merge_gap_threshold", None)
765
+ metadata["diarizer_use_fast_embedding"] = getattr(dz_cfg, "use_fast_embedding", None)
766
+ metadata["diarizer_embedding_model_id"] = getattr(dz_cfg, "embedding_model_id", None)
767
+ except Exception:
768
+ pass
769
+
770
+ metadata["tune_diarization_requested"] = bool(getattr(self.config, "tune_diarization", False))
771
+ metadata["diarization_tune_result"] = self._diarization_tune_result or {}
772
+
773
+ # Reference information
774
+ metadata["reference_transcript_provided"] = bool(reference_transcript)
775
+ metadata["reference_diarization_provided"] = bool(reference_diarization)
776
+ metadata["used_derived_rttm"] = bool("derived_ref" in locals() and derived_ref)
777
+
778
+ # Optional diarization method comparison (agglomerative vs spectral)
779
+ if getattr(self.config, "diarization_compare", False) and reference_diarization:
780
+ try:
781
+ # Recompute speech regions/windows/embeddings for re-clustering
782
+ speech_regions = self.diarizer._detect_speech(self._waveform, self._sample_rate)
783
+ windows = self.diarizer._create_windows(speech_regions)
784
+ embeddings = self.diarizer._extract_embeddings(
785
+ self._waveform, windows, self._sample_rate, cache_dir=self.config.cache_dir, audio_id=Path(sample_name).stem
786
+ )
787
+
788
+ comp_results = {}
789
+ for method in ("agglomerative", "spectral"):
790
+ try:
791
+ labels = self.diarizer._cluster_embeddings(embeddings, num_speakers=None, method_override=method)
792
+ hyp_segments = self.diarizer._create_segments(windows, labels, embeddings)
793
+ hyp_rttm = [(s.speaker_id, s.start, s.end) for s in hyp_segments]
794
+ der_res = self.evaluator.calculate_der(reference_diarization, hyp_rttm)
795
+ comp_results[method] = der_res.to_dict()
796
+ except Exception as e:
797
+ comp_results[method] = {"error": str(e)}
798
+
799
+ metadata["diarization_comparison"] = comp_results
800
+ except Exception as e:
801
+ self._log(f"Diarization comparison failed: {e}")
802
+
803
+ return EvaluationResult(
804
+ sample_name=sample_name,
805
+ condition=condition,
806
+ wer_result=wer_result,
807
+ der_result=der_result,
808
+ summary_result=summary_result,
809
+ metadata=metadata,
810
+ )
811
+
812
+ # =========================================================================
813
+ # Utility Methods
814
+ # =========================================================================
815
+
816
+ def get_transcript_text(self) -> str:
817
+ """Get full transcript as plain text"""
818
+ if self._transcript_segments is None:
819
+ return ""
820
+ return " ".join(seg.text for seg in self._transcript_segments if seg.text)
821
+
822
+ def get_formatted_transcript(self) -> str:
823
+ """Get transcript with speaker labels and timestamps"""
824
+ if self._transcript_segments is None:
825
+ return ""
826
+
827
+ lines = []
828
+ for seg in self._transcript_segments:
829
+ timestamp = format_duration(seg.start)
830
+ lines.append(f"[{timestamp}] {seg.speaker_id}: {seg.text}")
831
+
832
+ return "\n".join(lines)
833
+
834
+ def get_speaker_stats(self) -> Dict[str, Dict[str, Any]]:
835
+ """Get statistics per speaker"""
836
+ if self._transcript_segments is None:
837
+ return {}
838
+
839
+ stats = {}
840
+ for seg in self._transcript_segments:
841
+ if seg.speaker_id not in stats:
842
+ stats[seg.speaker_id] = {"word_count": 0, "duration": 0.0, "segment_count": 0}
843
+
844
+ stats[seg.speaker_id]["word_count"] += seg.word_count
845
+ stats[seg.speaker_id]["duration"] += seg.duration
846
+ stats[seg.speaker_id]["segment_count"] += 1
847
+
848
+ return stats
849
+
850
+ def clear_state(self):
851
+ """Clear internal state for fresh processing"""
852
+ self._waveform = None
853
+ self._sample_rate = None
854
+ self._diarization_segments = None
855
+ self._transcript_segments = None
856
+ self._summary = None
857
+
858
+ def _log(self, message: str):
859
+ """Log message"""
860
+ if self.config.verbose:
861
+ print(f"[Pipeline] {message}")
862
+ self.logger.info(message)
863
+
864
+ def _load_speaker_map(self, path: str) -> dict:
865
+ """Load a speaker map from JSON or YAML file."""
866
+ p = Path(path)
867
+ if not p.exists():
868
+ raise FileNotFoundError(f"Speaker map file not found: {path}")
869
+ try:
870
+ import json
871
+
872
+ with open(p, "r", encoding="utf-8") as fh:
873
+ data = json.load(fh)
874
+ if not isinstance(data, dict):
875
+ raise ValueError("Speaker map must be a JSON object mapping labels to names")
876
+ return data
877
+ except Exception:
878
+ try:
879
+ import yaml
880
+
881
+ with open(p, "r", encoding="utf-8") as fh:
882
+ data = yaml.safe_load(fh)
883
+ if not isinstance(data, dict):
884
+ raise ValueError("Speaker map must be a mapping in YAML/JSON format")
885
+ return data
886
+ except Exception as e:
887
+ raise ValueError(f"Failed to parse speaker map: {e}")
888
+
889
+ def _apply_speaker_map(self, mapping: dict):
890
+ """Apply speaker mapping to transcript segments and summary action items.
891
+
892
+ This replaces `seg.speaker_id` with the provided name and stores the original id in
893
+ `seg.metadata['original_speaker_id']` for traceability.
894
+ """
895
+ if not mapping:
896
+ return
897
+
898
+ # Update transcript segments if they exist
899
+ if getattr(self, "_transcript_segments", None):
900
+ for seg in self._transcript_segments:
901
+ orig = seg.speaker_id
902
+ mapped = mapping.get(orig)
903
+ if mapped and mapped != orig:
904
+ seg.metadata["original_speaker_id"] = orig
905
+ seg.speaker_id = mapped
906
+
907
+ # Update action item owners in summary
908
+ try:
909
+ for ai in self._summary.action_items or []:
910
+ owner = ai.get("owner")
911
+ if owner and owner in mapping:
912
+ ai["owner"] = mapping[owner]
913
+ except Exception:
914
+ pass
915
+
916
+ # Finally update diarization segments as well (if present)
917
+ try:
918
+ self._log(f"Applying speaker mapping to diarization segments: {mapping}")
919
+ for dseg in self._diarization_segments or []:
920
+ orig = dseg.speaker_id
921
+ mapped = mapping.get(orig)
922
+ self._log(f"Segment {orig} -> mapped: {mapped}")
923
+ if mapped and mapped != orig:
924
+ dseg.metadata["original_speaker_id"] = orig
925
+ dseg.speaker_id = mapped
926
+ self._log(
927
+ f"Post-map speaker ids: {[d.speaker_id for d in (self._diarization_segments or [])]}"
928
+ )
929
+ except Exception as e:
930
+ self._log(f"Error applying speaker map to diarization segments: {e}")
931
+ pass
932
+
933
+ def _save_intermediate_results(self, audio_path: str, metadata: MeetingMetadata):
934
+ """Save intermediate results to JSON"""
935
+ base_name = Path(audio_path).stem
936
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
937
+
938
+ results = {
939
+ "audio_path": audio_path,
940
+ "timestamp": timestamp,
941
+ "metadata": {
942
+ "title": metadata.title,
943
+ "date": metadata.date,
944
+ "duration": metadata.duration,
945
+ },
946
+ "config": {
947
+ "sample_rate": self.config.sample_rate,
948
+ "asr_model": self.config.asr_model_id,
949
+ "device": self.config.device,
950
+ },
951
+ "diarization": [
952
+ {
953
+ "speaker_id": seg.speaker_id,
954
+ "start": seg.start,
955
+ "end": seg.end,
956
+ "is_overlap": seg.is_overlap,
957
+ }
958
+ for seg in (self._diarization_segments or [])
959
+ ],
960
+ "transcript": [seg.to_dict() for seg in (self._transcript_segments or [])],
961
+ "summary": self._summary.to_dict() if self._summary else None,
962
+ }
963
+
964
+ output_path = Path(self.config.cache_dir) / f"{base_name}_{timestamp}_results.json"
965
+ save_json(results, output_path)
966
+
967
+ self._log(f"Intermediate results saved: {output_path}")
968
+
969
+ # ------------------------------------------------------------------
970
+ # Convenience methods for interactive flows (UI, Streamlit)
971
+ # ------------------------------------------------------------------
972
+ def run_diarization(self, audio_path: str) -> dict:
973
+ """Run loading + diarization steps and return a dict with summary info.
974
+
975
+ Returns: {"audio_duration": float, "num_windows": int, "num_speech_regions": int, "unique_speakers": [..], "segments": [..]}
976
+ """
977
+ # Load audio
978
+ self._waveform, self._sample_rate = self.audio_processor.load_audio(audio_path)
979
+ duration = self.audio_processor.get_duration(self._waveform, self._sample_rate)
980
+
981
+ # Run diarization
982
+ self._diarization_segments = self.diarizer.process(
983
+ self._waveform,
984
+ self._sample_rate,
985
+ num_speakers=None,
986
+ cache_dir=self.config.cache_dir,
987
+ audio_id=Path(audio_path).stem,
988
+ fast_mode=self.config.fast_mode,
989
+ )
990
+
991
+ unique_speakers = sorted(list(set(seg.speaker_id for seg in self._diarization_segments)))
992
+
993
+ return {
994
+ "audio_duration": duration,
995
+ "num_segments": len(self._diarization_segments),
996
+ "unique_speakers": unique_speakers,
997
+ "segments": [
998
+ {"speaker_id": s.speaker_id, "start": s.start, "end": s.end}
999
+ for s in self._diarization_segments
1000
+ ],
1001
+ }
1002
+
1003
+ def apply_speaker_map(
1004
+ self, mapping: dict, save_to_cache: bool = False, audio_id: Optional[str] = None
1005
+ ):
1006
+ """Apply a manual speaker mapping to internal state and optionally save the map to cache.
1007
+
1008
+ mapping: dict mapping original speaker id -> desired display name
1009
+ """
1010
+ self._apply_speaker_map(mapping)
1011
+ if save_to_cache and audio_id:
1012
+ try:
1013
+ save_json(mapping, Path(self.config.cache_dir) / f"{audio_id}_speaker_map.json")
1014
+ except Exception:
1015
+ pass
1016
+
1017
+ def continue_from_diarization(
1018
+ self,
1019
+ title: str = "Notulensi Rapat",
1020
+ date: Optional[str] = None,
1021
+ location: str = "",
1022
+ output_filename: Optional[str] = None,
1023
+ progress_callback: Optional[Callable[[str, int, int], None]] = None,
1024
+ ) -> PipelineResult:
1025
+ """Continue processing from the current _waveform and _diarization_segments.
1026
+
1027
+ Runs ASR, summarization, and document generation using existing in-memory diarization.
1028
+ """
1029
+ if (
1030
+ getattr(self, "_waveform", None) is None
1031
+ or getattr(self, "_diarization_segments", None) is None
1032
+ ):
1033
+ raise RuntimeError(
1034
+ "Diarization state not found. Run run_diarization(audio_path) first."
1035
+ )
1036
+
1037
+ update_progress = lambda step, cur, total: (
1038
+ progress_callback(step, cur, total) if progress_callback else None
1039
+ )
1040
+
1041
+ # Step 3: ASR
1042
+ update_progress("Transcribing speech", 3, 5)
1043
+ with Timer("Transcription"):
1044
+ self._transcript_segments = self.transcriber.transcribe_segments(
1045
+ self._waveform, self._diarization_segments, self._sample_rate
1046
+ )
1047
+
1048
+ total_words = sum(seg.word_count for seg in self._transcript_segments)
1049
+ self._log(f"Transcribed {len(self._transcript_segments)} segments, ~{total_words} words")
1050
+
1051
+ # Apply speaker map if configured
1052
+ if getattr(self.config, "speaker_map_path", None):
1053
+ try:
1054
+ speaker_map = self._load_speaker_map(self.config.speaker_map_path)
1055
+ self._apply_speaker_map(speaker_map)
1056
+ except Exception as e:
1057
+ self._log(f"Failed to load/apply speaker map: {e}")
1058
+
1059
+ # Step 4: Summarization
1060
+ update_progress("Generating summary", 4, 5)
1061
+ with Timer("Summarization"):
1062
+ self._summary = self.summarizer.summarize(self._transcript_segments)
1063
+
1064
+ self._log(f"Generated summary with {len(self._summary.key_points)} key points")
1065
+
1066
+ # Step 5: Document generation
1067
+ update_progress("Generating document", 5, 5)
1068
+
1069
+ participants = list(set(seg.speaker_id for seg in self._diarization_segments))
1070
+ if getattr(self.config, "speaker_map_path", None):
1071
+ try:
1072
+ speaker_map = self._load_speaker_map(self.config.speaker_map_path)
1073
+ participants = [speaker_map.get(p, p) for p in participants]
1074
+ except Exception:
1075
+ pass
1076
+
1077
+ metadata = MeetingMetadata(
1078
+ title=title,
1079
+ date=date or datetime.now().strftime("%d %B %Y"),
1080
+ time=datetime.now().strftime("%H:%M"),
1081
+ location=location,
1082
+ duration=format_duration(
1083
+ self.audio_processor.get_duration(self._waveform, self._sample_rate)
1084
+ ),
1085
+ participants=participants,
1086
+ )
1087
+
1088
+ if output_filename is None:
1089
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1090
+ safe_title = sanitize_filename(title)[:30]
1091
+ output_filename = f"notulensi_{safe_title}_{timestamp}.docx"
1092
+
1093
+ with Timer("Document generation"):
1094
+ doc_path = self.doc_generator.generate(
1095
+ metadata=metadata,
1096
+ summary=self._summary,
1097
+ transcript=self._transcript_segments,
1098
+ output_filename=output_filename,
1099
+ )
1100
+
1101
+ self._log(f"Document saved: {doc_path}")
1102
+
1103
+ # Save intermediate results
1104
+ if self.config.save_intermediate:
1105
+ self._save_intermediate_results(output_filename, metadata)
1106
+
1107
+ processing_time = 0.0
1108
+ result = PipelineResult(
1109
+ audio_path=output_filename,
1110
+ audio_duration=self.audio_processor.get_duration(self._waveform, self._sample_rate),
1111
+ num_speakers=len(set(seg.speaker_id for seg in self._diarization_segments)),
1112
+ num_segments=len(self._transcript_segments),
1113
+ total_words=total_words,
1114
+ processing_time=processing_time,
1115
+ segments=[seg.to_dict() for seg in (self._transcript_segments or [])],
1116
+ transcript_text="\n".join([s.text for s in (self._transcript_segments or [])]),
1117
+ summary=self._summary.to_dict() if self._summary else {},
1118
+ document_path=str(doc_path),
1119
+ )
1120
+
1121
+ return result
src/speaker.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Speaker classifier scaffold for multi-task training and evaluation.
2
+
3
+ This module provides a small PyTorch `SpeakerClassifier` that maps embeddings
4
+ (or pooled encoder outputs) to speaker logits, plus helpers to build speaker
5
+ mappings from manifests.
6
+ """
7
+
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Dict, List
11
+
12
+ try:
13
+ import torch
14
+ import torch.nn as nn
15
+ except Exception:
16
+ torch = None
17
+ nn = None
18
+
19
+
20
+ class SpeakerClassifier:
21
+ """A light-weight wrapper that exposes an API-compatible classifier.
22
+
23
+ If PyTorch is available, `SpeakerClassifier.model` is a `nn.Module`.
24
+ Otherwise this is a placeholder to keep the dependency optional in tests.
25
+ """
26
+
27
+ def __init__(self, input_dim: int, num_speakers: int, dropout: float = 0.1):
28
+ self.input_dim = input_dim
29
+ self.num_speakers = num_speakers
30
+ self.dropout = dropout
31
+ if torch is not None and nn is not None:
32
+ self.model = nn.Sequential(
33
+ nn.Dropout(p=dropout),
34
+ nn.Linear(input_dim, num_speakers),
35
+ )
36
+ else:
37
+ self.model = None
38
+
39
+ def forward(self, x):
40
+ if self.model is None:
41
+ raise RuntimeError("PyTorch not available for SpeakerClassifier")
42
+ return self.model(x)
43
+
44
+
45
+ def build_speaker_map(manifest_paths: List[str]) -> Dict[str, int]:
46
+ """Read JSONL manifest(s) and return a speaker->id mapping.
47
+
48
+ The manifest format: each line is JSON with optional "speaker" key.
49
+ Labels are returned in deterministic sorted order.
50
+ """
51
+ speakers = set()
52
+ for p in manifest_paths:
53
+ pth = Path(p)
54
+ if not pth.exists():
55
+ continue
56
+ with open(pth, "r", encoding="utf-8") as fh:
57
+ for line in fh:
58
+ line = line.strip()
59
+ if not line:
60
+ continue
61
+ try:
62
+ obj = json.loads(line)
63
+ except Exception:
64
+ continue
65
+ spk = obj.get("speaker")
66
+ if spk is not None:
67
+ speakers.add(str(spk))
68
+ sorted_spks = sorted(speakers)
69
+ return {s: i for i, s in enumerate(sorted_spks)}
src/summarizer.py ADDED
@@ -0,0 +1,1783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BERT Extractive Summarization Module
3
+ ====================================
4
+ Implements extractive summarization using IndoBERT/mBERT for meeting minutes.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import re
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ import numpy as np
14
+
15
+
16
+ def _collapse_repeated_phrases_global(text: str, max_ngram: int = 6, min_repeats: int = 2) -> str:
17
+ """Module-level helper to collapse repeated n-gram phrases.
18
+
19
+ Iteratively collapses repeated adjacent n-gram phrases into a single occurrence.
20
+ """
21
+ if not text or min_repeats < 2:
22
+ return text
23
+ pattern = re.compile(r"(\b(?:\w+\s+){0,%d}\w+\b)(?:\s+\1){%d,}" % (max_ngram - 1, min_repeats - 1), flags=re.IGNORECASE)
24
+ prev = None
25
+ out = text
26
+ while prev != out:
27
+ prev = out
28
+ out = pattern.sub(r"\1", out)
29
+ return out
30
+
31
+ from src.transcriber import TranscriptSegment
32
+
33
+
34
+ @dataclass
35
+ class SummarizationConfig:
36
+ """Configuration for summarization"""
37
+
38
+ # Method: 'extractive' (BERT embeddings) or 'abstractive' (seq2seq model)
39
+ method: str = "extractive"
40
+
41
+ # Models
42
+ # Use a cached/available model for reliability in offline environments
43
+ sentence_model_id: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
44
+ abstractive_model_id: str = "google/mt5-base"
45
+
46
+ # Extractive settings (increase to capture more key points)
47
+ num_sentences: int = 7
48
+ min_sentence_length: int = 6
49
+ max_sentence_length: int = 300
50
+
51
+ # Abstractive settings
52
+ max_input_chars: int = 1000
53
+ max_summary_length: int = 128
54
+ min_summary_length: int = 30
55
+
56
+ # Light abstractive refinement step (run on condensed extractive overview)
57
+ do_abstractive_refinement: bool = True
58
+ abstractive_refine_max_len: int = 80
59
+
60
+ # Generate a comprehensive executive overview (long, covering entire meeting)
61
+ comprehensive_overview: bool = True
62
+ comprehensive_max_length: int = 512
63
+
64
+ # Post-processing options
65
+ polish_overview: bool = True
66
+ semantic_dedup_threshold: float = 0.75
67
+
68
+ # Scoring weights
69
+ position_weight: float = 0.15
70
+ length_weight: float = 0.10
71
+ similarity_weight: float = 0.75
72
+
73
+ # Keywords for detection
74
+ decision_keywords: List[str] = field(
75
+ default_factory=lambda: [
76
+ "diputuskan",
77
+ "disepakati",
78
+ "kesimpulan",
79
+ "keputusan",
80
+ "jadi",
81
+ "maka",
82
+ "sepakat",
83
+ "setuju",
84
+ "final",
85
+ "kesepakatan",
86
+ "disimpulkan",
87
+ "ditetapkan",
88
+ "disetujui",
89
+ "putus",
90
+ ]
91
+ )
92
+
93
+ action_keywords: List[str] = field(
94
+ default_factory=lambda: [
95
+ "akan",
96
+ "harus",
97
+ "perlu",
98
+ "tolong",
99
+ "mohon",
100
+ "harap",
101
+ "deadline",
102
+ "target",
103
+ "tugas",
104
+ "tanggung jawab",
105
+ "action item",
106
+ "follow up",
107
+ "tindak lanjut",
108
+ "dikerjakan",
109
+ "selesaikan",
110
+ "lakukan",
111
+ "siapkan",
112
+ "minggu depan",
113
+ "besok",
114
+ "segera",
115
+ "bikin",
116
+ "buat",
117
+ ]
118
+ )
119
+
120
+ # Device
121
+ device: str = "cpu"
122
+
123
+
124
+ @dataclass
125
+ class MeetingSummary:
126
+ """Structured meeting summary"""
127
+
128
+ overview: str
129
+ key_points: List[str]
130
+ decisions: List[str]
131
+ action_items: List[Dict[str, str]]
132
+ topics: List[str] = field(default_factory=list)
133
+
134
+ def to_dict(self) -> Dict[str, Any]:
135
+ """Convert to dictionary"""
136
+ return {
137
+ "overview": self.overview,
138
+ "key_points": self.key_points,
139
+ "decisions": self.decisions,
140
+ "action_items": self.action_items,
141
+ "topics": self.topics,
142
+ "keywords": getattr(self, "keywords", []),
143
+ }
144
+
145
+ def __str__(self) -> str:
146
+ """String representation"""
147
+ lines = []
148
+ lines.append("=== RINGKASAN RAPAT ===\n")
149
+ lines.append(f"Overview:\n{self.overview}\n")
150
+
151
+ if self.key_points:
152
+ lines.append("Poin-Poin Penting:")
153
+ for i, point in enumerate(self.key_points, 1):
154
+ lines.append(f" {i}. {point}")
155
+ lines.append("")
156
+
157
+ if self.decisions:
158
+ lines.append("Keputusan:")
159
+ for i, decision in enumerate(self.decisions, 1):
160
+ lines.append(f" {i}. {decision}")
161
+ lines.append("")
162
+
163
+ if self.action_items:
164
+ lines.append("Action Items:")
165
+ for i, item in enumerate(self.action_items, 1):
166
+ owner = item.get("owner", "TBD")
167
+ task = item.get("task", "")
168
+ due = item.get("due", "")
169
+ if due:
170
+ lines.append(f" {i}. [{owner}] {task} (Due: {due})")
171
+ else:
172
+ lines.append(f" {i}. [{owner}] {task}")
173
+
174
+ if self.topics:
175
+ lines.append("")
176
+ lines.append("Topik:")
177
+ lines.append(", ".join(self.topics))
178
+
179
+ return "\n".join(lines)
180
+
181
+ def to_json(self) -> str:
182
+ """Return a JSON string for machine-readable outputs."""
183
+ import json
184
+
185
+ return json.dumps(self.to_dict(), ensure_ascii=False, indent=2)
186
+
187
+ def to_yaml(self) -> str:
188
+ """Return a YAML string (requires PyYAML)."""
189
+ try:
190
+ import yaml
191
+
192
+ return yaml.safe_dump(self.to_dict(), allow_unicode=True)
193
+ except Exception:
194
+ # Fallback to JSON if YAML not available
195
+ return self.to_json()
196
+
197
+
198
+ class AbstractiveSummarizer:
199
+ """Abstractive summarizer using HuggingFace transformers pipeline (mt5/mbart/etc)."""
200
+
201
+ def __init__(self, config: Optional[SummarizationConfig] = None):
202
+ self.config = config or SummarizationConfig()
203
+ self._pipeline = None
204
+
205
+ def _load_model(self):
206
+ if self._pipeline is None:
207
+ try:
208
+ from transformers import pipeline
209
+
210
+ device = 0 if self.config.device.startswith("cuda") else -1
211
+ print(f"[Summarizer] Loading abstractive model: {self.config.abstractive_model_id}")
212
+ self._pipeline = pipeline(
213
+ "summarization",
214
+ model=self.config.abstractive_model_id,
215
+ tokenizer=self.config.abstractive_model_id,
216
+ device=device,
217
+ truncation=True,
218
+ )
219
+ print("[Summarizer] Abstractive model loaded successfully")
220
+ except Exception as e:
221
+ print(f"[Summarizer] Warning: abstractive model load failed: {e}")
222
+ self._pipeline = None
223
+
224
+ def _chunk_text(self, text: str) -> List[str]:
225
+ max_chars = int(self.config.max_input_chars)
226
+ if len(text) <= max_chars:
227
+ return [text]
228
+ chunks = []
229
+ start = 0
230
+ while start < len(text):
231
+ end = min(len(text), start + max_chars)
232
+ # try to cut at sentence boundary
233
+ cut = text.rfind(".", start, end)
234
+ if cut <= start:
235
+ cut = end
236
+ chunk = text[start:cut].strip()
237
+ if chunk:
238
+ # prevent repeating identical chunks
239
+ chunk = self._collapse_repeated_phrases(chunk)
240
+ chunks.append(chunk)
241
+ start = cut
242
+ return chunks
243
+
244
+ def _clean_abstractive_output(self, overview: str, full_text: str) -> (str, List[str]):
245
+ """Clean artifacts from abstractive model output and produce fallback key points.
246
+
247
+ Returns (overview_clean, key_points)
248
+ """
249
+ overview_clean = self._clean_abstractive_text(overview)
250
+
251
+ # If abstract output is still noisy (placeholders remain or too few alpha tokens), fallback to extractive
252
+ if "<extra_id" in overview or len(re.findall(r"[a-zA-Z]{2,}", overview_clean)) < 10 or re.search(r"\b(\w+)(?:\s+\1){2,}", overview_clean.lower()):
253
+ sentences = BERTSummarizer(self.config)._split_sentences(full_text)
254
+ key_points = [s for s in sentences[: self.config.num_sentences]]
255
+ overview_clean = " ".join(key_points[:3])
256
+ return overview_clean, key_points
257
+
258
+ # Otherwise make sure key points are meaningful and deduplicated
259
+ parts = [s.strip() for s in re.split(r"\.|!|\?", overview_clean) if s.strip()]
260
+ seen_kp = set()
261
+ key_points: List[str] = []
262
+ for p in parts:
263
+ p_clean = re.sub(r"[^\w\s]", "", p) if p else p
264
+ p_clean = re.sub(r"\s+", " ", p_clean).strip()
265
+ if len(p_clean.split()) < 3:
266
+ continue
267
+ low = p_clean.lower()
268
+ if low in seen_kp:
269
+ continue
270
+ seen_kp.add(low)
271
+ key_points.append(p_clean)
272
+ if len(key_points) >= self.config.num_sentences:
273
+ break
274
+
275
+ return overview_clean, key_points
276
+
277
+ def _clean_abstractive_text(self, text: str) -> str:
278
+ """Lightweight cleaning of abstractive text outputs (remove placeholders, collapse punctuation).
279
+
280
+ Kept as a separate method for unit testing/backwards compatibility with older tests.
281
+ Also collapses repeated trivial tokens and reduces punctuation runs.
282
+ """
283
+ t = re.sub(r"<extra_id_\d+>", "", text)
284
+ t = re.sub(r"\)\s*<extra_id_\d+>", "", t)
285
+ # collapse repeated short filler words sequences e.g. "Jadi contohnya Jadi contohnya ..."
286
+ t = self._collapse_repeated_phrases(t)
287
+ t = re.sub(r"\s*[\.]{2,}\s*", ". ", t)
288
+ t = re.sub(r"[!?]{2,}", ".", t)
289
+ t = re.sub(r"\s+", " ", t).strip()
290
+ # Remove leading/trailing hyphens and stray punctuation
291
+ t = re.sub(r"^[-\s]+|[-\s]+$", "", t)
292
+ if not re.search(r"[.!?]$", t):
293
+ t = t + "."
294
+ return t
295
+
296
+ def _generate_keywords(self, text: str, top_k: int = 8) -> List[str]:
297
+ """Generate simple keywords by frequency (fallback)."""
298
+ toks = re.findall(r"\b[a-zA-Z]{4,}\b", text.lower())
299
+ freq = {}
300
+ stop = {"yang","dan","ini","itu","untuk","dengan","juga","sudah","ada","kita","saya","kamu"}
301
+ for w in toks:
302
+ if w in stop:
303
+ continue
304
+ freq[w] = freq.get(w, 0) + 1
305
+ sorted_words = sorted(freq.items(), key=lambda x: x[1], reverse=True)
306
+ return [w for w, _ in sorted_words[:top_k]]
307
+
308
+ def _collapse_repeated_phrases(self, text: str, max_ngram: int = 6, min_repeats: int = 2) -> str:
309
+ """Delegates to module-level collapse helper"""
310
+ return _collapse_repeated_phrases_global(text, max_ngram=max_ngram, min_repeats=min_repeats)
311
+
312
+ def _semantic_deduplicate(self, items: List[str], threshold: Optional[float] = None) -> List[str]:
313
+ """Delegate to AbstractiveSummarizer's semantic dedupe for compatibility."""
314
+ return AbstractiveSummarizer(self.config)._semantic_deduplicate(items, threshold)
315
+
316
+ def _semantic_dedup_action_items(self, actions: List[Dict[str, str]], threshold: Optional[float] = None) -> List[Dict[str, str]]:
317
+ """Delegate to AbstractiveSummarizer's action-item dedupe for compatibility."""
318
+ return AbstractiveSummarizer(self.config)._semantic_dedup_action_items(actions, threshold)
319
+
320
+ def _parse_structured_output(self, raw: str, defaults: Dict[str, Any]) -> (str, List[str]):
321
+ """Try to parse YAML/JSON or simple structured text into (overview, keywords).
322
+
323
+ If parsing fails, return (cleaned_raw, fallback_keywords)
324
+ """
325
+ cleaned = raw.strip()
326
+
327
+ # Try YAML first (if available)
328
+ try:
329
+ import yaml
330
+
331
+ parsed = yaml.safe_load(cleaned)
332
+ if isinstance(parsed, dict):
333
+ ov = parsed.get("overview", "")
334
+ kws = parsed.get("keywords", None)
335
+ if kws is None:
336
+ kws = self._generate_keywords(ov or " ".join(defaults.get("key_points", [])))
337
+ return (ov.strip() if isinstance(ov, str) else "", kws)
338
+ except Exception:
339
+ pass
340
+
341
+ # Try JSON
342
+ try:
343
+ import json
344
+
345
+ parsed = json.loads(cleaned)
346
+ if isinstance(parsed, dict):
347
+ ov = parsed.get("overview", "")
348
+ kws = parsed.get("keywords", None)
349
+ if kws is None:
350
+ kws = self._generate_keywords(ov or " ".join(defaults.get("key_points", [])))
351
+ return (ov.strip() if isinstance(ov, str) else "", kws)
352
+ except Exception:
353
+ pass
354
+
355
+ # Simple heuristic: look for header 'overview:' or 'Ringkasan:' in text
356
+ m = re.search(r"(?im)^(overview|ringkasan)\s*:\s*(.*)$", cleaned)
357
+ if m:
358
+ ov = m.group(2).strip()
359
+ kws = self._generate_keywords(ov or " ".join(defaults.get("key_points", [])))
360
+ return ov, kws
361
+
362
+ # If nothing recognized, return fallback cleaned text and keywords
363
+ return cleaned, self._generate_keywords(cleaned or " ".join(defaults.get("key_points", [])))
364
+
365
+ def _sanitize_for_prompt(self, text: str) -> str:
366
+ """Sanitize text before injecting into the prompt: remove model placeholders, URLs/domains/emails,
367
+ common web-article boilerplate (closing lines like "Semoga bermanfaat"), and collapse repeats."""
368
+ if not text:
369
+ return text
370
+ t = re.sub(r"<extra_id_\d+>", "", text)
371
+ # remove emails
372
+ t = re.sub(r"\b\S+@\S+\.\S+\b", " ", t)
373
+ # remove domain-like tokens (e.g., Eksekutif.com.co.id)
374
+ t = re.sub(r"\b\S+\.(?:com|co\.id|info|id|net|org)(?:\.[a-z]{2,})*\b", " ", t, flags=re.IGNORECASE)
375
+ # remove common article/web boilerplate short phrases that often appear as closings
376
+ t = re.sub(r"(?i)\b(semoga artikel ini bermanfaat(?: bagi anda semua)?|semoga bermanfaat|terima kasih(?: atas masukannya| juga)?)\b[.!\s,]*", " ", t)
377
+ t = re.sub(r"\s+", " ", t).strip()
378
+ t = _collapse_repeated_phrases_global(t)
379
+ return t
380
+
381
+ def _is_repetitive_text(self, text: str, max_run: int = 6) -> bool:
382
+ """Detect highly repetitive model outputs (including repeated n-gram phrases).
383
+
384
+ Returns True if repetition patterns exceed thresholds.
385
+ """
386
+ if not text:
387
+ return False
388
+ # check placeholder presence quickly
389
+ if re.search(r"<extra_id_\d+>", text):
390
+ return True
391
+ # Tokenize
392
+ tokens = re.findall(r"\w+", text.lower())
393
+ if not tokens:
394
+ return False
395
+ # Check simple token runs
396
+ run = 1
397
+ last = tokens[0]
398
+ for tok in tokens[1:]:
399
+ if tok == last:
400
+ run += 1
401
+ if run >= max_run:
402
+ return True
403
+ else:
404
+ last = tok
405
+ run = 1
406
+ # Check n-gram repeated phrase runs for n=1..4
407
+ max_ngram = 4
408
+ n_tokens = len(tokens)
409
+ for n in range(1, max_ngram + 1):
410
+ i = 0
411
+ while i + 2 * n <= n_tokens:
412
+ # compare tokens[i:i+n] with subsequent repeated occurrences
413
+ pattern = tokens[i:i + n]
414
+ run = 1
415
+ j = i + n
416
+ while j + n <= n_tokens and tokens[j:j + n] == pattern:
417
+ run += 1
418
+ j += n
419
+ if run >= max_run:
420
+ return True
421
+ i += 1
422
+ # fallback regex for single-token repetition
423
+ if re.search(r"(\b\w+\b)(?:\s+\1\b){%d,}" % (max_run - 1), text.lower()):
424
+ return True
425
+ return False
426
+
427
+ def _contains_domain_noise(self, text: str) -> bool:
428
+ """Detect domain-like or short web boilerplate noise (e.g., 'Eksekutif.com', 'Semoga artikel ini bermanfaat').
429
+
430
+ Returns True if common domain patterns or boilerplate phrases are found.
431
+ """
432
+ if not text:
433
+ return False
434
+ if re.search(r"\b\S+\.(?:com|co\.id|info|id|net|org)(?:\.[a-z]{2,})*\b", text, flags=re.IGNORECASE):
435
+ return True
436
+ if re.search(r"(?i)\b(semoga artikel ini bermanfaat(?: bagi anda semua)?|semoga bermanfaat|terima kasih)\b", text):
437
+ return True
438
+ return False
439
+
440
+ def _normalize_overview_text(self, text: str) -> str:
441
+ """Normalize overview into a readable paragraph or keep structured lists tidy."""
442
+ if not text:
443
+ return text
444
+ t = text.strip()
445
+ # collapse repeated fragments first
446
+ t = _collapse_repeated_phrases_global(t)
447
+
448
+ # If text contains list markers or section headers, tidy spacing and return
449
+ if "\n-" in t or "Poin-Poin Penting" in t or "Keputusan" in t or "Action Items" in t:
450
+ # normalize newlines and strip extra spaces
451
+ t = re.sub(r"\n\s+", "\n", t)
452
+ t = re.sub(r"\n{2,}", "\n\n", t)
453
+ return t.strip()
454
+
455
+ # Otherwise make a single paragraph and deduplicate near-duplicate fragments
456
+ # split by common separators (newline, bullet, or hyphen sequences)
457
+ if " - " in t:
458
+ parts = [p.strip(" -" ) for p in re.split(r"\s*-\s*", t) if p.strip()]
459
+ else:
460
+ parts = [p.strip() for p in re.split(r"(?<=[.!?])\s+", t) if p.strip()]
461
+
462
+ seen = set()
463
+ uniq = []
464
+ for p in parts:
465
+ norm = re.sub(r"[^a-z0-9 ]", "", p.lower())
466
+ norm = re.sub(r"\s+", " ", norm).strip()
467
+ if not norm:
468
+ continue
469
+ if norm in seen:
470
+ continue
471
+ seen.add(norm)
472
+ uniq.append(p.strip(" -."))
473
+
474
+ para = " ".join(uniq)
475
+ para = re.sub(r"\s+", " ", para).strip()
476
+
477
+ # Remove any leftover emails/domains or short web boilerplate that slipped through
478
+ para = re.sub(r"\b\S+@\S+\.\S+\b", " ", para)
479
+ para = re.sub(r"\b\S+\.(?:com|co\.id|info|id|net|org)(?:\.[a-z]{2,})*\b", " ", para, flags=re.IGNORECASE)
480
+ para = re.sub(r"(?i)\b(semoga artikel ini bermanfaat(?: bagi anda semua)?|semoga bermanfaat|terima kasih(?: atas masukannya| juga)?)\b[.!\s,]*", " ", para)
481
+ para = re.sub(r"\s+", " ", para).strip()
482
+
483
+ if para and not re.search(r"[.!?]$", para):
484
+ para = para + "."
485
+ if para:
486
+ para = para[0].upper() + para[1:]
487
+ return para
488
+
489
+ def _polish_overview(self, overview: str, full_text: str) -> str:
490
+ """Polish overview into an executive, coherent paragraph using abstractive model (if available).
491
+
492
+ Falls back to normalization and deduplication if model not available.
493
+ """
494
+ if not overview:
495
+ return overview
496
+ # Basic normalization first
497
+ overview = _collapse_repeated_phrases_global(overview)
498
+ overview = self._normalize_overview_text(overview)
499
+
500
+ # If model available and config allows, ask for paraphrase/expansion
501
+ if getattr(self.config, "polish_overview", True):
502
+ try:
503
+ self._load_model()
504
+ if self._pipeline is not None:
505
+ prompt = (
506
+ "Paraphrase dan perluas teks berikut menjadi paragraf eksekutif yang jelas, ringkas, dan mudah dibaca. "
507
+ "Jangan sertakan header."
508
+ "\n\nTeks:\n" + overview
509
+ )
510
+ out = self._pipeline(
511
+ prompt,
512
+ max_length=min(getattr(self.config, "comprehensive_max_length", 512), 350),
513
+ min_length=40,
514
+ truncation=True,
515
+ do_sample=False,
516
+ )
517
+ if isinstance(out, list) and out:
518
+ candidate = out[0].get("summary_text", "").strip()
519
+ candidate = self._clean_abstractive_text(candidate)
520
+ candidate = _collapse_repeated_phrases_global(candidate)
521
+ candidate = self._normalize_overview_text(candidate)
522
+ return candidate
523
+ except Exception:
524
+ pass
525
+
526
+ return overview
527
+
528
+ def _semantic_deduplicate(self, items: List[str], threshold: Optional[float] = None) -> List[str]:
529
+ """Deduplicate similar items using sentence-transformer embeddings + cosine similarity.
530
+
531
+ Returns the first occurrence for each semantic group.
532
+ """
533
+ if not items:
534
+ return []
535
+ thr = threshold if threshold is not None else getattr(self.config, "semantic_dedup_threshold", 0.75)
536
+ # try embeddings
537
+ try:
538
+ embs = self._compute_embeddings(items)
539
+ if embs is not None:
540
+ from sklearn.metrics.pairwise import cosine_similarity
541
+
542
+ sim = cosine_similarity(embs)
543
+ n = len(items)
544
+ taken = set()
545
+ result = []
546
+ for i in range(n):
547
+ if i in taken:
548
+ continue
549
+ result.append(items[i])
550
+ for j in range(i + 1, n):
551
+ if sim[i, j] >= thr:
552
+ taken.add(j)
553
+ # If embeddings didn't merge anything useful, fallback to token-jaccard grouping
554
+ if len(result) == len(items) and len(items) > 1:
555
+ # token Jaccard
556
+ token_sets = [set(re.findall(r"\w+", it.lower())) for it in items]
557
+ taken2 = set()
558
+ result2 = []
559
+ for i in range(len(items)):
560
+ if i in taken2:
561
+ continue
562
+ result2.append(items[i])
563
+ for j in range(i + 1, len(items)):
564
+ if j in taken2:
565
+ continue
566
+ si = token_sets[i]
567
+ sj = token_sets[j]
568
+ if not si or not sj:
569
+ continue
570
+ jacc = len(si & sj) / float(len(si | sj))
571
+ if jacc >= 0.45:
572
+ taken2.add(j)
573
+ return result2
574
+ return result
575
+ else:
576
+ raise ValueError("No embeddings")
577
+ except Exception:
578
+ # fallback to token-jaccard grouping first (robust when embeddings aren't available)
579
+ try:
580
+ token_sets = [set(re.findall(r"\w+", it.lower())) for it in items]
581
+ taken = set()
582
+ res = []
583
+ for i in range(len(items)):
584
+ if i in taken:
585
+ continue
586
+ res.append(items[i])
587
+ si = token_sets[i]
588
+ for j in range(i + 1, len(items)):
589
+ if j in taken:
590
+ continue
591
+ sj = token_sets[j]
592
+ if not si or not sj:
593
+ continue
594
+ jacc = len(si & sj) / float(len(si | sj))
595
+ if jacc >= 0.45:
596
+ taken.add(j)
597
+ return res
598
+ except Exception:
599
+ # final fallback to naive textual deduplication
600
+ seen = set()
601
+ res = []
602
+ for it in items:
603
+ low = re.sub(r"\s+", " ", it.lower()).strip()
604
+ if low in seen:
605
+ continue
606
+ seen.add(low)
607
+ res.append(it)
608
+ return res
609
+
610
+ def _semantic_dedup_action_items(self, actions: List[Dict[str, str]], threshold: Optional[float] = None) -> List[Dict[str, str]]:
611
+ """Deduplicate action items by task text; merge owners when necessary."""
612
+ if not actions:
613
+ return []
614
+ tasks = [a.get("task", "") for a in actions]
615
+ groups = self._semantic_deduplicate(tasks, threshold=threshold)
616
+ # groups contains first representative tasks; now build merged items
617
+ merged = []
618
+ for rep in groups:
619
+ owners = []
620
+ timestamps = []
621
+ dues = set()
622
+ for a in actions:
623
+ if a.get("task", "") == rep or (rep and rep in a.get("task", "")):
624
+ if a.get("owner") and a.get("owner") not in owners:
625
+ owners.append(a.get("owner"))
626
+ if a.get("timestamp"):
627
+ timestamps.append(a.get("timestamp"))
628
+ if a.get("due"):
629
+ dues.add(a.get("due"))
630
+ owner_str = " / ".join(owners) if owners else "TBD"
631
+ merged.append({
632
+ "owner": owner_str,
633
+ "task": rep,
634
+ "timestamp": timestamps[0] if timestamps else "",
635
+ "due": ", ".join(sorted(list(dues))) if dues else "",
636
+ })
637
+ return merged
638
+
639
+ def generate_comprehensive_summary(self, full_text: str, key_points: List[str], decisions: List[str], action_items: List[Dict[str, str]], topics: List[str]) -> (str, List[str]):
640
+ """Generate a comprehensive executive summary covering the meeting.
641
+
642
+ Uses the abstractive pipeline with a guided prompt built from extracted components.
643
+ Attempts to request YAML-structured output for reliable parsing; falls back to rule-based assembly.
644
+ Returns (overview_text, keywords)
645
+ """
646
+ # Build a structured prompt that requests YAML output for safe parsing
647
+ prompt_parts = [
648
+ "Anda adalah asisten yang menulis ringkasan rapat yang komprehensif dan terstruktur.",
649
+ "Output harus dalam format YAML dengan kunci: overview, key_points (list), decisions (list), action_items (list of {owner, task, due}), keywords (list).",
650
+ "Berikan overview naratif yang jelas, serta daftar poin penting, keputusan, dan tindak lanjut.",
651
+ "Topik yang dibahas:",
652
+ ", ".join(topics) if topics else "-",
653
+ "Poin-poin penting:\n" + "\n".join([f"- {p}" for p in key_points]) if key_points else "",
654
+ "Keputusan:\n" + "\n".join([f"- {d}" for d in decisions]) if decisions else "",
655
+ "Tindak lanjut (Action Items):\n" + "\n".join([f"- [{a.get('owner','TBD')}] {a.get('task','')}" for a in action_items]) if action_items else "",
656
+ "Tuliskan field 'overview' minimal 80 kata sebagai paragraf naratif yang merangkum seluruh rapat dengan jelas.",
657
+ "Mohon hasilkan YAML yang valid."
658
+ ]
659
+ prompt = "\n\n".join([p for p in prompt_parts if p])
660
+
661
+ # Sanitize inputs to avoid placeholder tokens and repeated garbage
662
+ key_points = [self._sanitize_for_prompt(k) for k in key_points if k and k.strip()]
663
+ decisions = [self._sanitize_for_prompt(d) for d in decisions if d and d.strip()]
664
+ for a in action_items:
665
+ a['task'] = self._sanitize_for_prompt(a.get('task',''))
666
+
667
+ # Deduplicate before sending to model
668
+ try:
669
+ key_points = self._semantic_deduplicate(key_points)
670
+ decisions = self._semantic_deduplicate(decisions)
671
+ except Exception:
672
+ key_points = list(dict.fromkeys(key_points))
673
+ decisions = list(dict.fromkeys(decisions))
674
+
675
+ # Use pipeline if available
676
+ try:
677
+ self._load_model()
678
+ if self._pipeline is not None:
679
+ # Try up to 2 attempts: first deterministic, second sampled if repetition/shortness detected
680
+ attempts = 2
681
+ for attempt in range(attempts):
682
+ gen_kwargs = dict(
683
+ max_length=getattr(self.config, "comprehensive_max_length", 512),
684
+ min_length=max(80, int(getattr(self.config, "comprehensive_max_length", 512) * 0.12)),
685
+ truncation=True,
686
+ do_sample=False,
687
+ no_repeat_ngram_size=4,
688
+ repetition_penalty=1.3,
689
+ )
690
+ if attempt == 1:
691
+ # more creative generation if deterministic attempt failed
692
+ gen_kwargs.update({"do_sample": True, "temperature": 0.7, "top_p": 0.9})
693
+
694
+ out = self._pipeline(prompt, **gen_kwargs)
695
+ text = out[0].get("summary_text", "").strip()
696
+
697
+ # collapse repeated fragments, then clean
698
+ text = self._collapse_repeated_phrases(text)
699
+ cleaned = self._clean_abstractive_text(text)
700
+
701
+ # Quick heuristic checks (repetition, too short, or domain-like web boilerplate -> retry)
702
+ if self._is_repetitive_text(cleaned) or len(cleaned.split()) < 20 or self._contains_domain_noise(cleaned):
703
+ # try again (next attempt) with sampling
704
+ if attempt + 1 < attempts:
705
+ continue
706
+
707
+ # Attempt to parse structured YAML/JSON
708
+ overview, keywords = self._parse_structured_output(cleaned, {
709
+ "key_points": key_points,
710
+ "decisions": decisions,
711
+ "action_items": action_items,
712
+ })
713
+
714
+ # Final normalization / optional polish
715
+ overview = self._normalize_overview_text(overview)
716
+ if getattr(self.config, "polish_overview", True):
717
+ overview = self._polish_overview(overview, full_text)
718
+
719
+ # Validate overview quality: non-empty, not too short, not repetitive
720
+ if overview and len(overview.split()) >= 10 and not self._is_repetitive_text(overview):
721
+ return overview, keywords
722
+ else:
723
+ # Try next attempt if available, otherwise break to fallback
724
+ if attempt + 1 < attempts:
725
+ continue
726
+ else:
727
+ break
728
+ except Exception:
729
+ pass
730
+
731
+ # Fallback rule-based assembly: construct a narrative paragraph summarizing meeting,
732
+ # rather than repeating the list headers. Use polishing to turn it into an executive paragraph.
733
+ def _format_action_items(ai_list):
734
+ pairs = []
735
+ for a in ai_list:
736
+ owner = a.get('owner', 'TBD')
737
+ task = a.get('task', '').strip()
738
+ if task:
739
+ pairs.append(f"{owner} akan {task.rstrip('.')}.")
740
+ return " ".join(pairs)
741
+
742
+ def _join_points(pts):
743
+ # join key points into a sentence
744
+ if not pts:
745
+ return ""
746
+ # take up to 4 points to avoid overly long lists
747
+ pts_sample = pts[:4]
748
+ return "; ".join([p.rstrip('.') for p in pts_sample]) + ""
749
+
750
+ narrative_parts = []
751
+ if topics:
752
+ narrative_parts.append("Topik utama yang dibahas meliputi: " + ", ".join(topics) + ".")
753
+ if key_points:
754
+ narrative_parts.append("Beberapa poin penting termasuk: " + _join_points(key_points) + ".")
755
+ if decisions:
756
+ narrative_parts.append("Keputusan utama yang dicapai termasuk: " + ", ".join([d.rstrip('.') for d in decisions]) + ".")
757
+ if action_items:
758
+ narrative_parts.append("Tindak lanjut yang disepakati di antaranya: " + _format_action_items(action_items))
759
+
760
+ assembled = " ".join([p for p in narrative_parts if p]).strip()
761
+ # Normalize and then optionally polish into a smooth executive paragraph
762
+ assembled = self._normalize_overview_text(assembled)
763
+ if getattr(self.config, "polish_overview", True):
764
+ assembled = self._polish_overview(assembled, full_text)
765
+
766
+ keywords = self._generate_keywords(assembled, top_k=8)
767
+ return assembled, keywords
768
+
769
+ def summarize(self, transcript_segments: List[TranscriptSegment]) -> MeetingSummary:
770
+ self._load_model()
771
+
772
+ full_text = " ".join([seg.text for seg in transcript_segments if seg.text])
773
+ if not full_text.strip():
774
+ return MeetingSummary(
775
+ overview="Tidak ada konten yang dapat diringkas.",
776
+ key_points=[],
777
+ decisions=[],
778
+ action_items=[],
779
+ )
780
+
781
+ # Clean up common disfluencies/politeness tokens and ASR annotations
782
+ full_text = re.sub(r"\[OVERLAP\]|\[NOISE\]|<.*?>", "", full_text)
783
+ full_text = re.sub(
784
+ r"\b(oke|ya|oke,|baik|sekarang|sekarang kita|nah|jadi|oke\.|jadi\.)\b",
785
+ "",
786
+ full_text,
787
+ flags=re.IGNORECASE,
788
+ )
789
+ full_text = re.sub(r"\s+", " ", full_text).strip()
790
+
791
+ # Chunk and summarize
792
+ if self._pipeline is None:
793
+ # fallback: return first few sentences
794
+ sentences = BERTSummarizer(self.config)._split_sentences(full_text)
795
+ overview = " ".join(sentences[: min(3, len(sentences))])
796
+ else:
797
+ chunks = self._chunk_text(full_text)
798
+ partial_summaries = []
799
+ for chunk in chunks:
800
+ try:
801
+ out = self._pipeline(
802
+ chunk,
803
+ max_length=self.config.max_summary_length,
804
+ min_length=self.config.min_summary_length,
805
+ truncation=True,
806
+ do_sample=False,
807
+ )
808
+ partial_summaries.append(out[0]["summary_text"].strip())
809
+ except Exception as e:
810
+ print(f"[Summarizer] chunk summarization failed: {e}")
811
+ continue
812
+
813
+ # If multiple partial summaries, join and optionally summarize again
814
+ combined = " ".join(partial_summaries)
815
+ if len(combined) > self.config.max_input_chars and self._pipeline:
816
+ try:
817
+ out = self._pipeline(
818
+ combined,
819
+ max_length=self.config.max_summary_length,
820
+ min_length=self.config.min_summary_length,
821
+ truncation=True,
822
+ do_sample=False,
823
+ )
824
+ overview = out[0]["summary_text"].strip()
825
+ except Exception:
826
+ overview = combined
827
+ else:
828
+ overview = combined
829
+
830
+ # Clean abstractive overview and produce robust key points (use helper)
831
+ overview, key_points = self._clean_abstractive_output(overview, full_text)
832
+
833
+ # Extract decisions and actions via keywords
834
+ sentences = BERTSummarizer(self.config)._split_sentences(full_text)
835
+ decisions = BERTSummarizer(self.config)._extract_decisions(sentences)
836
+ action_items = BERTSummarizer(self.config)._extract_action_items(transcript_segments)
837
+ topics = BERTSummarizer(self.config)._extract_topics(full_text)
838
+
839
+ # Optionally produce a comprehensive overview (uses abstractive pipeline)
840
+ if getattr(self.config, "comprehensive_overview", False):
841
+ try:
842
+ comp_overview, keywords = self.generate_comprehensive_summary(full_text, key_points, decisions, action_items, topics)
843
+ overview = comp_overview
844
+ except Exception:
845
+ keywords = []
846
+
847
+ ms = MeetingSummary(
848
+ overview=overview,
849
+ key_points=key_points,
850
+ decisions=decisions,
851
+ action_items=action_items,
852
+ topics=topics,
853
+ )
854
+ if 'keywords' in locals():
855
+ setattr(ms, 'keywords', keywords)
856
+ return ms
857
+
858
+
859
+ class BERTSummarizer:
860
+ """
861
+ Extractive Summarization using BERT sentence embeddings.
862
+
863
+ Selects most important sentences based on semantic similarity
864
+ to document centroid and other features.
865
+
866
+ Attributes:
867
+ config: SummarizationConfig object
868
+
869
+ Example:
870
+ >>> summarizer = BERTSummarizer()
871
+ >>> summary = summarizer.summarize(transcript_segments)
872
+ >>> print(summary.overview)
873
+ >>> print(summary.decisions)
874
+ """
875
+
876
+ def __init__(self, config: Optional[SummarizationConfig] = None):
877
+ """
878
+ Initialize BERTSummarizer.
879
+
880
+ Args:
881
+ config: SummarizationConfig object
882
+ """
883
+ self.config = config or SummarizationConfig()
884
+ self._model = None
885
+
886
+ def _load_model(self):
887
+ """Lazy load sentence transformer model"""
888
+ if self._model is None:
889
+ try:
890
+ from sentence_transformers import SentenceTransformer
891
+
892
+ print(f"[Summarizer] Loading model: {self.config.sentence_model_id}")
893
+
894
+ self._model = SentenceTransformer(self.config.sentence_model_id)
895
+
896
+ print("[Summarizer] Model loaded successfully")
897
+
898
+ except Exception as e:
899
+ print(f"[Summarizer] Warning: Could not load model: {e}")
900
+ print("[Summarizer] Using fallback mode")
901
+ self._model = "FALLBACK"
902
+
903
+ def _semantic_deduplicate(self, items: List[str], threshold: Optional[float] = None) -> List[str]:
904
+ """Delegate to AbstractiveSummarizer semantic dedup for compatibility."""
905
+ return AbstractiveSummarizer(self.config)._semantic_deduplicate(items, threshold)
906
+
907
+ def _semantic_dedup_action_items(self, actions: List[Dict[str, str]], threshold: Optional[float] = None) -> List[Dict[str, str]]:
908
+ """Delegate to AbstractiveSummarizer action-item dedup for compatibility."""
909
+ return AbstractiveSummarizer(self.config)._semantic_dedup_action_items(actions, threshold)
910
+
911
+ def _collapse_repeated_phrases(self, text: str, max_ngram: int = 6, min_repeats: int = 2) -> str:
912
+ """Delegates to module-level collapse helper for compatibility."""
913
+ return _collapse_repeated_phrases_global(text, max_ngram=max_ngram, min_repeats=min_repeats)
914
+
915
+ def summarize(self, transcript_segments: List[TranscriptSegment]) -> MeetingSummary:
916
+ """
917
+ Generate meeting summary from transcript.
918
+
919
+ Args:
920
+ transcript_segments: List of transcript segments with speaker info
921
+
922
+ Returns:
923
+ MeetingSummary with overview, key points, decisions, and action items
924
+ """
925
+ # If configuration prefers abstractive summarization, delegate to AbstractiveSummarizer
926
+ if getattr(self.config, "method", "extractive") == "abstractive":
927
+ try:
928
+ return AbstractiveSummarizer(self.config).summarize(transcript_segments)
929
+ except Exception as e:
930
+ print(
931
+ f"[Summarizer] Abstractive summarization failed, falling back to extractive: {e}"
932
+ )
933
+
934
+ self._load_model()
935
+
936
+ # Combine all text
937
+ full_text = " ".join([seg.text for seg in transcript_segments if seg.text])
938
+ # Clean up disfluencies and annotations commonly appearing in ASR output
939
+ full_text = re.sub(r"\[OVERLAP\]|\[NOISE\]|<.*?>", "", full_text)
940
+ full_text = re.sub(r"\s+", " ", full_text).strip()
941
+
942
+ if not full_text.strip():
943
+ return MeetingSummary(
944
+ overview="Tidak ada konten yang dapat diringkas.",
945
+ key_points=[],
946
+ decisions=[],
947
+ action_items=[],
948
+ )
949
+
950
+ # Get sentence-level metadata by merging speaker turns
951
+ sent_meta = self._get_sentences_with_meta(transcript_segments)
952
+
953
+ if not sent_meta:
954
+ return MeetingSummary(
955
+ overview="Tidak ada kalimat yang dapat diidentifikasi.",
956
+ key_points=[],
957
+ decisions=[],
958
+ action_items=[],
959
+ )
960
+
961
+ sentences = [s["text"] for s in sent_meta]
962
+
963
+ # Compute embeddings and select a diverse set of representative sentences via MMR
964
+ embeddings = self._compute_embeddings(sentences)
965
+ num_select = min(max(5, self.config.num_sentences + 2), len(sentences))
966
+
967
+ if embeddings is not None:
968
+ selected_idx = self._mmr_selection(sentences, embeddings, k=num_select)
969
+ key_sentences = [sentences[i] for i in selected_idx]
970
+ else:
971
+ # fallback: use earlier scoring
972
+ key_sentences = self._extract_key_sentences(sentences)
973
+
974
+ # Generate a multi-sentence overview with some ordering and cleaning
975
+ overview = self._generate_overview(key_sentences[:3])
976
+
977
+ # Optionally perform a light abstractive refinement on the extractive overview
978
+ if getattr(self.config, "do_abstractive_refinement", False):
979
+ try:
980
+ abs_sum = AbstractiveSummarizer(self.config)
981
+ abs_sum._load_model()
982
+ if abs_sum._pipeline is not None and overview:
983
+ out = abs_sum._pipeline(
984
+ overview,
985
+ max_length=getattr(self.config, "abstractive_refine_max_len", 80),
986
+ min_length=30,
987
+ truncation=True,
988
+ do_sample=False,
989
+ )
990
+ # Expect a single summary text
991
+ if isinstance(out, list) and out:
992
+ raw_overview = out[0].get("summary_text", overview).strip()
993
+ # Use AbstractiveSummarizer's cleaning & fallback logic
994
+ overview_cleaned, _ = abs_sum._clean_abstractive_output(raw_overview, full_text)
995
+ overview = overview_cleaned
996
+ except Exception:
997
+ # Fail silently and use extractive overview
998
+ pass
999
+
1000
+ # Build richer key points: include speaker attribution and short cleaned sentences
1001
+ key_points = []
1002
+ for i in selected_idx if embeddings is not None else list(range(len(key_sentences))):
1003
+ s = sentences[i]
1004
+ sp = sent_meta[i]["speaker_id"]
1005
+ # Short clean
1006
+ s_clean = re.sub(r"\s+", " ", s).strip()
1007
+ key_points.append(f"{s_clean} (oleh {sp})")
1008
+
1009
+ # Extract decisions using expanded context (look for decision keywords and enumerations)
1010
+ decisions = []
1011
+ seen_decisions = set()
1012
+ for i, s in enumerate(sentences):
1013
+ s_clean = re.sub(r"\s+", " ", s).strip()
1014
+ s_lower = s_clean.lower()
1015
+ if any(kw in s_lower for kw in self.config.decision_keywords) or re.match(
1016
+ r"^(pertama|kedua|ketiga|keempat|kelima)\b", s_lower
1017
+ ):
1018
+ context = self._expand_context_for_sentence(sent_meta, i, window=1)
1019
+ dec_text = re.sub(r"\[.*?\]", "", context)
1020
+ dec_text = re.sub(r"\s+", " ", dec_text).strip()
1021
+ # Truncate to a reasonable length (35 words) and remove trailing punctuation
1022
+ words = dec_text.split()
1023
+ dec_text = " ".join(words[:35]).rstrip(" ,.;:")
1024
+ if len(dec_text.split()) < 3:
1025
+ continue
1026
+ if dec_text and dec_text not in seen_decisions:
1027
+ decisions.append(dec_text)
1028
+ seen_decisions.add(dec_text)
1029
+
1030
+ # If no decisions found, try to extract from key_sentences
1031
+ if not decisions:
1032
+ for ks in key_sentences:
1033
+ if any(kw in ks.lower() for kw in self.config.decision_keywords):
1034
+ if ks not in seen_decisions:
1035
+ decisions.append(ks)
1036
+ seen_decisions.add(ks)
1037
+
1038
+ # Apply semantic deduplication to decisions
1039
+ try:
1040
+ decisions = self._semantic_deduplicate(decisions)
1041
+ except Exception:
1042
+ pass
1043
+
1044
+ # Extract action items at sentence level with speaker inference
1045
+ action_items = []
1046
+ seen_tasks = set()
1047
+ action_kw_re = re.compile(
1048
+ r"\b(" + "|".join([re.escape(k) for k in self.config.action_keywords]) + r")\b",
1049
+ flags=re.IGNORECASE,
1050
+ )
1051
+
1052
+ # verbs that indicate an actionable commitment (used to validate generic keyword matches)
1053
+ action_verbs_re = re.compile(r"\b(akan|harus|siapkan|bikin|buat|selesaikan|dikerjakan|tolong|mohon|harap)\b", flags=re.IGNORECASE)
1054
+
1055
+ for i, s in enumerate(sentences):
1056
+ text = re.sub(r"\[OVERLAP\]|\[NOISE\]|<.*?>", "", s).strip()
1057
+ if not text:
1058
+ continue
1059
+
1060
+ # explicit commit patterns
1061
+ commit_re = re.compile(
1062
+ r"\b(aku|saya|kami|kita|kamu)\b.*\b(bertanggung jawab|akan|saya akan|aku akan|aku akan membuat|kamu tolong|tolong|siapkan|bikin|harus|selesaikan|dikerjakan)\b",
1063
+ flags=re.IGNORECASE,
1064
+ )
1065
+
1066
+ owner = None
1067
+ task = None
1068
+
1069
+ if commit_re.search(text):
1070
+ owner = sent_meta[i]["speaker_id"]
1071
+ # try to isolate the actionable clause
1072
+ task = re.sub(
1073
+ r"^.*?\b(bertanggung jawab|akan|saya akan|aku akan|kamu tolong|tolong|siapkan|bikin|harus|selesaikan|dikerjakan)\b",
1074
+ "",
1075
+ text,
1076
+ flags=re.IGNORECASE,
1077
+ )
1078
+ task = task.strip(" .,:;-")
1079
+ if not task:
1080
+ task = text
1081
+
1082
+ elif action_kw_re.search(text):
1083
+ # Validate generic matches for actionability using helper
1084
+ if not self._is_actionable_text(text):
1085
+ continue
1086
+ owner = sent_meta[i]["speaker_id"]
1087
+ task = text
1088
+
1089
+ if task:
1090
+ # Normalize task text
1091
+ task = re.sub(
1092
+ r"^\s*(aku|saya|kami|kita|kamu)\b[:,\s]*", "", task, flags=re.IGNORECASE
1093
+ ).strip()
1094
+ task = re.sub(r"\s+", " ", task).strip(" .,:;-")
1095
+ if len(task.split()) < 3:
1096
+ continue
1097
+ filler_short = {"setuju", "oke", "ya", "nah", "betul"}
1098
+ if task.lower() in filler_short:
1099
+ continue
1100
+ key = task.lower()[:120]
1101
+ if key in seen_tasks:
1102
+ continue
1103
+ seen_tasks.add(key)
1104
+ action_items.append(
1105
+ {
1106
+ "owner": owner or "TBD",
1107
+ "task": task,
1108
+ "timestamp": f"{sent_meta[i]['start']:.1f}s",
1109
+ "due": "",
1110
+ }
1111
+ )
1112
+
1113
+ # Fall back to segment-level action extraction if none found
1114
+ if not action_items:
1115
+ action_items = self._extract_action_items(transcript_segments)
1116
+
1117
+ # Apply semantic deduplication to action items (merge owners when possible)
1118
+ try:
1119
+ action_items = self._semantic_dedup_action_items(action_items)
1120
+ except Exception:
1121
+ pass
1122
+
1123
+ # Extract topics (frequency-based) from cleaned full_text
1124
+ topics = self._extract_topics(full_text)
1125
+
1126
+ # Optionally produce a comprehensive overview (may use abstractive pipeline)
1127
+ if getattr(self.config, "comprehensive_overview", False):
1128
+ try:
1129
+ abs_s = AbstractiveSummarizer(self.config)
1130
+ comp_overview, keywords = abs_s.generate_comprehensive_summary(full_text, key_points, decisions, action_items, topics)
1131
+ overview = comp_overview
1132
+ except Exception:
1133
+ keywords = []
1134
+
1135
+ # Return comprehensive MeetingSummary
1136
+ ms = MeetingSummary(
1137
+ overview=overview,
1138
+ key_points=key_points,
1139
+ decisions=decisions,
1140
+ action_items=action_items,
1141
+ topics=topics,
1142
+ )
1143
+ if 'keywords' in locals():
1144
+ setattr(ms, 'keywords', keywords)
1145
+ return ms
1146
+
1147
+ def _split_sentences(self, text: str) -> List[str]:
1148
+ """Split text into sentences"""
1149
+ # Indonesian sentence splitting
1150
+ # Handle common abbreviations
1151
+ text = re.sub(r"([Dd]r|[Pp]rof|[Bb]pk|[Ii]bu|[Ss]dr|[Nn]o|[Hh]al)\.", r"\1<PERIOD>", text)
1152
+
1153
+ # Split on sentence-ending punctuation
1154
+ sentences = re.split(r"[.!?]+\s*", text)
1155
+
1156
+ # Restore periods in abbreviations
1157
+ sentences = [s.replace("<PERIOD>", ".") for s in sentences]
1158
+
1159
+ # Clean and filter
1160
+ cleaned = []
1161
+ for s in sentences:
1162
+ s = s.strip()
1163
+
1164
+ # Filter by length
1165
+ if len(s) < self.config.min_sentence_length:
1166
+ continue
1167
+ if len(s) > self.config.max_sentence_length:
1168
+ # Truncate very long sentences
1169
+ s = s[: self.config.max_sentence_length] + "..."
1170
+
1171
+ # Collapse trivial repeated fragments inside sentence
1172
+ s = self._collapse_repeated_phrases(s)
1173
+
1174
+ cleaned.append(s)
1175
+
1176
+ return cleaned
1177
+
1178
+ def _merge_speaker_turns(self, segments: List[TranscriptSegment]) -> List[Dict[str, Any]]:
1179
+ """Merge consecutive segments by the same speaker into 'turns' to provide more context.
1180
+
1181
+ Returns a list of dicts: {speaker_id, start, end, text, indices}
1182
+ """
1183
+ turns: List[Dict[str, Any]] = []
1184
+ for i, seg in enumerate(segments):
1185
+ if not seg.text or not seg.text.strip():
1186
+ continue
1187
+ # Clean common ASR artifacts and leading fillers
1188
+ text = re.sub(r"\[OVERLAP\]|\[NOISE\]|<.*?>", "", seg.text)
1189
+ text = re.sub(
1190
+ r"^\s*(oke|ya|nah|oke,|baik|sekarang|jadi)\b[\s,:-]*", "", text, flags=re.IGNORECASE
1191
+ )
1192
+ text = re.sub(r"\s+", " ", text).strip()
1193
+
1194
+ if not text:
1195
+ continue
1196
+
1197
+ if turns and turns[-1]["speaker_id"] == seg.speaker_id:
1198
+ turns[-1]["end"] = seg.end
1199
+ turns[-1]["text"] += " " + text
1200
+ turns[-1]["indices"].append(i)
1201
+ else:
1202
+ turns.append(
1203
+ {
1204
+ "speaker_id": seg.speaker_id,
1205
+ "start": seg.start,
1206
+ "end": seg.end,
1207
+ "text": text,
1208
+ "indices": [i],
1209
+ }
1210
+ )
1211
+ return turns
1212
+
1213
+ def _get_sentences_with_meta(self, segments: List[TranscriptSegment]) -> List[Dict[str, Any]]:
1214
+ """Split merged speaker turns into sentences and keep metadata."""
1215
+ turns = self._merge_speaker_turns(segments)
1216
+ sent_meta: List[Dict[str, Any]] = []
1217
+ for t in turns:
1218
+ sents = self._split_sentences(t["text"])
1219
+ for j, s in enumerate(sents):
1220
+ sent_meta.append(
1221
+ {
1222
+ "text": s,
1223
+ "speaker_id": t["speaker_id"],
1224
+ "start": t["start"],
1225
+ "end": t["end"],
1226
+ "turn_indices": t["indices"],
1227
+ "sent_idx_in_turn": j,
1228
+ }
1229
+ )
1230
+ return sent_meta
1231
+
1232
+ def _compute_embeddings(self, sentences: List[str]):
1233
+ """Compute sentence embeddings using sentence-transformers (lazy load)."""
1234
+ if not sentences:
1235
+ return None
1236
+ try:
1237
+ from sentence_transformers import SentenceTransformer
1238
+
1239
+ model = SentenceTransformer(self.config.sentence_model_id)
1240
+ embs = model.encode(sentences, show_progress_bar=False)
1241
+ return embs
1242
+ except Exception as e:
1243
+ print(f"[Summarizer] Embedding model error: {e}")
1244
+ return None
1245
+
1246
+ def _mmr_selection(
1247
+ self, sentences: List[str], embeddings, k: int = 5, lambda_param: float = 0.6
1248
+ ) -> List[int]:
1249
+ """Maximal Marginal Relevance (MMR) selection for diversity and coverage.
1250
+
1251
+ Returns list of selected sentence indices in original order.
1252
+ """
1253
+ import numpy as _np
1254
+
1255
+ if embeddings is None or len(sentences) <= k:
1256
+ return list(range(min(len(sentences), k)))
1257
+
1258
+ centroid = _np.mean(embeddings, axis=0)
1259
+ # similarity to centroid
1260
+ sim_to_centroid = _np.dot(embeddings, centroid) / (
1261
+ _np.linalg.norm(embeddings, axis=1) * (_np.linalg.norm(centroid) + 1e-8)
1262
+ )
1263
+
1264
+ selected = []
1265
+ candidate_indices = list(range(len(sentences)))
1266
+
1267
+ # pick the top similarity as first
1268
+ first = int(_np.argmax(sim_to_centroid))
1269
+ selected.append(first)
1270
+ candidate_indices.remove(first)
1271
+
1272
+ while len(selected) < k and candidate_indices:
1273
+ mmr_scores = []
1274
+ for idx in candidate_indices:
1275
+ sim_to_sel = max(
1276
+ [
1277
+ _np.dot(embeddings[idx], embeddings[s])
1278
+ / (_np.linalg.norm(embeddings[idx]) * _np.linalg.norm(embeddings[s]) + 1e-8)
1279
+ for s in selected
1280
+ ]
1281
+ )
1282
+ score = lambda_param * sim_to_centroid[idx] - (1 - lambda_param) * sim_to_sel
1283
+ mmr_scores.append((idx, score))
1284
+
1285
+ idx_best, _ = max(mmr_scores, key=lambda x: x[1])
1286
+ selected.append(idx_best)
1287
+ candidate_indices.remove(idx_best)
1288
+
1289
+ # return in original order
1290
+ selected_sorted = sorted(selected)
1291
+ return selected_sorted
1292
+
1293
+ def _expand_context_for_sentence(
1294
+ self, sent_meta: List[Dict[str, Any]], idx: int, window: int = 1
1295
+ ) -> str:
1296
+ """Return concatenated sentence with neighboring contextual sentences for better decision/action extraction."""
1297
+ start = max(0, idx - window)
1298
+ end = min(len(sent_meta), idx + window + 1)
1299
+ return " ".join([s["text"] for s in sent_meta[start:end]])
1300
+
1301
+ def _infer_owner_for_action(self, seg_index: int, sent_meta: List[Dict[str, Any]]) -> str:
1302
+ """Infer owner for an action by looking at the sentence speaker and recent explicit mentions."""
1303
+ # Prefer sentence speaker
1304
+ if 0 <= seg_index < len(sent_meta):
1305
+ return sent_meta[seg_index]["speaker_id"]
1306
+ return "TBD"
1307
+
1308
+ def _extract_key_sentences(self, sentences: List[str]) -> List[str]:
1309
+ """Extract most important sentences using BERT embeddings"""
1310
+ if not sentences:
1311
+ return []
1312
+
1313
+ # Fallback mode: simple heuristics
1314
+ if self._model == "FALLBACK" or len(sentences) <= self.config.num_sentences:
1315
+ return sentences[: self.config.num_sentences]
1316
+
1317
+ try:
1318
+ # Get sentence embeddings
1319
+ embeddings = self._model.encode(sentences, show_progress_bar=False)
1320
+
1321
+ # Calculate document centroid
1322
+ centroid = np.mean(embeddings, axis=0)
1323
+
1324
+ # Calculate importance scores for each sentence
1325
+ scores = []
1326
+
1327
+ for i, (sent, emb) in enumerate(zip(sentences, embeddings)):
1328
+ score = self._calculate_sentence_score(
1329
+ sentence=sent,
1330
+ embedding=emb,
1331
+ centroid=centroid,
1332
+ position=i,
1333
+ total_sentences=len(sentences),
1334
+ )
1335
+ scores.append((i, score, sent))
1336
+
1337
+ # Sort by score
1338
+ scores.sort(key=lambda x: x[1], reverse=True)
1339
+
1340
+ # Get top-k sentences (maintain original order)
1341
+ top_indices = sorted([s[0] for s in scores[: self.config.num_sentences]])
1342
+
1343
+ return [sentences[i] for i in top_indices]
1344
+
1345
+ except Exception as e:
1346
+ print(f"[Summarizer] Embedding extraction failed: {e}")
1347
+ return sentences[: self.config.num_sentences]
1348
+
1349
+ def _calculate_sentence_score(
1350
+ self,
1351
+ sentence: str,
1352
+ embedding: np.ndarray,
1353
+ centroid: np.ndarray,
1354
+ position: int,
1355
+ total_sentences: int,
1356
+ ) -> float:
1357
+ """Calculate importance score for a sentence"""
1358
+
1359
+ # 1. Cosine similarity to centroid
1360
+ similarity = np.dot(embedding, centroid) / (
1361
+ np.linalg.norm(embedding) * np.linalg.norm(centroid) + 1e-8
1362
+ )
1363
+
1364
+ # 2. Position score (favor beginning and end)
1365
+ if total_sentences > 1:
1366
+ normalized_pos = position / (total_sentences - 1)
1367
+ # U-shaped curve: high at start and end
1368
+ position_score = 1.0 - 0.6 * np.sin(np.pi * normalized_pos)
1369
+ else:
1370
+ position_score = 1.0
1371
+
1372
+ # 3. Length score (favor medium-length sentences)
1373
+ word_count = len(sentence.split())
1374
+ optimal_length = 20
1375
+ length_score = 1.0 - min(abs(word_count - optimal_length) / 30, 1.0)
1376
+
1377
+ # 4. Keyword bonus
1378
+ keyword_score = 0.0
1379
+ sentence_lower = sentence.lower()
1380
+
1381
+ for kw in self.config.decision_keywords + self.config.action_keywords:
1382
+ if kw in sentence_lower:
1383
+ keyword_score += 0.1
1384
+
1385
+ keyword_score = min(keyword_score, 0.3) # Cap bonus
1386
+
1387
+ # Combined score
1388
+ score = (
1389
+ self.config.similarity_weight * similarity
1390
+ + self.config.position_weight * position_score
1391
+ + self.config.length_weight * length_score
1392
+ + keyword_score
1393
+ )
1394
+
1395
+ return score
1396
+
1397
+ def _generate_overview(self, key_sentences: List[str]) -> str:
1398
+ """Generate overview from key sentences"""
1399
+ if not key_sentences:
1400
+ return "Tidak ada ringkasan yang dapat dibuat."
1401
+
1402
+ # Use top 2-3 sentences for overview
1403
+ overview_sentences = key_sentences[: min(3, len(key_sentences))]
1404
+ overview = " ".join(overview_sentences)
1405
+
1406
+ # Clean up
1407
+ overview = re.sub(r"\s+", " ", overview).strip()
1408
+
1409
+ return overview
1410
+
1411
+ def _extract_decisions(self, sentences: List[str]) -> List[str]:
1412
+ """Extract decision-related sentences and synthesize enumerated decisions.
1413
+
1414
+ This method collects sentence-level decision mentions, attempts to synthesize
1415
+ clauses from enumerated statements (e.g., "Pertama..., Kedua..."),
1416
+ and performs semantic deduplication to avoid repeated/near-duplicate items.
1417
+ """
1418
+ raw = []
1419
+
1420
+ for sent in sentences:
1421
+ sent_lower = sent.lower()
1422
+
1423
+ # Check for decision keywords
1424
+ if any(kw in sent_lower for kw in self.config.decision_keywords):
1425
+ # Clean the sentence
1426
+ clean_sent = re.sub(r"\s+", " ", sent).strip()
1427
+ if clean_sent and clean_sent not in raw:
1428
+ raw.append(clean_sent)
1429
+
1430
+ # Try to synthesize enumerated decisions from sentences
1431
+ synthesized = self._synthesize_enumerated_decisions(sentences)
1432
+
1433
+ all_decisions = raw + synthesized
1434
+
1435
+ # Deduplicate semantically (Jaccard over tokens)
1436
+ deduped = self._deduplicate_strings(all_decisions)
1437
+
1438
+ # Limit number of decisions returned
1439
+ return deduped[:7]
1440
+
1441
+ def _synthesize_enumerated_decisions(self, sentences: List[str]) -> List[str]:
1442
+ """Extract clauses following enumerations like 'Pertama..., Kedua...' and return list.
1443
+
1444
+ Handles both ordinal words (pertama, kedua, ...) and numbered lists (1., 2.)
1445
+ by splitting and returning non-trivial clauses.
1446
+ """
1447
+ synth: List[str] = []
1448
+ enum_words_re = re.compile(r"\b(pertama|kedua|ketiga|keempat|kelima)\b", flags=re.IGNORECASE)
1449
+
1450
+ for s in sentences:
1451
+ s_clean = s.strip()
1452
+ if enum_words_re.search(s_clean.lower()):
1453
+ # Split by Indonesian ordinal words
1454
+ parts = re.split(r"\bpertama\b|\bkedua\b|\bketiga\b|\bkeempat\b|\bkelima\b", s_clean, flags=re.IGNORECASE)
1455
+ for p in parts:
1456
+ p = p.strip(" .,:;\n-–—")
1457
+ if len(p.split()) >= 3 and p not in synth:
1458
+ synth.append(p)
1459
+
1460
+ # Also handle simple numbered enumerations like '1. ... 2. ...'
1461
+ if re.search(r"\d+\.\s*", s_clean):
1462
+ parts = re.split(r"\d+\.\s*", s_clean)
1463
+ for p in parts:
1464
+ p = p.strip(" .,:;\n-–—")
1465
+ if len(p.split()) >= 3 and p not in synth:
1466
+ synth.append(p)
1467
+
1468
+ return synth
1469
+
1470
+ def _normalize_text_for_dedup(self, text: str) -> str:
1471
+ """Normalize text for lightweight semantic deduplication."""
1472
+ t = text.lower()
1473
+ # remove punctuation, keep alphanumerics and spaces
1474
+ t = re.sub(r"[^a-z0-9\s]+", "", t)
1475
+ t = re.sub(r"\s+", " ", t).strip()
1476
+ return t
1477
+
1478
+ def _deduplicate_strings(self, items: List[str], threshold: float = 0.5) -> List[str]:
1479
+ """Deduplicate items using token Jaccard similarity threshold."""
1480
+ kept: List[str] = []
1481
+ norms: List[str] = []
1482
+
1483
+ for it in items:
1484
+ n = self._normalize_text_for_dedup(it)
1485
+ if not n:
1486
+ continue
1487
+ toks1 = set(n.split())
1488
+ is_dup = False
1489
+ for other in norms:
1490
+ toks2 = set(other.split())
1491
+ if not toks1 or not toks2:
1492
+ continue
1493
+ inter = len(toks1 & toks2)
1494
+ union = len(toks1 | toks2)
1495
+ if union > 0 and (inter / union) >= threshold:
1496
+ is_dup = True
1497
+ break
1498
+ if not is_dup:
1499
+ kept.append(it)
1500
+ norms.append(n)
1501
+
1502
+ return kept
1503
+
1504
+ def _extract_action_items(self, segments: List[TranscriptSegment]) -> List[Dict[str, str]]:
1505
+ """Extract action items with speaker attribution (improved heuristics)
1506
+
1507
+ Heuristics:
1508
+ - Detect explicit commitments like "aku akan", "saya bertanggung jawab", "kamu siapkan" and assign owner
1509
+ - Fallback to keyword-based detection
1510
+ - Normalize duplicate tasks and detect simple due-date mentions like "minggu depan", "besok"
1511
+ - Try to infer explicit owner names mentioned in the clause
1512
+ """
1513
+ action_items: List[Dict[str, str]] = []
1514
+ seen_tasks = set()
1515
+
1516
+ # Try to use AdvancedNLPExtractor (NER + dependency parse) for higher-quality extraction
1517
+ try:
1518
+ from src.nlp_utils import AdvancedNLPExtractor
1519
+
1520
+ extractor = AdvancedNLPExtractor()
1521
+ sent_meta = self._get_sentences_with_meta(segments)
1522
+ nlp_actions = extractor.extract_actions_from_sentences(sent_meta)
1523
+ for item in nlp_actions:
1524
+ task_key = item.get("task", "").lower()[:120]
1525
+ if task_key in seen_tasks:
1526
+ continue
1527
+ seen_tasks.add(task_key)
1528
+ action_items.append(
1529
+ {
1530
+ "owner": item.get("owner", "TBD"),
1531
+ "task": item.get("task", "").strip(),
1532
+ "timestamp": f"{sent_meta[item.get('sentence_idx', 0)]['start']:.1f}s",
1533
+ "due": self._detect_due_from_text(item.get("task", "")),
1534
+ }
1535
+ )
1536
+ except Exception:
1537
+ extractor = None
1538
+
1539
+ commit_re = re.compile(
1540
+ r"\b(aku|saya|kami|kita|kamu)\b.*\b(bertanggung jawab|akan|saya akan|aku akan|aku akan membuat|kamu tolong|tolong|siapkan|bikin|harus|selesaikan|dikerjakan)\b",
1541
+ flags=re.IGNORECASE,
1542
+ )
1543
+
1544
+ # Actionable verbs/phrases to validate generic keyword matches
1545
+ _action_verbs_re = re.compile(r"\b(akan|harus|siapkan|bikin|buat|selesaikan|dikerjakan|tolong|mohon|harap)\b", flags=re.IGNORECASE)
1546
+
1547
+ for seg in segments:
1548
+ if not seg.text:
1549
+ continue
1550
+
1551
+ text = re.sub(r"\[OVERLAP\]|\[NOISE\]|<.*?>", "", seg.text).strip()
1552
+ text_lower = text.lower()
1553
+
1554
+ # 1) explicit commitment patterns
1555
+ if commit_re.search(text_lower):
1556
+ # Try to extract short actionable clause
1557
+ task = re.sub(
1558
+ r"^.*?(bertanggung jawab|akan|membuat|siapkan|tolong|saya akan|aku akan|kamu tolong)\b",
1559
+ "",
1560
+ text,
1561
+ flags=re.IGNORECASE,
1562
+ )
1563
+ task = task.strip(" .,:;-")
1564
+ if not task:
1565
+ # fallback to whole segment
1566
+ task = text
1567
+
1568
+ # Try to detect explicit owner name within the clause (e.g., "Budi akan ...")
1569
+ owner = self._extract_name_as_owner(text) or seg.speaker_id
1570
+
1571
+ task_key = task.lower()[:120]
1572
+ if task_key not in seen_tasks:
1573
+ seen_tasks.add(task_key)
1574
+ action_items.append(
1575
+ {
1576
+ "owner": owner,
1577
+ "task": task,
1578
+ "timestamp": f"{seg.start:.1f}s",
1579
+ "due": self._detect_due_from_text(task),
1580
+ }
1581
+ )
1582
+ continue
1583
+
1584
+ # 2) keyword-based detection
1585
+ if any(kw in text_lower for kw in self.config.action_keywords):
1586
+ # Validate that the segment is actionable (has verbs like 'akan'/'perlu' or explicit name)
1587
+ if not self._is_actionable_text(text):
1588
+ continue
1589
+
1590
+ owner = self._extract_name_as_owner(text) or seg.speaker_id
1591
+ task = text.strip()
1592
+ task_key = task.lower()[:120]
1593
+ if task_key in seen_tasks:
1594
+ continue
1595
+ seen_tasks.add(task_key)
1596
+ action_items.append(
1597
+ {
1598
+ "owner": owner,
1599
+ "task": task,
1600
+ "timestamp": f"{seg.start:.1f}s",
1601
+ "due": self._detect_due_from_text(task),
1602
+ }
1603
+ )
1604
+
1605
+ # Post-process: deduplicate semantically and filter tiny filler tasks
1606
+ processed: List[Dict[str, str]] = []
1607
+ seen_norms = set()
1608
+
1609
+ # Filter out filler / non-actionable phrases (e.g., meeting start/thanks)
1610
+ filler_patterns = [
1611
+ r"\bkita mulai rapat",
1612
+ r"\bitu yang mau kita bahas",
1613
+ r"\bterima kasih",
1614
+ r"\bok(e|ey)?\b",
1615
+ r"\bsip\b",
1616
+ r"\bcukup(kan)? sampai",
1617
+ r"\btidak ada( yang)?\b",
1618
+ r"\biya\b",
1619
+ r"\bsetuju\b",
1620
+ ]
1621
+ filler_re = re.compile("|".join(filler_patterns), flags=re.IGNORECASE)
1622
+
1623
+ for it in action_items:
1624
+ task_text = it.get("task", "")
1625
+
1626
+ # Skip common non-actionable conversational lines
1627
+ if filler_re.search(task_text):
1628
+ continue
1629
+
1630
+ # Ensure the sentence is actionable (has a commitment verb or explicit owner/name)
1631
+ if not self._is_actionable_text(task_text):
1632
+ continue
1633
+
1634
+ norm = self._normalize_text_for_dedup(task_text)[:200]
1635
+ # skip if too short
1636
+ if len(task_text.split()) < 3:
1637
+ continue
1638
+ if norm in seen_norms:
1639
+ continue
1640
+ seen_norms.add(norm)
1641
+ processed.append(it)
1642
+
1643
+ # Limit number of action items
1644
+ return processed[:15]
1645
+
1646
+ def _detect_due_from_text(self, text: str) -> str:
1647
+ """Detect simple due-date hints from text and return a short normalized due string."""
1648
+ t = text.lower()
1649
+ if "besok" in t:
1650
+ return "besok"
1651
+ if "segera" in t or "secepat" in t or "sekarang" in t:
1652
+ return "segera"
1653
+ if "minggu depan" in t:
1654
+ return "1 minggu"
1655
+ m = re.search(r"(\d+)\s*minggu", t)
1656
+ if m:
1657
+ return f"{m.group(1)} minggu"
1658
+ if "2 minggu" in t or "dua minggu" in t:
1659
+ return "2 minggu"
1660
+ if "deadline" in t:
1661
+ # try to capture a following date/token
1662
+ m2 = re.search(r"deadline\s*[:\-\s]*([\w\-\./]+)", t)
1663
+ return m2.group(1) if m2 else "TBD"
1664
+ return ""
1665
+
1666
+ def _extract_name_as_owner(self, text: str) -> Optional[str]:
1667
+ """Return a candidate owner name if a capitalized proper name is explicitly present in the clause.
1668
+
1669
+ Simple heuristic: look for capitalized words (not at sentence start if it's a pronoun) followed by 'akan' or similar.
1670
+ """
1671
+ m = re.search(r"\b([A-Z][a-z]{2,})\b(?=\s+akan|\s+siapkan|\s+tolong|\s+bisa|\s+bertanggung)", text)
1672
+ if m:
1673
+ return m.group(1)
1674
+ return None
1675
+
1676
+ def _is_actionable_text(self, text: str) -> bool:
1677
+ """Return True if text contains indicators of an actionable commitment.
1678
+
1679
+ Indicators:
1680
+ - Commitment verbs (akan, harus, perlu, siapkan, dll.)
1681
+ - Explicit owner mention (capitalized name)
1682
+ - Time indicators / deadlines (besok, minggu depan, deadline)
1683
+ """
1684
+ t = text or ""
1685
+ tl = t.lower()
1686
+ if re.search(r"\b(akan|harus|siapkan|bikin|buat|selesaikan|dikerjakan|tolong|mohon|harap|perlu)\b", tl):
1687
+ return True
1688
+ # Only consider capitalized names as indicators if followed by an action verb
1689
+ if re.search(r"\b([A-Z][a-z]{2,})\b(?=\s+(akan|siapkan|tolong|mohon|harus|selesaikan|buat|bikin))", t):
1690
+ return True
1691
+ if any(k in tl for k in ("deadline", "minggu depan", "besok")):
1692
+ return True
1693
+ return False
1694
+
1695
+ def _extract_topics(self, text: str, num_topics: int = 5) -> List[str]:
1696
+ """Extract main topics from text using simple frequency analysis"""
1697
+ # Simple word frequency approach
1698
+ # Remove common Indonesian stopwords
1699
+ stopwords = {
1700
+ "yang",
1701
+ "dan",
1702
+ "di",
1703
+ "ke",
1704
+ "dari",
1705
+ "ini",
1706
+ "itu",
1707
+ "dengan",
1708
+ "untuk",
1709
+ "pada",
1710
+ "adalah",
1711
+ "dalam",
1712
+ "tidak",
1713
+ "akan",
1714
+ "sudah",
1715
+ "juga",
1716
+ "saya",
1717
+ "kita",
1718
+ "kami",
1719
+ "mereka",
1720
+ "ada",
1721
+ "bisa",
1722
+ "atau",
1723
+ "seperti",
1724
+ "jadi",
1725
+ "kalau",
1726
+ "karena",
1727
+ "tapi",
1728
+ "ya",
1729
+ "apa",
1730
+ "bagaimana",
1731
+ "kenapa",
1732
+ "siapa",
1733
+ "kapan",
1734
+ "dimana",
1735
+ "nya",
1736
+ "kan",
1737
+ "dong",
1738
+ "sih",
1739
+ "kok",
1740
+ "deh",
1741
+ "loh",
1742
+ "lah",
1743
+ }
1744
+
1745
+ # Tokenize and count
1746
+ words = re.findall(r"\b[a-zA-Z]{4,}\b", text.lower())
1747
+ word_counts = {}
1748
+
1749
+ for word in words:
1750
+ if word not in stopwords:
1751
+ word_counts[word] = word_counts.get(word, 0) + 1
1752
+
1753
+ # Sort by frequency
1754
+ sorted_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
1755
+
1756
+ # Return top topics
1757
+ return [word for word, count in sorted_words[:num_topics]]
1758
+
1759
+ def summarize_by_speaker(self, segments: List[TranscriptSegment]) -> Dict[str, str]:
1760
+ """Generate per-speaker summary"""
1761
+ # Group segments by speaker
1762
+ speaker_texts = {}
1763
+
1764
+ for seg in segments:
1765
+ if seg.speaker_id not in speaker_texts:
1766
+ speaker_texts[seg.speaker_id] = []
1767
+ speaker_texts[seg.speaker_id].append(seg.text)
1768
+
1769
+ # Summarize each speaker's contribution
1770
+ speaker_summaries = {}
1771
+
1772
+ for speaker_id, texts in speaker_texts.items():
1773
+ full_text = " ".join(texts)
1774
+ sentences = self._split_sentences(full_text)
1775
+
1776
+ if sentences:
1777
+ # Get top 2 sentences for each speaker
1778
+ key_sentences = self._extract_key_sentences(sentences)[:2]
1779
+ speaker_summaries[speaker_id] = " ".join(key_sentences)
1780
+ else:
1781
+ speaker_summaries[speaker_id] = "Tidak ada kontribusi yang dapat diringkas."
1782
+
1783
+ return speaker_summaries
src/transcriber.py ADDED
@@ -0,0 +1,1108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ASR Transcription Module
3
+ ========================
4
+ Implements speech-to-text with configurable backends (Whisper, Wav2Vec2).
5
+ Default is Whisper-base for multilingual support; supports beam CTC decoding for CTC models.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ import re
12
+ from dataclasses import dataclass, field
13
+ from pathlib import Path
14
+ from typing import Any, Callable, Dict, List, Optional
15
+
16
+ import numpy as np
17
+ import torch
18
+
19
+ from src.diarization import SpeakerSegment
20
+ from src.utils import setup_logger
21
+
22
+
23
+ @dataclass
24
+ class ASRConfig:
25
+ """Configuration for ASR"""
26
+
27
+ model_id: str = "openai/whisper-small"
28
+ chunk_length_s: float = 30.0
29
+ stride_length_s: float = 5.0
30
+ batch_size: int = 4
31
+ return_timestamps: Optional[str] = None # None or 'char'/'word'
32
+
33
+ # Approximate Continuous Speech Tokenizer token rate in Hz (e.g., 7.5). When set,
34
+ # the transcriber will apply a fast lossy compression preprocessor for speed.
35
+ # Default: disabled (None). Use --cst-hz to enable.
36
+ cst_hz: Optional[float] = None
37
+
38
+ # Backend options:
39
+ # - 'whisper': HuggingFace transformers ASR pipeline (seq2seq whisper)
40
+ # - 'transformers': HuggingFace transformers ASR pipeline (CTC wav2vec2, etc)
41
+ # - 'whisperx': WhisperX (faster-whisper + optional alignment; we use transcription + segments)
42
+ # - 'speechbrain': SpeechBrain adapter
43
+ backend: str = "whisper"
44
+
45
+ # Preferred language for whisper (use 'id' for Indonesian)
46
+ language: str = "id"
47
+
48
+ # WhisperX options
49
+ # compute_type common values: "float16" (GPU), "int8" / "int8_float16" (lower VRAM)
50
+ whisperx_compute_type: str = "auto"
51
+ whisperx_vad_filter: bool = True
52
+
53
+ # Use full-audio ASR and align timestamps to diarization segments if available
54
+ use_full_audio_for_segments: bool = False
55
+
56
+ # Quick mode (single-pass full audio + reduced precision) and parallelism
57
+ quick_mode: bool = False
58
+ parallel_workers: int = 4
59
+
60
+ # When not using full-audio timestamps, include a small context window around short segments
61
+ context_window_s: float = 0.5
62
+
63
+ # Decoder options: 'greedy' or 'beam' (beam can use pyctcdecode + kenlm)
64
+ decoder: str = "greedy"
65
+ beam_width: int = 10
66
+ use_lm: bool = False
67
+ lm_path: Optional[str] = None
68
+
69
+ # Text post-processing
70
+ capitalize_sentences: bool = True
71
+ normalize_whitespace: bool = True
72
+ add_punctuation: bool = False
73
+
74
+ # Device
75
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
76
+
77
+
78
+ @dataclass
79
+ class TranscriptSegment:
80
+ """Transcript segment with speaker and timing information"""
81
+
82
+ speaker_id: str
83
+ start: float
84
+ end: float
85
+ text: str
86
+ confidence: float = 1.0
87
+ is_overlap: bool = False
88
+ language: str = "id"
89
+ metadata: Dict[str, Any] = field(default_factory=dict)
90
+
91
+ @property
92
+ def duration(self) -> float:
93
+ """Get segment duration in seconds"""
94
+ return self.end - self.start
95
+
96
+ @property
97
+ def word_count(self) -> int:
98
+ """Get number of words in text"""
99
+ return len(self.text.split()) if self.text else 0
100
+
101
+ def to_dict(self) -> Dict[str, Any]:
102
+ """Convert to dictionary"""
103
+ return {
104
+ "speaker_id": self.speaker_id,
105
+ "start": self.start,
106
+ "end": self.end,
107
+ "text": self.text,
108
+ "confidence": self.confidence,
109
+ "is_overlap": self.is_overlap,
110
+ "duration": self.duration,
111
+ "word_count": self.word_count,
112
+ }
113
+
114
+
115
+ class ASRTranscriber:
116
+ """
117
+ Automatic Speech Recognition using Wav2Vec2-XLSR. Supports multiple backends including
118
+ HuggingFace `transformers` pipeline and optional SpeechBrain adapter.
119
+
120
+ Transcribes audio segments with speaker information.
121
+ Optimized for Indonesian language with code-switching support.
122
+
123
+ Attributes:
124
+ config: ASRConfig object
125
+
126
+ Example:
127
+ >>> transcriber = ASRTranscriber()
128
+ >>> segments = transcriber.transcribe_segments(waveform, diarization_segments)
129
+ >>> for seg in segments:
130
+ ... print(f"{seg.speaker_id}: {seg.text}")
131
+ """
132
+
133
+ def __init__(self, config: Optional[ASRConfig] = None, models_dir: str = "./models"):
134
+ """
135
+ Initialize ASRTranscriber.
136
+
137
+ Args:
138
+ config: ASRConfig object
139
+ models_dir: Directory to cache downloaded models
140
+ """
141
+ self.config = config or ASRConfig()
142
+ self.models_dir = Path(models_dir)
143
+ self.models_dir.mkdir(parents=True, exist_ok=True)
144
+
145
+ self.device = self.config.device
146
+
147
+ # Setup logger
148
+ self.logger = setup_logger("ASRTranscriber")
149
+ # Log configured CST value for diagnostics
150
+ try:
151
+ self.logger.info(f"ASRTranscriber configured cst_hz: {getattr(self.config, 'cst_hz', None)} Hz")
152
+ except Exception:
153
+ pass
154
+
155
+ # Model placeholders (lazy loading)
156
+ self._pipeline = None
157
+ self._processor = None
158
+ self._model = None
159
+ self._speechbrain_adapter = None
160
+ self._whisperx_model = None
161
+
162
+ def _load_model(self):
163
+ """Lazy load ASR model and pipeline"""
164
+ # If user configured SpeechBrain backend, prefer it
165
+ if getattr(self.config, "backend", "whisper") == "speechbrain":
166
+ if self._speechbrain_adapter is None:
167
+ try:
168
+ from .transcriber_speechbrain import (
169
+ SpeechBrainASRConfig,
170
+ SpeechBrainTranscriber,
171
+ )
172
+
173
+ sb_cfg = SpeechBrainASRConfig(model_id=self.config.model_id, device=self.device)
174
+ self._speechbrain_adapter = SpeechBrainTranscriber(
175
+ sb_cfg, models_dir=str(self.models_dir)
176
+ )
177
+ self.logger.info(
178
+ f"SpeechBrain adapter initialized with model: {self.config.model_id}"
179
+ )
180
+ except Exception as e:
181
+ self.logger.warning(f"Could not initialize SpeechBrain adapter: {e}")
182
+ self._speechbrain_adapter = None
183
+ return
184
+
185
+ # WhisperX backend
186
+ if getattr(self.config, "backend", None) == "whisperx":
187
+ if self._whisperx_model is None:
188
+ try:
189
+ # WhisperX imports torchaudio.AudioMetaData (not present in some builds, e.g., torchaudio 2.8 CPU on Windows)
190
+ import torchaudio
191
+
192
+ if not hasattr(torchaudio, "AudioMetaData"):
193
+ from typing import NamedTuple
194
+
195
+ class AudioMetaData(NamedTuple):
196
+ sample_rate: int
197
+ num_frames: int
198
+ num_channels: int
199
+ bits_per_sample: int = 16
200
+ encoding: str = "PCM_S"
201
+
202
+ # Provide stub to satisfy downstream imports; uses safe defaults
203
+ torchaudio.AudioMetaData = AudioMetaData # type: ignore
204
+
205
+ import whisperx # type: ignore
206
+
207
+ # Allowlist OmegaConf ListConfig for torch.load (needed since PyTorch 2.6 weights_only=True)
208
+ try:
209
+ import typing
210
+
211
+ import torch.serialization as ts
212
+ from omegaconf.base import ContainerMetadata # type: ignore
213
+ from omegaconf.listconfig import ListConfig # type: ignore
214
+
215
+ # Allow torch.load with weights_only=True to unpickle HF configs that store plain list
216
+ # Allowlist common builtin types and container types used inside HF checkpoints
217
+ ts.add_safe_globals([dict, list, int, float, str, tuple, set])
218
+
219
+ # Add collections.defaultdict (needed by some HF checkpoints under newer PyTorch)
220
+ import collections
221
+
222
+ ts.add_safe_globals([collections.defaultdict])
223
+
224
+ # Ensure OmegaConf ListConfig is allowlisted (common in HF configs)
225
+ ts.add_safe_globals([ListConfig])
226
+
227
+ # Allow AnyNode from OmegaConf which some HF configs embed
228
+ try:
229
+ from omegaconf.nodes import AnyNode # type: ignore
230
+
231
+ ts.add_safe_globals([AnyNode])
232
+ except Exception:
233
+ # Not strictly fatal; continue if import fails
234
+ pass
235
+
236
+ # Some checkpoints include TorchVersion objects
237
+ try:
238
+ import torch
239
+
240
+ ts.add_safe_globals([torch.torch_version.TorchVersion])
241
+ except Exception:
242
+ pass
243
+
244
+ # Add ContainerMetadata and Metadata from OmegaConf if present
245
+ try:
246
+ from omegaconf.base import Metadata # type: ignore
247
+
248
+ ts.add_safe_globals([ContainerMetadata, Metadata, typing.Any])
249
+ except Exception:
250
+ ts.add_safe_globals([ContainerMetadata, typing.Any])
251
+ except Exception as e:
252
+ self.logger.warning(f"Could not add ListConfig to torch safe globals: {e}")
253
+
254
+ model_name_or_path = self.config.model_id
255
+ p = Path(str(model_name_or_path))
256
+ if p.exists() and p.is_dir():
257
+ # WhisperX (faster-whisper / CTranslate2) expects a CT2-converted model directory
258
+ # containing model.bin + config files. A folder with only *.safetensors is a
259
+ # HuggingFace Transformers checkpoint and cannot be loaded directly by WhisperX.
260
+ has_model_bin = (p / "model.bin").exists()
261
+ has_safetensors = any(p.glob("*.safetensors"))
262
+ if not has_model_bin and has_safetensors:
263
+ raise RuntimeError(
264
+ "WhisperX backend membutuhkan model format CTranslate2 (ada file 'model.bin'). "
265
+ f"Folder '{p.as_posix()}' hanya berisi *.safetensors (format Transformers), jadi "
266
+ "tidak bisa dipakai langsung oleh WhisperX. "
267
+ "Solusi: pakai nama model WhisperX seperti 'large-v3-turbo' agar auto-download, "
268
+ "atau convert model Transformers -> CTranslate2 memakai ctranslate2 converter."
269
+ )
270
+
271
+ compute_type = getattr(self.config, "whisperx_compute_type", "auto")
272
+ if compute_type == "auto":
273
+ # Sensible default: float16 on CUDA, int8 on CPU
274
+ compute_type = "float16" if self.device == "cuda" else "int8"
275
+
276
+ # WhisperX uses faster-whisper under the hood; model can be a name ("large-v3", "large-v3-turbo")
277
+ # or a local directory containing model weights (e.g. safetensors).
278
+ self.logger.info(
279
+ f"Loading WhisperX model: {model_name_or_path} (device={self.device}, compute_type={compute_type})"
280
+ )
281
+
282
+ # Robust loading: try to parse WeightsUnpickler errors and auto-allowlist missing globals
283
+ def _load_model_with_retry():
284
+ import importlib
285
+ import re
286
+
287
+ import torch.serialization as ts
288
+
289
+ max_attempts = 8
290
+ attempt = 0
291
+ while True:
292
+ try:
293
+ return whisperx.load_model(
294
+ model_name_or_path,
295
+ device=self.device,
296
+ compute_type=compute_type,
297
+ download_root=str(self.models_dir),
298
+ )
299
+ except Exception as e:
300
+ attempt += 1
301
+ if attempt >= max_attempts:
302
+ # Give up and re-raise the original exception
303
+ raise
304
+ msg = str(e)
305
+ # Find module.Class patterns in the error message
306
+ missing = set(
307
+ re.findall(
308
+ r"GLOBAL\s+([\w\.]+)\s+was not an allowed global", msg
309
+ )
310
+ )
311
+ # Also catch suggestions in the message
312
+ more = set(re.findall(r"add_safe_globals\(\[([^\]]+)\]\)", msg))
313
+ for m in more:
314
+ # split comma-separated list like 'collections.defaultdict' or 'omegaconf.nodes.AnyNode'
315
+ parts = [
316
+ p.strip().strip("\"''") for p in m.split(",") if p.strip()
317
+ ]
318
+ missing.update(parts)
319
+
320
+ if not missing:
321
+ # nothing we can do programmatically
322
+ raise
323
+
324
+ for fullname in missing:
325
+ try:
326
+ module_name, cls_name = fullname.rsplit(".", 1)
327
+ mod = importlib.import_module(module_name)
328
+ cls = getattr(mod, cls_name)
329
+ ts.add_safe_globals([cls])
330
+ self.logger.info(
331
+ f"Auto-added {fullname} to torch safe globals"
332
+ )
333
+ except Exception as ie:
334
+ self.logger.warning(
335
+ f"Could not auto-add {fullname} to safe globals: {ie}"
336
+ )
337
+ # retry loop
338
+
339
+ self._whisperx_model = _load_model_with_retry()
340
+ self.logger.info("WhisperX model loaded successfully")
341
+ except Exception as e:
342
+ # When user explicitly requests WhisperX backend, fail loudly with a helpful message.
343
+ self._whisperx_model = None
344
+ raise RuntimeError(f"Failed to load WhisperX model: {e}") from e
345
+
346
+ if self._pipeline is None:
347
+ # If user explicitly selected WhisperX and the WhisperX model loaded OK,
348
+ # prefer WhisperX and skip attempting the Transformers pipeline which may
349
+ # not recognize model names like 'large-v3-turbo' and produce confusing errors.
350
+ if (
351
+ getattr(self.config, "backend", None) == "whisperx"
352
+ and self._whisperx_model is not None
353
+ ):
354
+ self._pipeline = "WHISPERX"
355
+ self.logger.info("WhisperX backend active; skipping Transformers pipeline load")
356
+ else:
357
+ try:
358
+ from transformers import pipeline
359
+
360
+ self.logger.info(f"Loading model: {self.config.model_id}")
361
+
362
+ # Try to use pipeline first (simpler)
363
+ self._pipeline = pipeline(
364
+ "automatic-speech-recognition",
365
+ model=self.config.model_id,
366
+ device=0 if self.device == "cuda" and torch.cuda.is_available() else -1,
367
+ chunk_length_s=self.config.chunk_length_s,
368
+ stride_length_s=(self.config.stride_length_s, self.config.stride_length_s),
369
+ )
370
+
371
+ self.logger.info("Model loaded successfully via pipeline")
372
+
373
+ except Exception as e:
374
+ self.logger.warning(f"Pipeline loading failed: {e}")
375
+ self.logger.info("Attempting direct model loading...")
376
+
377
+ # Attempt direct transformers model loading (Wav2Vec2)
378
+ try:
379
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
380
+
381
+ self._processor = Wav2Vec2Processor.from_pretrained(
382
+ self.config.model_id, cache_dir=str(self.models_dir)
383
+ )
384
+ self._model = Wav2Vec2ForCTC.from_pretrained(
385
+ self.config.model_id, cache_dir=str(self.models_dir)
386
+ )
387
+
388
+ if self.device == "cuda" and torch.cuda.is_available():
389
+ self._model = self._model.cuda()
390
+
391
+ self._model.eval()
392
+ self.logger.info("Model loaded successfully via direct loading")
393
+
394
+ # If user requested beam decoding, try to prepare a CTC beam decoder (pyctcdecode)
395
+ self._ctc_decoder = None
396
+ try:
397
+ if self.config.decoder == "beam":
398
+ from pyctcdecode import build_ctcdecoder
399
+
400
+ # Build label list from tokenizer vocab ordered by id
401
+ vocab = self._processor.tokenizer.get_vocab()
402
+ labels = [t for t, _ in sorted(vocab.items(), key=lambda x: x[1])]
403
+
404
+ if self.config.use_lm and self.config.lm_path:
405
+ self.logger.info("Building CTC decoder with LM...")
406
+ self._ctc_decoder = build_ctcdecoder(
407
+ labels, self.config.lm_path
408
+ )
409
+ else:
410
+ self.logger.info("Building CTC decoder (no LM)")
411
+ self._ctc_decoder = build_ctcdecoder(labels)
412
+
413
+ self.logger.info("CTC decoder ready")
414
+ except Exception as e:
415
+ self.logger.warning(
416
+ f"Could not build CTC decoder (pyctcdecode/kenlm missing or failed): {e}"
417
+ )
418
+ self._ctc_decoder = None
419
+
420
+ except Exception as e2:
421
+ self.logger.error(f"Direct loading also failed: {e2}")
422
+ self.logger.warning("Using fallback placeholder mode")
423
+ self._pipeline = "FALLBACK"
424
+
425
+ def transcribe_segments(
426
+ self,
427
+ waveform: torch.Tensor,
428
+ segments: List[SpeakerSegment],
429
+ sample_rate: int = 16000,
430
+ progress_callback: Optional[Callable[[int, int], None]] = None,
431
+ ) -> List[TranscriptSegment]:
432
+ """
433
+ Transcribe each speaker segment. If `use_full_audio_for_segments` is enabled,
434
+ run ASR once on the full audio and map word/segment timestamps back to
435
+ the diarization segments when the ASR pipeline returns timestamps.
436
+ Falls back to context-augmented per-segment transcription when timestamps
437
+ are not available.
438
+ """
439
+ try:
440
+ self._load_model()
441
+ except Exception as e:
442
+ # If loading the configured ASR backend fails (common when deployment preset
443
+ # forced WhisperX but model_id is a Transformers repo), attempt a safe
444
+ # runtime fallback to a lightweight Whisper model so interactive UI flows
445
+ # remain responsive instead of crashing.
446
+ self.logger.error(
447
+ f"ASR model load failed: {e}. Attempting fallback to 'whisper' backend with 'openai/whisper-small'."
448
+ )
449
+ try:
450
+ self.config.backend = "whisper"
451
+ self.config.model_id = "openai/whisper-small"
452
+ # Clear any partially-initialized model state
453
+ self._pipeline = None
454
+ self._model = None
455
+ self._processor = None
456
+ self._whisperx_model = None
457
+ self._load_model()
458
+ self.logger.info("Fallback ASR model loaded successfully (openai/whisper-small)")
459
+ except Exception as e2:
460
+ self.logger.error(f"Fallback ASR model load also failed: {e2}")
461
+ # Re-raise to let caller handle/report the error
462
+ raise
463
+
464
+ # If SpeechBrain backend adapter is configured, delegate to it
465
+ if (
466
+ getattr(self.config, "backend", None) == "speechbrain"
467
+ and getattr(self, "_speechbrain_adapter", None) is not None
468
+ ):
469
+ try:
470
+ sb_res = self._speechbrain_adapter.transcribe_segments(
471
+ waveform, segments, sample_rate
472
+ )
473
+ for s in sb_res:
474
+ s.text = self._postprocess_text(s.text)
475
+ return sb_res
476
+ except Exception as e:
477
+ self.logger.error(f"SpeechBrain adapter transcription failed: {e}")
478
+
479
+ transcripts = []
480
+ total_segments = len(segments)
481
+
482
+ # If using full-audio mapping, run pipeline once on entire audio and try to align
483
+ full_asr_result = None
484
+ audio_np_full = waveform.squeeze().cpu().numpy()
485
+
486
+ if self.config.use_full_audio_for_segments:
487
+ # If SpeechBrain backend is used, ask adapter to produce full transcription
488
+ if (
489
+ getattr(self.config, "backend", "whisper") == "speechbrain"
490
+ and self._speechbrain_adapter is not None
491
+ ):
492
+ try:
493
+ self.logger.info(
494
+ "Running full-audio ASR via SpeechBrain adapter for alignment to segments"
495
+ )
496
+ full_text = self._speechbrain_adapter.transcribe_full_audio(
497
+ waveform, sample_rate
498
+ )
499
+ # SpeechBrain adapter currently returns plain text; we can't map timestamps, so store as simple str
500
+ full_asr_result = {"text": full_text}
501
+ except Exception as e:
502
+ self.logger.error(f"SpeechBrain full-audio ASR failed: {e}")
503
+ full_asr_result = None
504
+
505
+ elif self._pipeline not in (None, "FALLBACK"):
506
+ try:
507
+ # Whisper (seq2seq) pipelines don't accept 'sampling_rate' kwarg; omit it and set language
508
+ if getattr(self.config, "backend", "transformers") == "whisper":
509
+ kwargs = {}
510
+ # prefer explicit language if configured (e.g., Indonesian 'id')
511
+ kwargs["language"] = self.config.language
512
+ else:
513
+ kwargs = {"sampling_rate": sample_rate}
514
+
515
+ rt = self.config.return_timestamps
516
+ if rt in ("char", "word"):
517
+ kwargs["return_timestamps"] = rt
518
+
519
+ self.logger.info("Running full-audio ASR for alignment to segments")
520
+ full_asr_result = self._pipeline(audio_np_full, **kwargs)
521
+ except Exception as e:
522
+ self.logger.error(f"Full-audio ASR failed: {e}")
523
+ full_asr_result = None
524
+
525
+ # Build list of segment tasks that need per-segment ASR
526
+ tasks = []
527
+ for idx, seg in enumerate(segments):
528
+ # Skip very short segments
529
+ duration = seg.end - seg.start
530
+ if duration < 0.3:
531
+ continue
532
+ tasks.append((idx, seg))
533
+
534
+ # If we have a full-audio ASR result that includes timestamps, map once and avoid per-segment ASR
535
+ if full_asr_result is not None:
536
+ for idx, seg in tasks:
537
+ text = self._map_full_asr_to_segment(full_asr_result, seg)
538
+ if text:
539
+ text = self._postprocess_text(text)
540
+ if text:
541
+ transcripts.append(
542
+ TranscriptSegment(
543
+ speaker_id=seg.speaker_id,
544
+ start=seg.start,
545
+ end=seg.end,
546
+ text=text,
547
+ confidence=seg.confidence,
548
+ is_overlap=seg.is_overlap,
549
+ metadata={
550
+ "embedding": (
551
+ seg.embedding if hasattr(seg, "embedding") else None
552
+ ),
553
+ "asr_model": self.config.model_id,
554
+ },
555
+ )
556
+ )
557
+ # Filter out tasks that were handled by mapping
558
+ tasks = [
559
+ (i, s)
560
+ for (i, s) in tasks
561
+ if not any(t.start == s.start and t.end == s.end for t in transcripts)
562
+ ]
563
+
564
+ # If quick_mode or parallel workers > 1, perform parallel per-segment ASR
565
+ workers = int(getattr(self.config, "parallel_workers", 1))
566
+ if workers > 1 and tasks:
567
+ import concurrent.futures
568
+
569
+ def _transcribe_task(item):
570
+ idx, seg = item
571
+ # Progress update is handled by caller optionally, but we log
572
+ # Use context window if available
573
+ if self.config.context_window_s and self._pipeline not in (None, "FALLBACK"):
574
+ ctx_start = max(0.0, seg.start - self.config.context_window_s)
575
+ ctx_end = seg.end + self.config.context_window_s
576
+ cs = int(ctx_start * sample_rate)
577
+ ce = int(min(ctx_end * sample_rate, waveform.shape[-1]))
578
+ audio_np = waveform[:, cs:ce].squeeze().cpu().numpy()
579
+ text = self._transcribe_audio(
580
+ torch.from_numpy(audio_np).unsqueeze(0), sample_rate
581
+ )
582
+ else:
583
+ start_sample = int(seg.start * sample_rate)
584
+ end_sample = int(seg.end * sample_rate)
585
+ audio_segment = waveform[:, start_sample:end_sample]
586
+ text = self._transcribe_audio(audio_segment, sample_rate)
587
+
588
+ text = self._postprocess_text(text)
589
+ return idx, seg, text
590
+
591
+ with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as ex:
592
+ futures = {ex.submit(_transcribe_task, t): t for t in tasks}
593
+ for fut in concurrent.futures.as_completed(futures):
594
+ try:
595
+ idx, seg, text = fut.result()
596
+ if not text or not text.strip():
597
+ continue
598
+ transcripts.append(
599
+ TranscriptSegment(
600
+ speaker_id=seg.speaker_id,
601
+ start=seg.start,
602
+ end=seg.end,
603
+ text=text,
604
+ confidence=seg.confidence,
605
+ is_overlap=seg.is_overlap,
606
+ metadata={
607
+ "embedding": (
608
+ seg.embedding if hasattr(seg, "embedding") else None
609
+ ),
610
+ "asr_model": self.config.model_id,
611
+ },
612
+ )
613
+ )
614
+ except Exception as e:
615
+ self.logger.error(f"Segment transcription failed: {e}")
616
+ else:
617
+ # Serial fallback
618
+ for idx, seg in tasks:
619
+ # create context window
620
+ if self.config.context_window_s and self._pipeline not in (None, "FALLBACK"):
621
+ ctx_start = max(0.0, seg.start - self.config.context_window_s)
622
+ ctx_end = seg.end + self.config.context_window_s
623
+ cs = int(ctx_start * sample_rate)
624
+ ce = int(min(ctx_end * sample_rate, waveform.shape[-1]))
625
+ audio_np = waveform[:, cs:ce].squeeze().cpu().numpy()
626
+ text = self._transcribe_audio(
627
+ torch.from_numpy(audio_np).unsqueeze(0), sample_rate
628
+ )
629
+ else:
630
+ start_sample = int(seg.start * sample_rate)
631
+ end_sample = int(seg.end * sample_rate)
632
+ audio_segment = waveform[:, start_sample:end_sample]
633
+ text = self._transcribe_audio(audio_segment, sample_rate)
634
+
635
+ # Post-process text
636
+ text = self._postprocess_text(text)
637
+
638
+ # Skip empty transcriptions
639
+ if not text or not text.strip():
640
+ continue
641
+
642
+ transcripts.append(
643
+ TranscriptSegment(
644
+ speaker_id=seg.speaker_id,
645
+ start=seg.start,
646
+ end=seg.end,
647
+ text=text,
648
+ confidence=seg.confidence,
649
+ is_overlap=seg.is_overlap,
650
+ metadata={
651
+ "embedding": seg.embedding if hasattr(seg, "embedding") else None,
652
+ "asr_model": self.config.model_id,
653
+ },
654
+ )
655
+ )
656
+
657
+ return transcripts
658
+
659
+ def _detect_language_from_text(self, text: str) -> Optional[str]:
660
+ """Detect top language code from text using langdetect. Returns ISO code or None."""
661
+ try:
662
+ from langdetect import detect_langs
663
+
664
+ if not text or not text.strip():
665
+ return None
666
+ probs = detect_langs(text)
667
+ if not probs:
668
+ return None
669
+ return probs[0].lang
670
+ except Exception:
671
+ return None
672
+
673
+ def _transcribe_audio(self, audio_segment: torch.Tensor, sample_rate: int) -> str:
674
+ """Transcribe a single audio segment
675
+
676
+ Supports `language='auto'` for Whisper backend which will perform a quick
677
+ pre-pass (no language hint) and use a text-based language detector to
678
+ choose the language for the final transcription pass.
679
+
680
+ If `self.config.cst_hz` is set, an aggressive lossy preprocessor (approximation
681
+ of a low-rate Continuous Speech Tokenizer) is applied before sending audio to
682
+ the ASR backend. This significantly reduces compute at the cost of precision
683
+ and should be used only when speed is critical.
684
+ """
685
+
686
+ # Fallback mode: only return placeholders when no working ASR backend is available.
687
+ # If user requested WhisperX backend and model is loaded, prefer using WhisperX.
688
+ if self._pipeline == "FALLBACK":
689
+ backend = getattr(self.config, "backend", None)
690
+ if not (backend == "whisperx" and self._whisperx_model is not None):
691
+ duration = audio_segment.shape[-1] / sample_rate
692
+ return f"[Transkripsi placeholder - durasi {duration:.1f}s]"
693
+
694
+ # Convert to numpy
695
+ audio_np = audio_segment.squeeze().cpu().numpy()
696
+
697
+ # Apply CST approximation preprocessor if requested (lossy, speed-optimized)
698
+ if getattr(self.config, "cst_hz", None) is not None:
699
+ try:
700
+ audio_np = self._apply_cst_approximation(audio_np, sample_rate, float(self.config.cst_hz))
701
+ # After approximation we keep the original sample_rate for downstream callers
702
+ self.logger.info(f"Applied CST approximation: {self.config.cst_hz} Hz (lossy)")
703
+ except Exception as e:
704
+ self.logger.warning(f"CST approximation failed, continuing with original audio: {e}")
705
+
706
+ # Ensure float32
707
+ if audio_np.dtype != np.float32:
708
+ audio_np = audio_np.astype(np.float32)
709
+
710
+
711
+ # WhisperX backend
712
+ if getattr(self.config, "backend", None) == "whisperx":
713
+ try:
714
+ if self._whisperx_model is None:
715
+ self._load_model()
716
+ if self._whisperx_model is None:
717
+ return ""
718
+
719
+ language = getattr(self.config, "language", "id")
720
+ # whisperx expects None for auto language
721
+ language_arg = None if language == "auto" else language
722
+
723
+ vad_filter = bool(getattr(self.config, "whisperx_vad_filter", True))
724
+
725
+ # Build kwargs and only pass vad_filter if the transcribe signature accepts it
726
+ from inspect import signature
727
+
728
+ kwargs = {"batch_size": self.config.batch_size}
729
+ if language_arg is not None:
730
+ kwargs["language"] = language_arg
731
+
732
+ try:
733
+ sig = signature(self._whisperx_model.transcribe)
734
+ if "vad_filter" in sig.parameters:
735
+ kwargs["vad_filter"] = vad_filter
736
+ except Exception:
737
+ # If introspection fails, do not pass vad_filter
738
+ pass
739
+
740
+ # First attempt
741
+ try:
742
+ result = self._whisperx_model.transcribe(audio_np, **kwargs)
743
+ except Exception as e_inner:
744
+ self.logger.warning(f"WhisperX transcription failed on first attempt: {e_inner}. Retrying with `vad_filter=False, batch_size=1`")
745
+ # retry with safer options
746
+ try:
747
+ retry_kwargs = kwargs.copy()
748
+ retry_kwargs["batch_size"] = 1
749
+ if "vad_filter" in retry_kwargs:
750
+ retry_kwargs["vad_filter"] = False
751
+ result = self._whisperx_model.transcribe(audio_np, **retry_kwargs)
752
+ except Exception as e_retry:
753
+ self.logger.error(f"WhisperX transcription retry failed: {e_retry}. Falling back to lightweight Whisper model.")
754
+ # Fallback: switch backend to 'whisper' with small model and attempt to load it
755
+ try:
756
+ self.config.backend = "whisper"
757
+ self.config.model_id = "openai/whisper-small"
758
+ # Clear whisperx state
759
+ self._whisperx_model = None
760
+ self._pipeline = None
761
+ self._model = None
762
+ self._processor = None
763
+ self._load_model()
764
+ # attempt pipeline-based transcription
765
+ return self._transcribe_audio(audio_segment, sample_rate)
766
+ except Exception as e_fb:
767
+ self.logger.error(f"Fallback ASR model load/transcription failed: {e_fb}")
768
+ return ""
769
+
770
+ # Normalize result into plain text.
771
+ if isinstance(result, dict):
772
+ # 'text' is common, but some ASR returns 'segments' list
773
+ if "text" in result and result.get("text"):
774
+ return result.get("text", "")
775
+ if "segments" in result and isinstance(result["segments"], list):
776
+ seg_texts = [
777
+ s.get("text", "") for s in result["segments"] if isinstance(s, dict)
778
+ ]
779
+ joined = " ".join(t.strip() for t in seg_texts if t and t.strip())
780
+ return joined or ""
781
+ # fallback to empty
782
+ return ""
783
+ return str(result)
784
+ except Exception as e:
785
+ self.logger.error(f"WhisperX transcription failed: {e}")
786
+ return ""
787
+ # Use pipeline if available
788
+ if self._pipeline is not None and self._pipeline != "FALLBACK":
789
+ try:
790
+ # Whisper backend: handle language auto-detection
791
+ if getattr(self.config, "backend", "transformers") == "whisper":
792
+ if getattr(self.config, "language", "id") == "auto":
793
+ # quick pre-pass to get candidate text
794
+ try:
795
+ quick_kwargs = {}
796
+ rt = self.config.return_timestamps
797
+ if rt in ("char", "word"):
798
+ quick_kwargs["return_timestamps"] = rt
799
+ quick_res = self._pipeline(audio_np, **quick_kwargs)
800
+ quick_text = (
801
+ quick_res.get("text", "")
802
+ if isinstance(quick_res, dict)
803
+ else str(quick_res)
804
+ )
805
+ detected = self._detect_language_from_text(quick_text)
806
+ chosen_lang = detected if detected else "id"
807
+ except Exception:
808
+ chosen_lang = "id"
809
+ else:
810
+ chosen_lang = getattr(self.config, "language", "id")
811
+
812
+ kwargs = {"language": chosen_lang}
813
+ else:
814
+ kwargs = {"sampling_rate": sample_rate}
815
+
816
+ rt = self.config.return_timestamps
817
+ if rt in ("char", "word"):
818
+ kwargs["return_timestamps"] = rt
819
+
820
+ result = self._pipeline(audio_np, **kwargs)
821
+
822
+ # If result is a dict with text
823
+ if isinstance(result, dict):
824
+ # If pipeline returns a list of word/segment timestamps, user may want that via full-audio flow
825
+ if isinstance(result.get("chunks", None), list) or isinstance(
826
+ result.get("segments", None), list
827
+ ):
828
+ return result.get("text", "")
829
+ return result.get("text", "")
830
+ return str(result)
831
+
832
+ except Exception as e:
833
+ self.logger.warning(f"Pipeline transcription failed: {e}")
834
+ # Try to fall back to direct model path (if available)
835
+ self._pipeline = None
836
+ # continue to attempt direct model below
837
+
838
+ # Use direct model if pipeline not available
839
+ if self._model is not None and self._processor is not None:
840
+ try:
841
+ # Process input
842
+ inputs = self._processor(
843
+ audio_np, sampling_rate=sample_rate, return_tensors="pt", padding=True
844
+ )
845
+
846
+ # Move to device
847
+ if self.device == "cuda" and torch.cuda.is_available():
848
+ inputs = {k: v.cuda() for k, v in inputs.items()}
849
+
850
+ # Run inference
851
+ with torch.no_grad():
852
+ logits = self._model(**inputs).logits
853
+
854
+ # If CTC beam decoder available and requested, use it
855
+ if (
856
+ getattr(self, "_ctc_decoder", None) is not None
857
+ and self.config.decoder == "beam"
858
+ ):
859
+ try:
860
+ # Convert logits to probabilities (T, C)
861
+ probs = torch.softmax(logits, dim=-1).cpu().numpy()
862
+ # some models return batch dimension; take first batch
863
+ emissions = probs[0]
864
+
865
+ try:
866
+ # Try simple decode
867
+ transcription = self._ctc_decoder.decode(
868
+ emissions, beam_width=self.config.beam_width
869
+ )
870
+ except Exception:
871
+ # Try beam candidates and pick top
872
+ beams = self._ctc_decoder.decode_beams(
873
+ emissions, beam_width=self.config.beam_width
874
+ )
875
+ transcription = beams[0][0] if beams else ""
876
+
877
+ return transcription if transcription else ""
878
+ except Exception as e:
879
+ self.logger.warning(f"CTC beam decode failed: {e}")
880
+ # fallback to greedy
881
+
882
+ # Fallback: greedy argmax decode
883
+ predicted_ids = torch.argmax(logits, dim=-1)
884
+ transcription = self._processor.batch_decode(predicted_ids)
885
+
886
+ return transcription[0] if transcription else ""
887
+
888
+ except Exception as e:
889
+ self.logger.error(f"Direct model transcription failed: {e}")
890
+ return ""
891
+
892
+ return ""
893
+
894
+ def transcribe_full_audio(self, waveform: torch.Tensor, sample_rate: int = 16000) -> str:
895
+ """
896
+ Transcribe full audio without diarization.
897
+ Useful for baseline comparison.
898
+ """
899
+ self._load_model()
900
+
901
+ # WhisperX: call directly to keep consistency
902
+ if getattr(self.config, "backend", None) == "whisperx":
903
+ audio_np = waveform.squeeze().cpu().numpy().astype(np.float32, copy=False)
904
+ if self._whisperx_model is None:
905
+ return ""
906
+ language = getattr(self.config, "language", "id")
907
+ language_arg = None if language == "auto" else language
908
+ vad_filter = bool(getattr(self.config, "whisperx_vad_filter", True))
909
+ try:
910
+ res = self._whisperx_model.transcribe(
911
+ audio_np,
912
+ batch_size=self.config.batch_size,
913
+ language=language_arg,
914
+ vad_filter=vad_filter,
915
+ )
916
+ text = res.get("text", "") if isinstance(res, dict) else str(res)
917
+ return self._postprocess_text(text)
918
+ except Exception as e:
919
+ self.logger.warning(f"WhisperX full-audio transcription failed: {e}. Retrying with vad_filter=False, batch_size=1")
920
+ try:
921
+ res = self._whisperx_model.transcribe(
922
+ audio_np,
923
+ batch_size=1,
924
+ language=language_arg,
925
+ vad_filter=False,
926
+ )
927
+ text = res.get("text", "") if isinstance(res, dict) else str(res)
928
+ return self._postprocess_text(text)
929
+ except Exception as e2:
930
+ self.logger.error(f"WhisperX full-audio retry failed: {e2}. Falling back to 'whisper-small'.")
931
+ # Fallback to whisper-small pipeline
932
+ try:
933
+ self.config.backend = "whisper"
934
+ self.config.model_id = "openai/whisper-small"
935
+ self._whisperx_model = None
936
+ self._pipeline = None
937
+ self._model = None
938
+ self._processor = None
939
+ self._load_model()
940
+ text = self._transcribe_audio(waveform, sample_rate)
941
+ return self._postprocess_text(text)
942
+ except Exception as e_fb:
943
+ self.logger.error(f"Fallback full-audio ASR failed: {e_fb}")
944
+ return ""
945
+
946
+ text = self._transcribe_audio(waveform, sample_rate)
947
+ return self._postprocess_text(text)
948
+
949
+ def _apply_cst_approximation(self, audio_np: np.ndarray, sample_rate: int, cst_hz: float) -> np.ndarray:
950
+ """Approximate a Continuous Speech Tokenizer by block-averaging audio frames
951
+
952
+ This method is intentionally conservative and reversible only in the sense
953
+ that it produces a downsample-like version of the waveform which is then
954
+ expanded back to the original rate (by repeating block values). This is
955
+ extremely lossy but can reduce model runtime for long audio when you
956
+ accept lower ASR fidelity.
957
+
958
+ Implementation details:
959
+ - token_duration = 1.0 / cst_hz
960
+ - compute mean amplitude per token window
961
+ - expand each token mean to the window length (constant value) to produce
962
+ a waveform of the original sample length
963
+
964
+ Note: This is an approximation to the user's requested ultralow-rate tokenizer
965
+ (7.5 Hz). For best accuracy, tune `cst_hz` and verify results on your data.
966
+ """
967
+ if cst_hz <= 0 or np.isnan(cst_hz):
968
+ return audio_np
969
+
970
+ token_dur = 1.0 / float(cst_hz)
971
+ window_samp = max(1, int(round(token_dur * sample_rate)))
972
+ # Partition audio and compute mean for each window
973
+ n = len(audio_np)
974
+ n_windows = int(np.ceil(n / window_samp))
975
+ means = []
976
+ for i in range(n_windows):
977
+ s = i * window_samp
978
+ e = min(n, s + window_samp)
979
+ if e <= s:
980
+ means.append(0.0)
981
+ else:
982
+ means.append(float(np.mean(audio_np[s:e])))
983
+
984
+ # Reconstruct waveform by repeating means per window
985
+ out = np.zeros(n, dtype=np.float32)
986
+ for i, m in enumerate(means):
987
+ s = i * window_samp
988
+ e = min(n, s + window_samp)
989
+ out[s:e] = m
990
+
991
+ return out
992
+
993
+ def _postprocess_text(self, text: str) -> str:
994
+ """Clean and format transcribed text"""
995
+ if not text:
996
+ return ""
997
+
998
+ # Basic cleaning
999
+ text = text.strip()
1000
+
1001
+ # Remove special tokens and math/code blocks bounded by $$...$$
1002
+ text = re.sub(r"<[^>]+>", "", text)
1003
+ text = re.sub(r"\$\$.*?\$\$", "", text, flags=re.DOTALL)
1004
+
1005
+ # Normalize whitespace
1006
+ if self.config.normalize_whitespace:
1007
+ text = " ".join(text.split())
1008
+
1009
+ # Capitalize first letter of sentences
1010
+ if self.config.capitalize_sentences and text:
1011
+ # Capitalize first character
1012
+ text = text[0].upper() + text[1:] if len(text) > 1 else text.upper()
1013
+
1014
+ # Capitalize after sentence-ending punctuation
1015
+ text = re.sub(r"([.!?]\s+)([a-z])", lambda m: m.group(1) + m.group(2).upper(), text)
1016
+
1017
+ # Add period if missing
1018
+ if text and text[-1] not in ".!?,:;":
1019
+ text += "."
1020
+
1021
+ return text
1022
+
1023
+ def _map_full_asr_to_segment(self, full_result: Any, seg: SpeakerSegment) -> str:
1024
+ """Attempt to extract text for a given segment from a full-audio ASR result.
1025
+
1026
+ Supports multiple result shapes returned by different ASR pipelines:
1027
+ - result['chunks'] or result['segments']: list of dicts with 'start','end','text'
1028
+ - result may also include 'words' lists with per-word timestamps
1029
+ If no timestamped structure is present, returns empty string so caller can fallback.
1030
+ """
1031
+ try:
1032
+ # Prefer 'chunks' (some pipelines) then 'segments'
1033
+ blocks = None
1034
+ if isinstance(full_result, dict):
1035
+ if isinstance(full_result.get("chunks"), list):
1036
+ blocks = full_result["chunks"]
1037
+ elif isinstance(full_result.get("segments"), list):
1038
+ blocks = full_result["segments"]
1039
+ # some pipelines return word-level timestamps
1040
+ elif isinstance(full_result.get("words"), list):
1041
+ words = full_result["words"]
1042
+ text_parts = [
1043
+ w["word"]
1044
+ for w in words
1045
+ if w.get("start") is not None
1046
+ and w.get("end") is not None
1047
+ and (w["start"] >= seg.start and w["end"] <= seg.end)
1048
+ ]
1049
+ return " ".join(text_parts)
1050
+
1051
+ if blocks is None:
1052
+ return ""
1053
+
1054
+ # Concatenate blocks that overlap with seg time window
1055
+ collected = []
1056
+ for b in blocks:
1057
+ bstart = float(b.get("start", 0.0))
1058
+ bend = float(b.get("end", 0.0))
1059
+ if bstart < seg.end and bend > seg.start:
1060
+ collected.append(b.get("text", ""))
1061
+
1062
+ return " ".join([c.strip() for c in collected]).strip()
1063
+ except Exception:
1064
+ return ""
1065
+
1066
+ def get_transcription_stats(self, segments: List[TranscriptSegment]) -> Dict[str, Any]:
1067
+ """
1068
+ Get transcription statistics.
1069
+
1070
+ Args:
1071
+ segments: List of transcript segments
1072
+
1073
+ Returns:
1074
+ Dictionary with statistics
1075
+ """
1076
+ if not segments:
1077
+ return {
1078
+ "total_segments": 0,
1079
+ "total_words": 0,
1080
+ "total_duration": 0.0,
1081
+ "words_per_minute": 0.0,
1082
+ "speakers": {},
1083
+ }
1084
+
1085
+ total_words = sum(seg.word_count for seg in segments)
1086
+ total_duration = sum(seg.duration for seg in segments)
1087
+
1088
+ # Per-speaker stats
1089
+ speaker_stats = {}
1090
+ for seg in segments:
1091
+ if seg.speaker_id not in speaker_stats:
1092
+ speaker_stats[seg.speaker_id] = {
1093
+ "word_count": 0,
1094
+ "duration": 0.0,
1095
+ "segment_count": 0,
1096
+ }
1097
+
1098
+ speaker_stats[seg.speaker_id]["word_count"] += seg.word_count
1099
+ speaker_stats[seg.speaker_id]["duration"] += seg.duration
1100
+ speaker_stats[seg.speaker_id]["segment_count"] += 1
1101
+
1102
+ return {
1103
+ "total_segments": len(segments),
1104
+ "total_words": total_words,
1105
+ "total_duration": total_duration,
1106
+ "words_per_minute": (total_words / total_duration * 60) if total_duration > 0 else 0,
1107
+ "speakers": speaker_stats,
1108
+ }
src/transcriber_speechbrain.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SpeechBrain ASR wrapper (optional)
3
+ Provides a lightweight adapter around SpeechBrain's EncoderASR/EncoderDecoderASR to be used
4
+ as an optional backend in `meeting_transcriber`.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Any, List, Optional
12
+
13
+ import numpy as np
14
+ import torch
15
+
16
+ from src.diarization import SpeakerSegment
17
+ from src.transcriber import TranscriptSegment
18
+
19
+
20
+ @dataclass
21
+ class SpeechBrainASRConfig:
22
+ model_id: str = "speechbrain/asr-crdnn-rnnlm-librispeech"
23
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
24
+ chunk_length_s: float = 30.0
25
+
26
+
27
+ class SpeechBrainTranscriber:
28
+ """Adapter for SpeechBrain ASR models.
29
+
30
+ Usage:
31
+ t = SpeechBrainTranscriber(config)
32
+ t.transcribe_segments(waveform, segments, sample_rate)
33
+ """
34
+
35
+ def __init__(self, config: Optional[SpeechBrainASRConfig] = None, models_dir: str = "./models"):
36
+ self.config = config or SpeechBrainASRConfig()
37
+ self.models_dir = Path(models_dir)
38
+ self.models_dir.mkdir(parents=True, exist_ok=True)
39
+ self._model = None
40
+
41
+ def _load_model(self):
42
+ if self._model is not None:
43
+ return
44
+
45
+ try:
46
+ # Prefer the new import path to avoid deprecation warnings in SpeechBrain >=1.0
47
+ try:
48
+ from speechbrain.inference import ( # type: ignore
49
+ EncoderASR,
50
+ EncoderDecoderASR,
51
+ )
52
+ except Exception:
53
+ from speechbrain.pretrained import ( # type: ignore
54
+ EncoderASR,
55
+ EncoderDecoderASR,
56
+ )
57
+
58
+ # Try EncoderDecoderASR first (seq2seq), fall back to EncoderASR
59
+ try:
60
+ self._model = EncoderDecoderASR.from_hparams(
61
+ source=self.config.model_id, savedir=str(self.models_dir)
62
+ )
63
+ except Exception:
64
+ self._model = EncoderASR.from_hparams(
65
+ source=self.config.model_id, savedir=str(self.models_dir)
66
+ )
67
+
68
+ except Exception as e:
69
+ print(f"[SpeechBrain] Could not load model: {e}")
70
+ self._model = None
71
+
72
+ def transcribe_full_audio(self, waveform: torch.Tensor, sample_rate: int = 16000) -> str:
73
+ """Transcribe full audio waveform. Returns post-processed text (raw)."""
74
+ self._load_model()
75
+ if self._model is None:
76
+ return ""
77
+
78
+ # SpeechBrain typically expects a file path for convenience; some models accept numpy arrays
79
+ try:
80
+ audio_np = waveform.squeeze().cpu().numpy()
81
+ # Many SpeechBrain models accept numpy arrays for `transcribe_batch`/`transcribe_file`
82
+ # Use transcribe_batch for in-memory audio
83
+ try:
84
+ res = self._model.transcribe_batch([audio_np])
85
+ if isinstance(res, list):
86
+ return str(res[0])
87
+ return str(res)
88
+ except Exception:
89
+ # Fallback: write temporary file
90
+ import tempfile
91
+
92
+ import soundfile as sf
93
+
94
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
95
+ sf.write(tmp.name, audio_np.astype("float32"), sample_rate)
96
+ return str(self._model.transcribe_file(tmp.name))
97
+ except Exception as e:
98
+ print(f"[SpeechBrain] Full audio transcription failed: {e}")
99
+ return ""
100
+
101
+ def transcribe_segments(
102
+ self, waveform: torch.Tensor, segments: List[SpeakerSegment], sample_rate: int = 16000
103
+ ) -> List[TranscriptSegment]:
104
+ """Transcribe each segment and return list of TranscriptSegment objects."""
105
+ self._load_model()
106
+ transcripts: List[TranscriptSegment] = []
107
+
108
+ if self._model is None:
109
+ return transcripts
110
+
111
+ for seg in segments:
112
+ start = int(seg.start * sample_rate)
113
+ end = int(seg.end * sample_rate)
114
+ segment_np = waveform[:, start:end].squeeze().cpu().numpy()
115
+
116
+ if segment_np.size == 0:
117
+ continue
118
+
119
+ # Skip extremely short segments
120
+ if seg.end - seg.start < 0.2:
121
+ continue
122
+
123
+ try:
124
+ # prefer in-memory transcribe_batch
125
+ res = self._model.transcribe_batch([segment_np])
126
+ text = str(res[0]) if isinstance(res, list) else str(res)
127
+ except Exception:
128
+ # fallback to temporary file path
129
+ try:
130
+ import tempfile
131
+
132
+ import soundfile as sf
133
+
134
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
135
+ sf.write(tmp.name, segment_np.astype("float32"), sample_rate)
136
+ text = str(self._model.transcribe_file(tmp.name))
137
+ except Exception as e:
138
+ print(f"[SpeechBrain] Segment transcription failed: {e}")
139
+ text = ""
140
+
141
+ if not text or not text.strip():
142
+ continue
143
+
144
+ transcripts.append(
145
+ TranscriptSegment(
146
+ speaker_id=seg.speaker_id,
147
+ start=seg.start,
148
+ end=seg.end,
149
+ text=text.strip(),
150
+ confidence=getattr(seg, "confidence", 1.0),
151
+ is_overlap=getattr(seg, "is_overlap", False),
152
+ )
153
+ )
154
+
155
+ return transcripts
src/utils.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility Functions Module
3
+ ========================
4
+ Helper functions used across the system.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import hashlib
10
+ import json
11
+ import logging
12
+ import os
13
+ import re
14
+ import time
15
+ from functools import wraps
16
+ from pathlib import Path
17
+ from typing import Any, List, Optional, Union
18
+
19
+ # =============================================================================
20
+ # Logging Setup
21
+ # =============================================================================
22
+
23
+
24
+ def setup_logger(
25
+ name: str = "MeetingTranscriber", level: int = logging.INFO, log_file: Optional[str] = None
26
+ ) -> logging.Logger:
27
+ """
28
+ Setup and return a logger instance.
29
+
30
+ Args:
31
+ name: Logger name
32
+ level: Logging level
33
+ log_file: Optional file path for logging
34
+
35
+ Returns:
36
+ Configured logger instance
37
+ """
38
+ logger = logging.getLogger(name)
39
+ logger.setLevel(level)
40
+
41
+ # Console handler
42
+ console_handler = logging.StreamHandler()
43
+ console_handler.setLevel(level)
44
+
45
+ # Formatter
46
+ formatter = logging.Formatter(
47
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
48
+ )
49
+ console_handler.setFormatter(formatter)
50
+ logger.addHandler(console_handler)
51
+
52
+ # File handler (optional)
53
+ if log_file:
54
+ os.makedirs(os.path.dirname(log_file), exist_ok=True)
55
+ file_handler = logging.FileHandler(log_file, encoding="utf-8")
56
+ file_handler.setLevel(level)
57
+ file_handler.setFormatter(formatter)
58
+ logger.addHandler(file_handler)
59
+
60
+ return logger
61
+
62
+
63
+ # =============================================================================
64
+ # Timing Utilities
65
+ # =============================================================================
66
+
67
+
68
+ def timer(func):
69
+ """Decorator to measure function execution time"""
70
+
71
+ @wraps(func)
72
+ def wrapper(*args, **kwargs):
73
+ start_time = time.time()
74
+ result = func(*args, **kwargs)
75
+ end_time = time.time()
76
+ print(f"[Timer] {func.__name__} took {end_time - start_time:.2f} seconds")
77
+ return result
78
+
79
+ return wrapper
80
+
81
+
82
+ class Timer:
83
+ """Context manager for timing code blocks"""
84
+
85
+ def __init__(self, name: str = "Block"):
86
+ self.name = name
87
+ self.start_time = None
88
+ self.end_time = None
89
+
90
+ def __enter__(self):
91
+ self.start_time = time.time()
92
+ return self
93
+
94
+ def __exit__(self, *args):
95
+ self.end_time = time.time()
96
+ self.elapsed = self.end_time - self.start_time
97
+ print(f"[Timer] {self.name} took {self.elapsed:.2f} seconds")
98
+
99
+
100
+ # =============================================================================
101
+ # File Utilities
102
+ # =============================================================================
103
+
104
+
105
+ def get_file_hash(filepath: Union[str, Path], algorithm: str = "md5") -> str:
106
+ """
107
+ Calculate hash of a file.
108
+
109
+ Args:
110
+ filepath: Path to file
111
+ algorithm: Hash algorithm ('md5', 'sha256')
112
+
113
+ Returns:
114
+ Hex digest of file hash
115
+ """
116
+ hash_func = hashlib.new(algorithm)
117
+
118
+ with open(filepath, "rb") as f:
119
+ for chunk in iter(lambda: f.read(8192), b""):
120
+ hash_func.update(chunk)
121
+
122
+ return hash_func.hexdigest()
123
+
124
+
125
+ def ensure_dir(path: Union[str, Path]) -> Path:
126
+ """Ensure directory exists, create if not"""
127
+ path = Path(path)
128
+ path.mkdir(parents=True, exist_ok=True)
129
+ return path
130
+
131
+
132
+ def list_audio_files(
133
+ directory: Union[str, Path], extensions: Optional[List[str]] = None
134
+ ) -> List[Path]:
135
+ """
136
+ List all audio files in directory.
137
+
138
+ Args:
139
+ directory: Directory to search
140
+ extensions: List of extensions to include (default: common audio formats)
141
+
142
+ Returns:
143
+ List of audio file paths
144
+ """
145
+ if extensions is None:
146
+ extensions = [".wav", ".mp3", ".flac", ".ogg", ".m4a", ".wma", ".aac"]
147
+
148
+ directory = Path(directory)
149
+ audio_files = []
150
+
151
+ for ext in extensions:
152
+ audio_files.extend(directory.glob(f"*{ext}"))
153
+ audio_files.extend(directory.glob(f"*{ext.upper()}"))
154
+
155
+ return sorted(audio_files)
156
+
157
+
158
+ def sanitize_filename(filename: str) -> str:
159
+ """Remove invalid characters from filename"""
160
+ # Remove invalid characters
161
+ sanitized = re.sub(r'[<>:"/\\|?*]', "", filename)
162
+ # Replace spaces with underscores
163
+ sanitized = sanitized.replace(" ", "_")
164
+ # Remove multiple underscores
165
+ sanitized = re.sub(r"_+", "_", sanitized)
166
+ return sanitized.strip("_")
167
+
168
+
169
+ # =============================================================================
170
+ # JSON Utilities
171
+ # =============================================================================
172
+
173
+
174
+ def save_json(data: Any, filepath: Union[str, Path], indent: int = 2):
175
+ """Save data to JSON file"""
176
+ filepath = Path(filepath)
177
+ filepath.parent.mkdir(parents=True, exist_ok=True)
178
+
179
+ with open(filepath, "w", encoding="utf-8") as f:
180
+ json.dump(data, f, ensure_ascii=False, indent=indent, default=str)
181
+
182
+
183
+ def load_json(filepath: Union[str, Path]) -> Any:
184
+ """Load data from JSON file"""
185
+ with open(filepath, "r", encoding="utf-8") as f:
186
+ return json.load(f)
187
+
188
+
189
+ # =============================================================================
190
+ # Text Utilities
191
+ # =============================================================================
192
+
193
+
194
+ def format_duration(seconds: float) -> str:
195
+ """Format duration in seconds to human-readable string"""
196
+ if seconds < 0:
197
+ return "0:00"
198
+
199
+ hours = int(seconds // 3600)
200
+ minutes = int((seconds % 3600) // 60)
201
+ secs = int(seconds % 60)
202
+
203
+ if hours > 0:
204
+ return f"{hours}:{minutes:02d}:{secs:02d}"
205
+ return f"{minutes}:{secs:02d}"
206
+
207
+
208
+ def format_timestamp(seconds: float) -> str:
209
+ """Format timestamp for document display"""
210
+ seconds = max(0, seconds)
211
+ hours = int(seconds // 3600)
212
+ minutes = int((seconds % 3600) // 60)
213
+ secs = int(seconds % 60)
214
+
215
+ if hours > 0:
216
+ return f"{hours:02d}:{minutes:02d}:{secs:02d}"
217
+ return f"{minutes:02d}:{secs:02d}"
218
+
219
+
220
+ def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str:
221
+ """Truncate text to maximum length"""
222
+ if len(text) <= max_length:
223
+ return text
224
+ return text[: max_length - len(suffix)] + suffix
225
+
226
+
227
+ def clean_text(text: str) -> str:
228
+ """Clean text: normalize whitespace, remove special chars"""
229
+ if not text:
230
+ return ""
231
+
232
+ # Normalize whitespace
233
+ text = " ".join(text.split())
234
+
235
+ # Remove control characters
236
+ text = re.sub(r"[\x00-\x1f\x7f-\x9f]", "", text)
237
+
238
+ return text.strip()
239
+
240
+
241
+ # =============================================================================
242
+ # Progress Utilities
243
+ # =============================================================================
244
+
245
+
246
+ class ProgressTracker:
247
+ """Simple progress tracker for long operations"""
248
+
249
+ def __init__(self, total: int, description: str = "Processing"):
250
+ self.total = total
251
+ self.current = 0
252
+ self.description = description
253
+ self.start_time = time.time()
254
+
255
+ def update(self, n: int = 1):
256
+ """Update progress by n steps"""
257
+ self.current += n
258
+ self._print_progress()
259
+
260
+ def _print_progress(self):
261
+ """Print progress bar"""
262
+ percent = self.current / self.total * 100 if self.total > 0 else 0
263
+ elapsed = time.time() - self.start_time
264
+
265
+ # Estimate remaining time
266
+ if self.current > 0:
267
+ eta = elapsed / self.current * (self.total - self.current)
268
+ eta_str = format_duration(eta)
269
+ else:
270
+ eta_str = "?"
271
+
272
+ bar_length = 30
273
+ filled = int(bar_length * self.current / self.total) if self.total > 0 else 0
274
+ bar = "█" * filled + "░" * (bar_length - filled)
275
+
276
+ print(
277
+ f"\r[{bar}] {percent:5.1f}% ({self.current}/{self.total}) ETA: {eta_str} ",
278
+ end="",
279
+ flush=True,
280
+ )
281
+
282
+ if self.current >= self.total:
283
+ print() # New line at completion
284
+
285
+ def finish(self):
286
+ """Mark progress as complete"""
287
+ self.current = self.total
288
+ self._print_progress()
289
+
290
+ elapsed = time.time() - self.start_time
291
+ print(f"[{self.description}] Completed in {format_duration(elapsed)}")
292
+
293
+
294
+ # =============================================================================
295
+ # Validation Utilities
296
+ # =============================================================================
297
+
298
+
299
+ def validate_audio_file(filepath: Union[str, Path]) -> bool:
300
+ """
301
+ Validate that file exists and is a supported audio format.
302
+
303
+ Args:
304
+ filepath: Path to audio file
305
+
306
+ Returns:
307
+ True if valid, raises exception otherwise
308
+ """
309
+ filepath = Path(filepath)
310
+
311
+ if not filepath.exists():
312
+ raise FileNotFoundError(f"Audio file not found: {filepath}")
313
+
314
+ supported_formats = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".wma", ".aac"}
315
+
316
+ if filepath.suffix.lower() not in supported_formats:
317
+ raise ValueError(
318
+ f"Unsupported audio format: {filepath.suffix}. "
319
+ f"Supported: {', '.join(supported_formats)}"
320
+ )
321
+
322
+ return True
323
+
324
+
325
+ def validate_ground_truth_file(filepath: Union[str, Path]) -> bool:
326
+ """
327
+ Validate ground truth file format.
328
+
329
+ Args:
330
+ filepath: Path to ground truth file
331
+
332
+ Returns:
333
+ True if valid
334
+ """
335
+ filepath = Path(filepath)
336
+
337
+ if not filepath.exists():
338
+ raise FileNotFoundError(f"Ground truth file not found: {filepath}")
339
+
340
+ supported_formats = {".txt", ".json", ".rttm"}
341
+
342
+ if filepath.suffix.lower() not in supported_formats:
343
+ raise ValueError(
344
+ f"Unsupported ground truth format: {filepath.suffix}. "
345
+ f"Supported: {', '.join(supported_formats)}"
346
+ )
347
+
348
+ return True
349
+
350
+
351
+ # =============================================================================
352
+ # Ground Truth Parsing
353
+ # =============================================================================
354
+
355
+
356
+ def parse_transcript_file(filepath: Union[str, Path]) -> str:
357
+ """
358
+ Parse transcript file (plain text).
359
+
360
+ Args:
361
+ filepath: Path to transcript file
362
+
363
+ Returns:
364
+ Transcript text
365
+ """
366
+ with open(filepath, "r", encoding="utf-8") as f:
367
+ return f.read().strip()
368
+
369
+
370
+ def parse_rttm_file(filepath: Union[str, Path]) -> List[tuple]:
371
+ """
372
+ Parse RTTM (Rich Transcription Time Marked) file for diarization ground truth.
373
+
374
+ RTTM format:
375
+ SPEAKER <file_id> <channel> <start> <duration> <NA> <NA> <speaker_id> <NA> <NA>
376
+
377
+ Args:
378
+ filepath: Path to RTTM file
379
+
380
+ Returns:
381
+ List of (speaker_id, start, end) tuples
382
+ """
383
+ segments = []
384
+
385
+ with open(filepath, "r", encoding="utf-8") as f:
386
+ for line in f:
387
+ line = line.strip()
388
+ if not line or line.startswith("#"):
389
+ continue
390
+
391
+ parts = line.split()
392
+ if len(parts) >= 8 and parts[0] == "SPEAKER":
393
+ start = float(parts[3])
394
+ duration = float(parts[4])
395
+ speaker_id = parts[7]
396
+
397
+ segments.append((speaker_id, start, start + duration))
398
+
399
+ return segments
400
+
401
+
402
+ # -----------------------------------------------------------------------------
403
+ # Helpers for building RTTM from speaker-labeled transcripts
404
+ # -----------------------------------------------------------------------------
405
+
406
+
407
+ def parse_speaker_labeled_text(text: str) -> List[Tuple[str, str]]:
408
+ """Parse speaker-labeled transcript text into a list of (speaker, text).
409
+
410
+ Recognizes lines that start with `Name:` (case-insensitive) as speaker labels.
411
+ Consecutive non-label lines are appended to the current speaker utterance.
412
+
413
+ Returns empty list if input is empty.
414
+ """
415
+ label_re = re.compile(r"^\s*([^:\n\r]{1,80}):\s*(.*)$")
416
+
417
+ items: List[Tuple[str, str]] = []
418
+
419
+ cur_speaker = None
420
+ cur_lines: List[str] = []
421
+
422
+ for raw in text.splitlines():
423
+ line = raw.rstrip("\n\r")
424
+ m = label_re.match(line)
425
+ if m:
426
+ if cur_speaker is not None:
427
+ items.append((cur_speaker, " ".join(l.strip() for l in cur_lines if l.strip())))
428
+ cur_speaker = m.group(1).strip()
429
+ first = m.group(2).strip()
430
+ cur_lines = [first] if first else []
431
+ else:
432
+ if line.strip():
433
+ cur_lines.append(line.strip())
434
+
435
+ if cur_speaker is not None:
436
+ items.append((cur_speaker, " ".join(l.strip() for l in cur_lines if l.strip())))
437
+
438
+ return items
439
+
440
+
441
+ def align_reference_to_segments(
442
+ utterances: List[Tuple[str, str]],
443
+ hyp_segments: List[object],
444
+ min_score: float = 0.20,
445
+ ) -> List[Tuple[str, float, float]]:
446
+ """Align reference speaker utterances to hypothesis transcript segments.
447
+
448
+ Strategy (simple heuristic):
449
+ - Iterate utterances in order and try to find the best contiguous window of
450
+ hypothesis segments (starting from last matched index) whose combined
451
+ words have maximal overlap with the reference utterance words.
452
+ - Overlap score = intersection_words / reference_word_count.
453
+ - Accept match if score >= min_score; assign start/end from matched segments.
454
+
455
+ Returns list of (speaker_id, start, end).
456
+ """
457
+ if not utterances or not hyp_segments:
458
+ return []
459
+
460
+ # Precompute normalized words for hypothesis segments
461
+ hyp_words = []
462
+ for seg in hyp_segments:
463
+ txt = getattr(seg, "text", "") or ""
464
+ words = [w.lower() for w in re.findall(r"\w+", txt)]
465
+ hyp_words.append(words)
466
+
467
+ results: List[Tuple[str, float, float]] = []
468
+ cur_idx = 0
469
+
470
+ for speaker, ref_text in utterances:
471
+ ref_tokens = [w.lower() for w in re.findall(r"\w+", ref_text)]
472
+ if not ref_tokens:
473
+ continue
474
+ ref_set = set(ref_tokens)
475
+
476
+ best_score = 0.0
477
+ best_j = None
478
+ best_k = None
479
+
480
+ # Search windows starting at cur_idx
481
+ for j in range(cur_idx, len(hyp_segments)):
482
+ combined = []
483
+ for k in range(j, len(hyp_segments)):
484
+ combined.extend(hyp_words[k])
485
+ if not combined:
486
+ continue
487
+ comb_set = set(combined)
488
+ score = len(ref_set & comb_set) / max(1, len(ref_set))
489
+
490
+ if score > best_score:
491
+ best_score = score
492
+ best_j = j
493
+ best_k = k
494
+
495
+ # early break if we reach high confidence
496
+ if score >= 0.75:
497
+ break
498
+
499
+ if best_j is not None and best_score >= min_score:
500
+ start = float(getattr(hyp_segments[best_j], "start", 0.0))
501
+ end = float(getattr(hyp_segments[best_k], "end", start))
502
+ spk = re.sub(r"[^0-9A-Za-z_\-]", "_", speaker)
503
+ results.append((spk, start, end))
504
+ cur_idx = best_k + 1
505
+ else:
506
+ # If no match found, skip (could be silence/non-speech)
507
+ continue
508
+
509
+ return results
510
+
511
+
512
+ def create_ground_truth_template(
513
+ output_path: Union[str, Path], audio_duration: float, num_speakers: int = 2
514
+ ):
515
+ """
516
+ Create template ground truth files for annotation.
517
+
518
+ Args:
519
+ output_path: Output directory
520
+ audio_duration: Duration of audio in seconds
521
+ num_speakers: Expected number of speakers
522
+ """
523
+ output_path = Path(output_path)
524
+ output_path.mkdir(parents=True, exist_ok=True)
525
+
526
+ # Create transcript template
527
+ transcript_template = """# Ground Truth Transcript
528
+ # Instruksi: Tulis transkripsi lengkap audio di bawah ini
529
+ # Hapus baris komentar (yang dimulai dengan #) sebelum evaluasi
530
+
531
+ [Tulis transkripsi di sini...]
532
+ """
533
+
534
+ with open(output_path / "transcript.txt", "w", encoding="utf-8") as f:
535
+ f.write(transcript_template)
536
+
537
+ # Create RTTM template
538
+ rttm_template = f"""# Ground Truth Diarization (RTTM Format)
539
+ # Format: SPEAKER <file_id> <channel> <start_time> <duration> <NA> <NA> <speaker_id> <NA> <NA>
540
+ #
541
+ # Contoh:
542
+ # SPEAKER audio 1 0.0 5.5 <NA> <NA> SPEAKER_00 <NA> <NA>
543
+ # SPEAKER audio 1 5.5 3.2 <NA> <NA> SPEAKER_01 <NA> <NA>
544
+ #
545
+ # Audio duration: {audio_duration:.2f} seconds
546
+ # Expected speakers: {num_speakers}
547
+ #
548
+ # Tambahkan baris SPEAKER di bawah:
549
+
550
+ """
551
+
552
+ with open(output_path / "diarization.rttm", "w", encoding="utf-8") as f:
553
+ f.write(rttm_template)
554
+
555
+ print(f"Ground truth templates created in: {output_path}")