File size: 15,741 Bytes
752a399
4d6b6c4
 
 
 
df0c7b6
539bad7
4d48604
4d6b6c4
 
 
539bad7
4d6b6c4
 
c90efd5
45711db
 
 
c21ae2d
9bb731f
4d6b6c4
 
 
 
 
c90efd5
 
 
 
 
 
 
 
 
4d6b6c4
 
 
 
 
 
 
 
9bb731f
752a399
 
4d6b6c4
7658895
4d6b6c4
278caca
1c555c0
 
 
 
 
234f598
1c555c0
 
 
 
 
 
 
 
 
 
 
4d48604
9b7234d
4d48604
9bc34e0
278caca
9b7234d
4d48604
 
 
9bc34e0
4d48604
 
 
 
 
 
 
 
 
278caca
9b7234d
 
278caca
9b7234d
234f598
 
 
 
 
 
 
 
 
 
 
 
 
 
278caca
 
22c6367
4f7e89c
 
 
 
 
 
 
 
 
 
 
 
 
c90efd5
 
4f7e89c
 
 
 
539bad7
 
4d48604
 
234f598
 
4d48604
 
 
 
 
 
 
 
 
 
 
 
 
 
234f598
 
4d48604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2392b11
9bc34e0
 
 
 
 
 
 
 
 
4d6b6c4
 
 
 
 
f6b7ada
cfa4101
4d48604
 
4d6b6c4
539bad7
4d6b6c4
539bad7
9a8a554
539bad7
64efa14
4d6b6c4
539bad7
 
64efa14
df0c7b6
1c555c0
c90efd5
 
df0c7b6
c90efd5
3e50eb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d4a364
c90efd5
 
 
 
 
 
 
 
 
 
4d48604
1c555c0
 
 
 
 
 
 
 
234f598
 
 
4d48604
 
 
c90efd5
9bc34e0
 
c90efd5
 
 
 
 
 
 
 
3e50eb1
c90efd5
9bc34e0
c90efd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e50eb1
ce553db
c90efd5
 
 
ce553db
 
 
 
 
 
 
 
c90efd5
ce553db
4d48604
ce553db
cf8a520
ce553db
12cdd04
 
 
4d48604
 
 
 
234f598
4d48604
45711db
cf8a520
4d48604
 
 
 
 
 
 
 
 
 
 
45711db
 
12cdd04
9bc34e0
cf8a520
ce553db
4d48604
 
ce553db
4d48604
 
 
 
 
ce553db
4d48604
ce553db
bad4ead
4d48604
9b7234d
4d48604
9b7234d
4d48604
ce553db
4d48604
ce553db
9b7234d
bad4ead
 
ce553db
cf8a520
ce553db
 
 
 
 
 
4d48604
ce553db
45711db
9a8a554
1e1214b
 
 
 
c90efd5
45711db
8ce75a0
ce553db
4d4a364
c90efd5
4d4a364
9bb731f
c90efd5
 
64efa14
97bec3e
752a399
9bb731f
 
c90efd5
539bad7
4d6b6c4
c90efd5
 
4d6b6c4
 
9bb731f
 
4d6b6c4
9a8a554
4d6b6c4
02fb5b8
4f7e89c
1e1214b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f7e89c
 
 
234f598
 
4f7e89c
c90efd5
 
 
 
 
4f7e89c
 
 
c90efd5
4f7e89c
 
 
 
c90efd5
4f7e89c
234f598
4f7e89c
c90efd5
22c6367
c90efd5
234f598
9bc34e0
c90efd5
 
 
4f7e89c
 
 
 
 
 
02fb5b8
4d6b6c4
752a399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d6b6c4
 
 
752a399
 
4d6b6c4
 
 
752a399
4d6b6c4
752a399
 
 
 
efb3ddd
4d6b6c4
 
752a399
 
 
 
efb3ddd
752a399
 
4d6b6c4
539bad7
4d6b6c4
539bad7
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559

import logging
import subprocess
import time
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from collections import defaultdict, Counter

import numpy as np
import librosa
import torch

from app.core.config import get_settings
from app.services.transcription import TranscriptionService
from app.services.alignment import AlignmentService
from app.services.transcription import WordTimestamp


from app.services.diarization import DiarizationService, SpeakerSegment, DiarizationResult

logger = logging.getLogger(__name__)
settings = get_settings()


@dataclass
class TranscriptSegment:
    """A transcribed segment with speaker info."""
    start: float
    end: float
    speaker: str
    role: Optional[str]
    text: str


@dataclass
class ProcessingResult:
    """Result of audio processing."""
    segments: List[TranscriptSegment]
    speaker_count: int
    duration: float
    processing_time: float
    speakers: List[str]
    roles: Dict[str, str]

    txt_content: str = ""
    csv_content: str = ""


def pad_and_refine_tensor(
    waveform: torch.Tensor,
    sr: int,
    start_s: float,
    end_s: float,
    pad_ms: int = 250,
) -> Tuple[float, float]:

    total_len = waveform.shape[1]
    s = max(int((start_s - pad_ms / 1000) * sr), 0)
    e = min(int((end_s + pad_ms / 1000) * sr), total_len)

    if e <= s:
        return start_s, end_s

    return s / sr, e / sr


def normalize_asr_result(result: dict):

    words = []

    for w in result.get("words", []):

        word = w.get("word", "").strip()
        if not word:
            continue

        words.append(
            {
                "word": word,
                "start": float(w["start"]),
                "end": float(w["end"]),
                "speaker": w.get("speaker"),
            }
        )

    text = result.get("text", "").strip()
    return text, words


def guess_speaker_by_overlap(start, end, diar_segments):

    best_spk = None
    best_overlap = 0.0

    for seg in diar_segments:

        overlap = max(0.0, min(end, seg.end) - max(start, seg.start))

        if overlap > best_overlap:
            best_overlap = overlap
            best_spk = seg.speaker

    return best_spk or diar_segments[0].speaker



def convert_audio_to_wav(audio_path: Path) -> Path:
    """Convert any audio to WAV 16kHz Mono using ffmpeg."""
    output_path = audio_path.parent / f"{audio_path.stem}_processed.wav"
    if output_path.exists():
        output_path.unlink()
    command = ["ffmpeg", "-i", str(audio_path), "-ar", "16000", "-ac", "1", "-y", str(output_path)]
    try:
        subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        logger.info(f"Converted audio to WAV: {output_path}")
        return output_path
    except subprocess.CalledProcessError as e:
        logger.error(f"FFmpeg conversion failed: {e}")
        return audio_path


def format_timestamp(seconds: float) -> str:
    m = int(seconds // 60)
    s = seconds % 60
    return f"{m:02d}:{s:06.3f}"


def merge_consecutive_segments(
    segments: List[SpeakerSegment], 
    max_gap: float = 0.8,
    min_duration: float = 0.15,
) -> List[SpeakerSegment]:
    """Merge consecutive segments from same speaker."""
    if not segments:
        return []
    
    merged = []
    current = SpeakerSegment(
        start=segments[0].start,
        end=segments[0].end,
        speaker=segments[0].speaker
    )
    
    for seg in segments[1:]:
        seg_dur = seg.end - seg.start
        if (seg.speaker == current.speaker and (seg.start - current.end) <= max_gap
            or seg_dur < min_duration):
            # Merge: extend current segment
                current.end = seg.end
        else:
            # New speaker or gap too large
            merged.append(current)
            current = SpeakerSegment(
                start=seg.start,
                end=seg.end,
                speaker=seg.speaker
            )
    
    merged.append(current)
    return merged


def overlap_prefix(a: str, b: str, n: int = 12) -> bool:
    if not a or not b:
        return False

    a = a.strip().lower()
    b = b.strip().lower()

    return a[:n] in b or b[:n] in a


class Processor:
    @classmethod
    async def process_audio(
        cls,
        audio_path: Path,
        model_name: str = "PhoWhisper Lora Finetuned",
        language="vi",
        merge_segments: bool = True,
           
    ) -> ProcessingResult:

        import asyncio

        t0= time.time()

        # 1: Convert to WAV
        logger.info("Step 1: Converting audio to WAV 16kHz...")
        wav_path = await asyncio.get_event_loop().run_in_executor(None, convert_audio_to_wav, audio_path)

        # 2: Load audio
        y, sr = librosa.load(wav_path, sr=16000, mono=True)
        waveform = torch.from_numpy(y).unsqueeze(0)
        if y.size == 0:
            raise ValueError("Empty audio")
        duration = len(y) / sr
        
        # 3: Run diarization and ASR in parallel
        logger.info("Step 3+7: Running diarization and ASR in parallel...")

        diarization_task = asyncio.create_task(
            DiarizationService.diarize_async(wav_path)
        )

        asr_task = asyncio.create_task(
            TranscriptionService.transcribe_with_words_async(
                audio_array=y,
                model_name=model_name,
                language=language,
                vad_options=True
            )
        )

        try:
            diarization, asr_result = await asyncio.gather(
                diarization_task,
                asr_task
            )
        except Exception:
            logger.exception("Parallel AI processing failed")
            raise


        diarization_segments = diarization.segments or []
        speakers = diarization.speakers or []
        roles = diarization.roles or {}

        if not diarization_segments:
            diarization_segments = [SpeakerSegment(0.0, duration, "SPEAKER_0")]
            speakers = ["SPEAKER_0"]
            roles = {"SPEAKER_0": "KH"}

        diarization_segments.sort(key=lambda x: x.start)
        
        diarization_segments = [
            SpeakerSegment(
                *pad_and_refine_tensor(waveform, sr, s.start, s.end),
                speaker=s.speaker,
            )
            for s in diarization_segments
        ]
        
        diarization_segments.sort(key=lambda x: x.start)

        
        if merge_segments and diarization_segments:
            logger.info("Step 4: Merging consecutive segments...")
            diarization_segments = merge_consecutive_segments(diarization_segments)
    
        # 4. Normalize speakers
        raw_speakers = sorted({seg.speaker for seg in diarization_segments})

        speaker_map = {
            spk: f"Speaker {i+1}"
            for i, spk in enumerate(raw_speakers)
        }

        speakers = list(speaker_map.values())

        # 5. Normalize roles
        speaker_duration = defaultdict(float)
        for seg in diarization_segments:
            speaker_duration[seg.speaker] += seg.end - seg.start

        logger.info(f"speaker_duration(raw) = {speaker_duration}")

        if speaker_duration:
            agent_raw = max(speaker_duration, key=speaker_duration.get)

            roles = {
                speaker_map[spk]: ("NV" if spk == agent_raw else "KH")
                for spk in speaker_duration
            }
        else:
            roles = {}

        # Default fallback
        for label in speakers:
            roles.setdefault(label, "KH")

        logger.info(f"roles(mapped) = {roles}")

        # 7: Normalize asr result
        text, raw_words = normalize_asr_result(asr_result)

        processed_segments: List[TranscriptSegment] = []

        if not raw_words:
            processed_segments = [
                TranscriptSegment(
                    start=0.0,
                    end=duration,
                    speaker=speakers[0],
                    role=roles[speakers[0]],
                    text="(No speech detected)"
                )
            ]
            
        else:

            # ===== CONVERT TO WordTimestamp =====
            word_objs: List[WordTimestamp] = []

            for w in raw_words:

                spk = w.get("speaker")

                if spk is None:
                    spk = guess_speaker_by_overlap(
                        w["start"], w["end"], diarization_segments
                    )

                word_objs.append(
                    WordTimestamp(
                        word=w["word"],
                        start=w["start"],
                        end=w["end"],
                        speaker=spk,
                    )
                )

            word_objs.sort(key=lambda x: x.start)
            
            # ===== ALIGNMENT =====
            aligned_segments = AlignmentService.align_precision(
                word_objs,
                diarization_segments
            )

            processed_segments = []

            if not aligned_segments:

                vote = [w.speaker for w in word_objs if w.speaker]

                if vote:
                    raw_spk = Counter(vote).most_common(1)[0][0]
                else:
                    raw_spk = diarization_segments[0].speaker

                label = speaker_map.get(raw_spk, "Speaker 1")

                processed_segments.append(
                    TranscriptSegment(0, duration, label, roles[label], text)
                )

            else:

                for seg in aligned_segments:

                    raw_spk = seg.speaker
                    label = speaker_map.get(raw_spk, "Speaker 1")
                    role = roles.get(label, "KH")

                    processed_segments.append(
                        TranscriptSegment(
                            start=seg.start,
                            end=seg.end,
                            speaker=label,
                            role=role,
                            text=seg.text,
                        )
                    )

        processed_segments = cls._conversation_correction(processed_segments)

        processed_segments = cls._sync_speaker_with_role(processed_segments)

        processed_segments = cls._merge_adjacent_segments(
            processed_segments
        )
        processed_segments.sort(key=lambda x: x.start)

        processing_time = time.time() - t0
        
        txt_content = cls._generate_txt(
            processed_segments,
            len(speakers),
            processing_time,
            duration,
            roles
        )

        csv_content = cls._generate_csv(processed_segments)

        return ProcessingResult(
            segments=processed_segments,
            speaker_count=len(speakers),
            duration=duration,
            processing_time=processing_time,
            speakers=speakers,
            roles=roles,
            txt_content=txt_content,
            csv_content=csv_content,
        )
    
    
    @staticmethod
    def _conversation_correction(
        segments: List[TranscriptSegment],
        ack_max_duration: float = 1.2,
        turn_gap: float = 0.6,
    ) -> List[TranscriptSegment]:

        if len(segments) < 3:
            return segments

        ACK_WORDS = {
            "dạ", "vâng", "ừ", "ừm", "uh", "ok", "okay", "ạ", "dạ vâng"
        }

        corrected = segments.copy()

        for i in range(1, len(corrected) - 1):

            prev_seg = corrected[i - 1]
            seg = corrected[i]
            next_seg = corrected[i + 1]

            seg_dur = seg.end - seg.start
            gap_prev = seg.start - prev_seg.end
            gap_next = next_seg.start - seg.end

            text_clean = seg.text.lower().strip()

            if (
                seg.role == "NV"
                and seg_dur <= ack_max_duration
                and text_clean in ACK_WORDS
                and prev_seg.role == "NV"
                and next_seg.role == "NV"
            ):
                seg.role = "KH"


            if (
                seg_dur <= ack_max_duration
                and gap_prev <= turn_gap
                and gap_next <= turn_gap
                and prev_seg.role == next_seg.role
                and seg.role != prev_seg.role
            ):
                # Keep KH interruption
                if seg.role == "KH":
                    continue

                # Otherwise flip back to surrounding speaker
                seg.role = prev_seg.role

        return corrected

    @staticmethod
    def _sync_speaker_with_role(
        segments: List[TranscriptSegment]
    ) -> List[TranscriptSegment]:

        for seg in segments:
            if seg.role == "NV":
                seg.speaker = "Speaker 1"
            else:
                seg.speaker = "Speaker 2"

        return segments

    
    @staticmethod
    def _merge_adjacent_segments(
        segments: List[TranscriptSegment],
        max_gap_s: float = 0.8,
        max_segment_duration: float = 9.0
    ) -> List[TranscriptSegment]:
        """
        Merge adjacent segments if:
        - same speaker
        - gap <= max_gap_s
        """
        if not segments:
            return segments

        segments = sorted(segments, key=lambda s: s.start)
        merged = [segments[0]]

        for seg in segments[1:]:
            prev = merged[-1]

            gap = seg.start - prev.end
            combined_duration = seg.end - prev.start

            if (
                seg.speaker == prev.speaker  and seg.role == prev.role
                and gap <= max_gap_s
                and combined_duration <= max_segment_duration
                and not overlap_prefix(seg.text, prev.text)
            ):
                # MERGE
                prev.text = f"{prev.text} {seg.text}".strip()
                prev.end = max(prev.end, seg.end)
            else:
                merged.append(seg)

        return merged

        
    @classmethod
    def _generate_txt(
            cls,
            segments: List[TranscriptSegment],
            speaker_count: int,
            processing_time: float,
            duration: float,
            roles: Dict[str, str],
        ) -> str:

        segments = sorted(segments, key=lambda s: s.start)
        speakers = []
        for seg in segments:
            if seg.speaker and seg.speaker not in speakers:
                speakers.append(seg.speaker)
                
        lines = [
            "# Transcription Result",
            f"# Duration: {format_timestamp(duration)}",
            f"# Speakers: {speaker_count}",
            f"# Roles: {roles}",
            f"# Processing time: {processing_time:.1f}s",
            "",
        ]
        icon_pool = ["🔵", "🟢", "🟡", "🟠", "🔴", "🟣"]
        speaker_icons = {
            spk: icon_pool[i % len(icon_pool)]
            for i, spk in enumerate(speakers)
        }


        for seg in segments:
            ts = f"[{format_timestamp(seg.start)}{format_timestamp(seg.end)}]"
            role = seg.role or "UNKNOWN"

            speaker_icon = speaker_icons.get(seg.speaker, "⚪")
            lines.append(
                f"{ts} {speaker_icon} [{seg.speaker}|{role}] {seg.text}"
            )

        return "\n".join(lines)

    @classmethod
    def _generate_csv(cls, segments: List[TranscriptSegment]) -> str:
        import csv
        from io import StringIO

        output = StringIO()
        writer = csv.writer(output)
        writer.writerow(["start", "end", "speaker", "text"])
        for seg in segments:
            writer.writerow([round(seg.start, 3), round(seg.end, 3), seg.speaker, seg.text])
        return output.getvalue()