File size: 6,260 Bytes
b924e1d
 
 
 
4193ca8
 
 
b924e1d
4193ca8
b924e1d
 
 
 
 
 
56dc9d8
b924e1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4193ca8
 
 
 
 
 
b924e1d
 
 
 
 
 
 
4193ca8
 
 
b924e1d
4193ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56dc9d8
 
 
 
 
 
 
 
 
4193ca8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from __future__ import annotations

import json
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Iterable, List, Tuple


def read_hf_token(token: str | None = None, key_path: str | Path = "hugging_face_key.txt") -> str:
    """Ưu tiên token truyền vào, nếu không thì đọc từ biến môi trường hoặc file."""
    candidates = [
        token,
        os.getenv("HF_TOKEN"),  # HF Spaces standard secret name
        os.getenv("HUGGINGFACE_TOKEN"),
        os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
    ]
    for value in candidates:
        if value and value.strip():
            return value.strip()
    path = Path(key_path)
    if not path.exists():
        raise FileNotFoundError(
            f"Không tìm thấy token. Truyền biến --hf-token hoặc đặt file {path}."
        )
    content = path.read_text(encoding="utf-8").strip()
    if not content:
        raise ValueError(f"File {path} trống, hãy dán Hugging Face access token vào.")
    return content


def ensure_audio_path(audio_path: str | Path) -> Path:
    """Kiểm tra đường dẫn audio hợp lệ."""
    path = Path(audio_path)
    if not path.exists():
        raise FileNotFoundError(f"Không tìm thấy file âm thanh: {path}")
    if not path.is_file():
        raise ValueError(f"Đường dẫn không phải file: {path}")
    return path


def export_segments_json(segments: Iterable[dict], output_path: str | Path) -> Path:
    """Lưu danh sách segment thành JSON."""
    path = Path(output_path)
    path.parent.mkdir(parents=True, exist_ok=True)
    data: List[dict] = list(segments)
    path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
    return path


def seconds_to_mmss(seconds: float) -> str:
    total_seconds = int(round(seconds))
    minutes, sec = divmod(total_seconds, 60)
    return f"{minutes:02d}:{sec:02d}"


def format_segments_table(segments: Iterable[dict]) -> str:
    """Trả về chuỗi bảng đơn giản để in ra terminal."""
    lines = []
    for idx, seg in enumerate(segments, start=1):
        start = seg.get("start", 0.0)
        end = seg.get("end", 0.0)
        speaker = seg.get("speaker", "unknown")
        lines.append(
            f"{idx:02d} | {seconds_to_mmss(start)} -> {seconds_to_mmss(end)} | speaker {speaker}"
        )
    return "\n".join(lines)


def merge_adjacent_segments(
    segments: list[dict],
    max_gap: float = 0.5,
    min_duration: float = 1.0,
) -> list[dict]:
    """
    Ghép các đoạn liên tiếp cùng speaker nếu khoảng trống <= max_gap (giây).
    Đồng thời lọc bỏ đoạn quá ngắn (< min_duration).
    """
    if not segments:
        return []

    merged: list[dict] = []
    # đảm bảo sắp xếp theo thời gian
    segs = sorted(segments, key=lambda s: s.get("start", 0.0))
    current = segs[0].copy()

    for seg in segs[1:]:
        if (
            seg.get("speaker") == current.get("speaker")
            and seg.get("start", 0.0) - current.get("end", 0.0) <= max_gap
        ):
            current["end"] = max(current.get("end", 0.0), seg.get("end", 0.0))
        else:
            if current.get("end", 0.0) - current.get("start", 0.0) >= min_duration:
                merged.append(current)
            current = seg.copy()

    if current.get("end", 0.0) - current.get("start", 0.0) >= min_duration:
        merged.append(current)

    return merged


def convert_to_wav_16k(audio_path: Path) -> Tuple[Path, Path | None]:
    """
    Chuyển audio về WAV mono 16 kHz bằng ffmpeg.
    Trả về (đường dẫn dùng để suy luận, thư mục tạm để dọn dẹp hoặc None nếu không cần).
    """
    safe_stem = audio_path.stem.replace(" ", "_")
    tmpdir = Path(tempfile.mkdtemp(prefix="diarization_audio_"))

    if not shutil.which("ffmpeg"):
        # Sao chép tệp hiện có vào thư mục tạm với tên không dấu cách
        output = tmpdir / audio_path.name.replace(" ", "_")
        shutil.copy2(audio_path, output)
        return output, tmpdir

    output = tmpdir / f"{safe_stem}_16k.wav"
    cmd = [
        "ffmpeg",
        "-y",
        "-i",
        str(audio_path),
        "-ac",
        "1",
        "-ar",
        "16000",
        "-vn",
        "-f",
        "wav",
        str(output),
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        stderr = result.stderr.strip()
        raise RuntimeError(f"ffmpeg convert thất bại: {stderr}")
    if not output.exists():
        raise RuntimeError("ffmpeg không tạo được file WAV.")
    return output, tmpdir


def download_audio_from_url(url: str) -> Tuple[Path, Path]:
    """
    Tải audio từ YouTube/TikTok/... dùng yt-dlp, xuất WAV để xử lý tiếp.
    Trả về (đường dẫn file wav, thư mục tạm chứa file).
    """
    if not shutil.which("yt-dlp"):
        raise RuntimeError("Cần cài yt-dlp để tải liên kết (pip install yt-dlp).")
    if not shutil.which("ffmpeg"):
        raise RuntimeError("Cần cài ffmpeg để chuyển đổi audio.")

    tmpdir = Path(tempfile.mkdtemp(prefix="download_media_"))
    out_tmpl = tmpdir / "%(title)s.%(ext)s"
    cmd = [
        "yt-dlp",
        "-x",
        "--audio-format",
        "wav",
        "--audio-quality",
        "0",
        "-o",
        str(out_tmpl),
        url,
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        stderr = result.stderr.strip()
        # Check if it's a network error
        if "Failed to resolve" in stderr or "No address associated with hostname" in stderr:
            raise RuntimeError(
                "Không thể kết nối internet để tải video. "
                "Hugging Face Spaces (free tier) không hỗ trợ download từ URL. "
                "Vui lòng tải file audio trực tiếp thay vì dùng URL."
            )
        raise RuntimeError(f"Tải audio thất bại: {stderr}")

    wav_files = list(tmpdir.glob("*.wav"))
    if not wav_files:
        raise RuntimeError("Không tìm thấy file WAV sau khi tải.")
    return wav_files[0], tmpdir