File size: 2,191 Bytes
d7a2919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Audio utility functions for the diarization pipeline."""

import io
import numpy as np
import torch
import torchaudio
from pathlib import Path
from typing import Union, Tuple, Iterator
from loguru import logger

SUPPORTED_FORMATS = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm"}
TARGET_SAMPLE_RATE = 16000


def load_audio(source, target_sr: int = TARGET_SAMPLE_RATE) -> Tuple[torch.Tensor, int]:
    if isinstance(source, bytes):
        source = io.BytesIO(source)
    waveform, sr = torchaudio.load(source)
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    if sr != target_sr:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
        waveform = resampler(waveform)
        sr = target_sr
    return waveform.squeeze(0), sr


def pcm_bytes_to_tensor(data: bytes, dtype=np.float32) -> torch.Tensor:
    arr = np.frombuffer(data, dtype=dtype).copy()
    return torch.from_numpy(arr)


def chunk_audio(audio, sample_rate, chunk_duration=30.0, overlap=1.0):
    chunk_samples = int(chunk_duration * sample_rate)
    step_samples = int((chunk_duration - overlap) * sample_rate)
    n = len(audio)
    for start in range(0, n, step_samples):
        end = min(start + chunk_samples, n)
        yield audio[start:end], start / sample_rate
        if end == n:
            break


def format_timestamp(seconds: float) -> str:
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = seconds % 60
    return f"{hours:02d}:{minutes:02d}:{secs:06.3f}"


def segments_to_rttm(segments, audio_name: str = "audio") -> str:
    lines = []
    for seg in segments:
        duration = seg.end - seg.start
        lines.append(
            f"SPEAKER {audio_name} 1 {seg.start:.3f} {duration:.3f} "
            f"<NA> <NA> {seg.speaker} <NA> <NA>"
        )
    return "\n".join(lines)


def segments_to_srt(segments) -> str:
    lines = []
    for i, seg in enumerate(segments, 1):
        start = format_timestamp(seg.start).replace(".", ",")
        end = format_timestamp(seg.end).replace(".", ",")
        lines.append(f"{i}\n{start} --> {end}\n[{seg.speaker}]\n")
    return "\n".join(lines)