File size: 5,608 Bytes
b924e1d
 
 
 
4193ca8
 
b924e1d
 
 
 
 
4193ca8
b924e1d
 
 
 
 
 
 
 
 
 
 
 
 
 
e169a58
b924e1d
 
 
4193ca8
 
b924e1d
c2b3996
b924e1d
 
d2997d8
e169a58
 
88a86dd
d2997d8
e169a58
 
 
b503b1d
 
 
 
 
e169a58
 
71155dc
 
 
 
 
 
 
 
 
 
 
 
 
4193ca8
71155dc
81c336f
71155dc
81c336f
 
71155dc
81c336f
 
b924e1d
71155dc
b924e1d
 
 
 
 
 
 
 
 
 
 
 
 
4193ca8
 
 
b924e1d
4193ca8
 
 
 
 
 
 
 
 
 
 
 
 
b924e1d
 
4193ca8
 
 
 
 
 
 
 
 
 
b924e1d
4193ca8
b924e1d
 
 
 
 
 
 
 
 
4193ca8
 
b924e1d
 
4193ca8
 
b924e1d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Any, Dict, Optional
import shutil

import torch
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook

from .utils import ensure_audio_path, read_hf_token, convert_to_wav_16k


@dataclass
class Segment:
    start: float
    end: float
    speaker: str


class DiarizationEngine:
    """Bao gói pipeline diarization của pyannote."""

    def __init__(
        self,
        model_id: str = "pyannote/speaker-diarization-3.1",
        token: str | None = None,
        key_path: str | Path = "hugging_face_key.txt",
        device: str = "auto",
        segmentation_params: Optional[Dict[str, float]] = None,
        clustering_params: Optional[Dict[str, float]] = None,
    ) -> None:
        import sys
        self.device = self._resolve_device(device)
        auth_token = read_hf_token(token, key_path)
        
        # Load pipeline with authentication
        print(f"DEBUG: Loading model {model_id} with token={'***' if auth_token else 'None'}", file=sys.stderr)
        pipeline = Pipeline.from_pretrained(model_id, use_auth_token=auth_token)
        
        if pipeline is None:
            raise RuntimeError(
                f"Failed to load pipeline '{model_id}'. "
                f"IMPORTANT: You need to accept terms for ALL these models:\n"
                f"  1. https://hf.co/pyannote/speaker-diarization-3.1\n"
                f"  2. https://hf.co/pyannote/segmentation-3.0\n"
                f"  3. https://hf.co/pyannote/embedding\n"
                f"After accepting, add HF_TOKEN to Space secrets with your token."
            )
        
        # Get default parameters and customize if needed
        try:
            params = pipeline.default_parameters()
        except NotImplementedError:
            # If no default parameters, try to instantiate without params
            params = {}
        
        print(f"DEBUG: Pipeline params: {params}", file=sys.stderr)
        
        # Update segmentation params if available
        if "segmentation" in params and segmentation_params:
            params["segmentation"].update(segmentation_params)
        if "clustering" in params and clustering_params:
            params["clustering"].update(clustering_params)
        
        # Instantiate pipeline with parameters (modifies in-place and returns self)
        print(f"DEBUG: Instantiating pipeline...", file=sys.stderr)
        pipeline.instantiate(params)
        print(f"DEBUG: Pipeline instantiated successfully", file=sys.stderr)
        
        # Store and move to device
        self.pipeline = pipeline
        self.pipeline.to(self.device)
        print(f"DEBUG: Pipeline moved to device: {self.device}", file=sys.stderr)

    @staticmethod
    def _resolve_device(device: str) -> torch.device:
        if device == "cpu":
            return torch.device("cpu")
        if device == "cuda":
            if not torch.cuda.is_available():
                raise RuntimeError("Yêu cầu CUDA nhưng không phát hiện GPU khả dụng.")
            return torch.device("cuda")
        if device == "auto":
            return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        raise ValueError("Giá trị device hợp lệ: auto, cpu, cuda.")

    def diarize(
        self, audio_path: str | Path, show_progress: bool = True, keep_audio: bool = False
    ):
        audio_path = ensure_audio_path(audio_path)
        prepared_path, tmpdir = convert_to_wav_16k(audio_path)
        try:
            if show_progress:
                with ProgressHook() as hook:
                    result = self.pipeline(str(prepared_path), hook=hook)
            else:
                result = self.pipeline(str(prepared_path))
            if keep_audio:
                return result, prepared_path, tmpdir
            return result
        finally:
            if tmpdir and not keep_audio:
                shutil.rmtree(tmpdir, ignore_errors=True)

    @staticmethod
    def _get_annotation(diarization: Any):
        """Hỗ trợ cả dạng trả về cũ (Annotation) và mới (có speaker_diarization)."""
        if hasattr(diarization, "itertracks"):
            return diarization
        if hasattr(diarization, "speaker_diarization"):
            return diarization.speaker_diarization
        raise TypeError("Output pipeline không có Annotation hoặc speaker_diarization.")

    def to_segments(self, diarization: Any) -> List[Segment]:
        annotation = self._get_annotation(diarization)
        segments: List[Segment] = []
        for segment, _, speaker in annotation.itertracks(yield_label=True):
            segments.append(
                Segment(
                    start=float(segment.start),
                    end=float(segment.end),
                    speaker=str(speaker),
                )
            )
        return segments

    def save_rttm(self, diarization: Any, output_path: str | Path) -> Path:
        annotation = self._get_annotation(diarization)
        path = Path(output_path)
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open("w", encoding="utf-8") as f:
            annotation.write_rttm(f)
        return path

    def run(self, audio_path: str | Path, show_progress: bool = True) -> List[Segment]:
        """Chạy pipeline và trả về danh sách segment."""
        diarization = self.diarize(audio_path, show_progress=show_progress)
        return self.to_segments(diarization)