Thanh-Lam commited on
Commit
b924e1d
·
1 Parent(s): 0741d8d
README.md CHANGED
@@ -1 +1,32 @@
1
- # Vietnamese_Diarization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vietnamese_Diarization
2
+
3
+ Kho mã mẫu diarization tiếng Việt dùng pyannote/speaker-diarization-community-1.
4
+
5
+ ## Yêu cầu
6
+ - Python 3.10+
7
+ - ffmpeg (bắt buộc cho torchcodec audio decoding)
8
+ - Đã chấp nhận điều khoản model tại https://huggingface.co/pyannote/speaker-diarization-community-1
9
+ - Hugging Face access token (dán vào hugging_face_key.txt hoặc đặt biến môi trường HUGGINGFACE_TOKEN/HUGGINGFACE_ACCESS_TOKEN)
10
+
11
+ ## Cài đặt nhanh
12
+ - Cài thư viện: `pip install pyannote.audio` hoặc `uv add pyannote.audio`
13
+ - Đảm bảo ffmpeg đã có trong PATH
14
+
15
+ ## Chạy mẫu
16
+ - Diarization và in kết quả: `python infer.py path/to/audio.wav`
17
+ - Lưu thêm RTTM: `python infer.py path/to/audio.wav --rttm outputs/audio.rttm`
18
+ - Lưu JSON: `python infer.py path/to/audio.wav --json outputs/audio.json`
19
+ - Chọn thiết bị: thêm `--device cpu` hoặc `--device cuda` (mặc định auto)
20
+
21
+ ## API Python
22
+ ```
23
+ from app import diarize_file
24
+ segments = diarize_file("audio.wav", device="auto")
25
+ ```
26
+
27
+ ## Cấu trúc
28
+ - app.py: API Python đơn giản
29
+ - infer.py: CLI chạy diarization
30
+ - src/models.py: Bao gói pipeline pyannote
31
+ - src/utils.py: Hỗ trợ đọc token, định dạng kết quả
32
+ - hugging_face_key.txt: nơi dán Hugging Face access token (không commit token thật)
__pycache__/app.cpython-312.pyc ADDED
Binary file (1.82 kB). View file
 
__pycache__/infer.cpython-312.pyc ADDED
Binary file (3 kB). View file
 
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import List
5
+
6
+ from src.models import DiarizationEngine, Segment
7
+
8
+
9
+ def diarize_file(
10
+ audio_path: str | Path,
11
+ hf_token: str | None = None,
12
+ device: str = "auto",
13
+ show_progress: bool = True,
14
+ ) -> List[Segment]:
15
+ """API đơn giản để dùng trực tiếp trong Python."""
16
+ engine = DiarizationEngine(token=hf_token, device=device)
17
+ return engine.run(audio_path, show_progress=show_progress)
18
+
19
+
20
+ if __name__ == "__main__":
21
+ # Ví dụ nhanh: python app.py audio.wav
22
+ import argparse
23
+
24
+ parser = argparse.ArgumentParser(description="Ví dụ chạy diarization qua hàm Python.")
25
+ parser.add_argument("audio", help="Đường dẫn tới file âm thanh")
26
+ parser.add_argument(
27
+ "--device",
28
+ choices=["auto", "cpu", "cuda"],
29
+ default="auto",
30
+ help="Thiết bị ưu tiên khi khởi tạo pipeline",
31
+ )
32
+ args = parser.parse_args()
33
+
34
+ segments = diarize_file(args.audio, device=args.device)
35
+ for idx, seg in enumerate(segments, start=1):
36
+ print(f"{idx:02d} | {seg.start:7.2f}s -> {seg.end:7.2f}s | speaker {seg.speaker}")
eval.py ADDED
File without changes
finetune.py ADDED
File without changes
infer.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ from src.models import DiarizationEngine
7
+ from src.utils import export_segments_json, format_segments_table
8
+
9
+
10
+ def parse_args() -> argparse.Namespace:
11
+ parser = argparse.ArgumentParser(
12
+ description="Chạy diarization bằng pyannote/speaker-diarization-community-1"
13
+ )
14
+ parser.add_argument("audio", help="Đường dẫn file âm thanh (wav, mp3, flac...)")
15
+ parser.add_argument(
16
+ "--hf-token",
17
+ dest="hf_token",
18
+ default=None,
19
+ help="Hugging Face access token, nếu bỏ trống sẽ đọc từ hugging_face_key.txt",
20
+ )
21
+ parser.add_argument(
22
+ "--device",
23
+ choices=["auto", "cpu", "cuda"],
24
+ default="auto",
25
+ help="Ưu tiên thiết bị chạy pipeline",
26
+ )
27
+ parser.add_argument(
28
+ "--no-progress",
29
+ action="store_true",
30
+ help="Tắt hiển thị tiến trình tải model/feature",
31
+ )
32
+ parser.add_argument(
33
+ "--rttm",
34
+ default=None,
35
+ help="Đường dẫn lưu file RTTM (tùy chọn)",
36
+ )
37
+ parser.add_argument(
38
+ "--json",
39
+ dest="json_out",
40
+ default=None,
41
+ help="Đường dẫn lưu kết quả dạng JSON (tùy chọn)",
42
+ )
43
+ return parser.parse_args()
44
+
45
+
46
+ def main() -> None:
47
+ args = parse_args()
48
+ engine = DiarizationEngine(token=args.hf_token, device=args.device)
49
+ diarization = engine.diarize(args.audio, show_progress=not args.no_progress)
50
+ segments = engine.to_segments(diarization)
51
+
52
+ print("Kết quả phân đoạn:")
53
+ print(format_segments_table([seg.__dict__ for seg in segments]))
54
+
55
+ if args.rttm:
56
+ rttm_path = engine.save_rttm(diarization, args.rttm)
57
+ print(f"Đã lưu RTTM tại: {rttm_path}")
58
+
59
+ if args.json_out:
60
+ json_path = export_segments_json([seg.__dict__ for seg in segments], args.json_out)
61
+ print(f"Đã lưu JSON tại: {json_path}")
62
+
63
+
64
+ if __name__ == "__main__":
65
+ main()
params/eval.yaml ADDED
File without changes
params/finetune.yaml ADDED
File without changes
params/infer.yaml ADDED
File without changes
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (112 Bytes). View file
 
src/__pycache__/models.cpython-312.pyc ADDED
Binary file (4.52 kB). View file
 
src/__pycache__/utils.cpython-312.pyc ADDED
Binary file (3.29 kB). View file
 
src/models.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Iterable, List
6
+
7
+ import torch
8
+ from pyannote.audio import Pipeline
9
+ from pyannote.audio.pipelines.utils.hook import ProgressHook
10
+
11
+ from .utils import ensure_audio_path, read_hf_token
12
+
13
+
14
+ @dataclass
15
+ class Segment:
16
+ start: float
17
+ end: float
18
+ speaker: str
19
+
20
+
21
+ class DiarizationEngine:
22
+ """Bao gói pipeline diarization của pyannote."""
23
+
24
+ def __init__(
25
+ self,
26
+ model_id: str = "pyannote/speaker-diarization-community-1",
27
+ token: str | None = None,
28
+ key_path: str | Path = "hugging_face_key.txt",
29
+ device: str = "auto",
30
+ ) -> None:
31
+ self.device = self._resolve_device(device)
32
+ auth_token = read_hf_token(token, key_path)
33
+ self.pipeline = Pipeline.from_pretrained(model_id, token=auth_token)
34
+ self.pipeline.to(self.device)
35
+
36
+ @staticmethod
37
+ def _resolve_device(device: str) -> torch.device:
38
+ if device == "cpu":
39
+ return torch.device("cpu")
40
+ if device == "cuda":
41
+ if not torch.cuda.is_available():
42
+ raise RuntimeError("Yêu cầu CUDA nhưng không phát hiện GPU khả dụng.")
43
+ return torch.device("cuda")
44
+ if device == "auto":
45
+ return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
46
+ raise ValueError("Giá trị device hợp lệ: auto, cpu, cuda.")
47
+
48
+ def diarize(self, audio_path: str | Path, show_progress: bool = True):
49
+ audio_path = ensure_audio_path(audio_path)
50
+ if show_progress:
51
+ with ProgressHook() as hook:
52
+ return self.pipeline(str(audio_path), hook=hook)
53
+ return self.pipeline(str(audio_path))
54
+
55
+ @staticmethod
56
+ def to_segments(diarization) -> List[Segment]:
57
+ segments: List[Segment] = []
58
+ for segment, _, speaker in diarization.itertracks(yield_label=True):
59
+ segments.append(
60
+ Segment(
61
+ start=float(segment.start),
62
+ end=float(segment.end),
63
+ speaker=str(speaker),
64
+ )
65
+ )
66
+ return segments
67
+
68
+ @staticmethod
69
+ def save_rttm(diarization, output_path: str | Path) -> Path:
70
+ path = Path(output_path)
71
+ path.parent.mkdir(parents=True, exist_ok=True)
72
+ diarization.write_rttm(path)
73
+ return path
74
+
75
+ def run(self, audio_path: str | Path, show_progress: bool = True) -> List[Segment]:
76
+ """Chạy pipeline và trả về danh sách segment."""
77
+ diarization = self.diarize(audio_path, show_progress=show_progress)
78
+ return self.to_segments(diarization)
src/utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Iterable, List
7
+
8
+
9
+ def read_hf_token(token: str | None = None, key_path: str | Path = "hugging_face_key.txt") -> str:
10
+ """Ưu tiên token truyền vào, nếu không thì đọc từ biến môi trường hoặc file."""
11
+ candidates = [
12
+ token,
13
+ os.getenv("HUGGINGFACE_TOKEN"),
14
+ os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
15
+ ]
16
+ for value in candidates:
17
+ if value and value.strip():
18
+ return value.strip()
19
+ path = Path(key_path)
20
+ if not path.exists():
21
+ raise FileNotFoundError(
22
+ f"Không tìm thấy token. Truyền biến --hf-token hoặc đặt file {path}."
23
+ )
24
+ content = path.read_text(encoding="utf-8").strip()
25
+ if not content:
26
+ raise ValueError(f"File {path} trống, hãy dán Hugging Face access token vào.")
27
+ return content
28
+
29
+
30
+ def ensure_audio_path(audio_path: str | Path) -> Path:
31
+ """Kiểm tra đường dẫn audio hợp lệ."""
32
+ path = Path(audio_path)
33
+ if not path.exists():
34
+ raise FileNotFoundError(f"Không tìm thấy file âm thanh: {path}")
35
+ if not path.is_file():
36
+ raise ValueError(f"Đường dẫn không phải file: {path}")
37
+ return path
38
+
39
+
40
+ def export_segments_json(segments: Iterable[dict], output_path: str | Path) -> Path:
41
+ """Lưu danh sách segment thành JSON."""
42
+ path = Path(output_path)
43
+ path.parent.mkdir(parents=True, exist_ok=True)
44
+ data: List[dict] = list(segments)
45
+ path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
46
+ return path
47
+
48
+
49
+ def format_segments_table(segments: Iterable[dict]) -> str:
50
+ """Trả về chuỗi bảng đơn giản để in ra terminal."""
51
+ lines = []
52
+ for idx, seg in enumerate(segments, start=1):
53
+ start = seg.get("start", 0.0)
54
+ end = seg.get("end", 0.0)
55
+ speaker = seg.get("speaker", "unknown")
56
+ lines.append(f"{idx:02d} | {start:7.2f}s -> {end:7.2f}s | speaker {speaker}")
57
+ return "\n".join(lines)