Spaces:
Running
Running
File size: 5,608 Bytes
b924e1d 4193ca8 b924e1d 4193ca8 b924e1d e169a58 b924e1d 4193ca8 b924e1d c2b3996 b924e1d d2997d8 e169a58 88a86dd d2997d8 e169a58 b503b1d e169a58 71155dc 4193ca8 71155dc 81c336f 71155dc 81c336f 71155dc 81c336f b924e1d 71155dc b924e1d 4193ca8 b924e1d 4193ca8 b924e1d 4193ca8 b924e1d 4193ca8 b924e1d 4193ca8 b924e1d 4193ca8 b924e1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Any, Dict, Optional
import shutil
import torch
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook
from .utils import ensure_audio_path, read_hf_token, convert_to_wav_16k
@dataclass
class Segment:
start: float
end: float
speaker: str
class DiarizationEngine:
"""Bao gói pipeline diarization của pyannote."""
def __init__(
self,
model_id: str = "pyannote/speaker-diarization-3.1",
token: str | None = None,
key_path: str | Path = "hugging_face_key.txt",
device: str = "auto",
segmentation_params: Optional[Dict[str, float]] = None,
clustering_params: Optional[Dict[str, float]] = None,
) -> None:
import sys
self.device = self._resolve_device(device)
auth_token = read_hf_token(token, key_path)
# Load pipeline with authentication
print(f"DEBUG: Loading model {model_id} with token={'***' if auth_token else 'None'}", file=sys.stderr)
pipeline = Pipeline.from_pretrained(model_id, use_auth_token=auth_token)
if pipeline is None:
raise RuntimeError(
f"Failed to load pipeline '{model_id}'. "
f"IMPORTANT: You need to accept terms for ALL these models:\n"
f" 1. https://hf.co/pyannote/speaker-diarization-3.1\n"
f" 2. https://hf.co/pyannote/segmentation-3.0\n"
f" 3. https://hf.co/pyannote/embedding\n"
f"After accepting, add HF_TOKEN to Space secrets with your token."
)
# Get default parameters and customize if needed
try:
params = pipeline.default_parameters()
except NotImplementedError:
# If no default parameters, try to instantiate without params
params = {}
print(f"DEBUG: Pipeline params: {params}", file=sys.stderr)
# Update segmentation params if available
if "segmentation" in params and segmentation_params:
params["segmentation"].update(segmentation_params)
if "clustering" in params and clustering_params:
params["clustering"].update(clustering_params)
# Instantiate pipeline with parameters (modifies in-place and returns self)
print(f"DEBUG: Instantiating pipeline...", file=sys.stderr)
pipeline.instantiate(params)
print(f"DEBUG: Pipeline instantiated successfully", file=sys.stderr)
# Store and move to device
self.pipeline = pipeline
self.pipeline.to(self.device)
print(f"DEBUG: Pipeline moved to device: {self.device}", file=sys.stderr)
@staticmethod
def _resolve_device(device: str) -> torch.device:
if device == "cpu":
return torch.device("cpu")
if device == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("Yêu cầu CUDA nhưng không phát hiện GPU khả dụng.")
return torch.device("cuda")
if device == "auto":
return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
raise ValueError("Giá trị device hợp lệ: auto, cpu, cuda.")
def diarize(
self, audio_path: str | Path, show_progress: bool = True, keep_audio: bool = False
):
audio_path = ensure_audio_path(audio_path)
prepared_path, tmpdir = convert_to_wav_16k(audio_path)
try:
if show_progress:
with ProgressHook() as hook:
result = self.pipeline(str(prepared_path), hook=hook)
else:
result = self.pipeline(str(prepared_path))
if keep_audio:
return result, prepared_path, tmpdir
return result
finally:
if tmpdir and not keep_audio:
shutil.rmtree(tmpdir, ignore_errors=True)
@staticmethod
def _get_annotation(diarization: Any):
"""Hỗ trợ cả dạng trả về cũ (Annotation) và mới (có speaker_diarization)."""
if hasattr(diarization, "itertracks"):
return diarization
if hasattr(diarization, "speaker_diarization"):
return diarization.speaker_diarization
raise TypeError("Output pipeline không có Annotation hoặc speaker_diarization.")
def to_segments(self, diarization: Any) -> List[Segment]:
annotation = self._get_annotation(diarization)
segments: List[Segment] = []
for segment, _, speaker in annotation.itertracks(yield_label=True):
segments.append(
Segment(
start=float(segment.start),
end=float(segment.end),
speaker=str(speaker),
)
)
return segments
def save_rttm(self, diarization: Any, output_path: str | Path) -> Path:
annotation = self._get_annotation(diarization)
path = Path(output_path)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
annotation.write_rttm(f)
return path
def run(self, audio_path: str | Path, show_progress: bool = True) -> List[Segment]:
"""Chạy pipeline và trả về danh sách segment."""
diarization = self.diarize(audio_path, show_progress=show_progress)
return self.to_segments(diarization)
|