File size: 3,920 Bytes
0db822c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Load audio files, extract timed segments, and save them as 16 kHz mono WAV.

Includes RMS-based silence filtering so that near-silent segments (music
intros, gaps, applause) are dropped before they pollute the training set.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import List, Tuple

import librosa
import numpy as np
import soundfile as sf

from .parse_transcripts import TranscriptSegment

logger = logging.getLogger(__name__)

TARGET_SR = 16_000          # Whisper expects 16 kHz
_DEFAULT_MIN_AMPLITUDE = 0.001  # RMS below this → treat segment as silent


def load_audio(path: Path | str, sr: int = TARGET_SR) -> Tuple[np.ndarray, int]:
    """Load any audio file, resample to `sr` Hz, and convert to mono float32."""
    audio, _ = librosa.load(str(path), sr=sr, mono=True)
    return audio.astype(np.float32), sr


def get_audio_duration(path: Path | str) -> float:
    """Return duration in seconds without loading the full file."""
    return librosa.get_duration(path=str(path))


def extract_segment(
    audio: np.ndarray,
    sr: int,
    start: float,
    end: float,
) -> np.ndarray:
    """Slice `audio` between `start` and `end` seconds."""
    start_idx = int(start * sr)
    end_idx   = int(end   * sr)
    return audio[start_idx:end_idx]


def save_wav(array: np.ndarray, sr: int, path: Path | str) -> None:
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    sf.write(str(path), array, sr, subtype="PCM_16")


def _rms(array: np.ndarray) -> float:
    """Root-mean-square amplitude of an audio array."""
    if array.size == 0:
        return 0.0
    return float(np.sqrt(np.mean(array.astype(np.float64) ** 2)))


def process_pair(
    audio_path: Path | str,
    transcript_segments: List[TranscriptSegment],
    output_dir: Path | str,
    sample_rate: int = TARGET_SR,
    min_amplitude: float = _DEFAULT_MIN_AMPLITUDE,
) -> List[dict]:
    """
    Split one audio file into WAV segments aligned to transcript_segments.

    Each extracted segment is validated:
    - Empty chunks (zero samples) are skipped.
    - Near-silent chunks whose RMS amplitude is below `min_amplitude` are
      skipped — these correspond to silent gaps, music intros, or noise-only
      sections that would confuse the model.

    Returns a list of metadata dicts ready to be added to the dataset manifest.
    Each dict has keys: audio_path, sentence, duration, source_audio, segment_id.
    """
    audio_path = Path(audio_path)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    logger.info("Loading %s ...", audio_path.name)
    audio, sr = load_audio(audio_path, sr=sample_rate)

    records: List[dict] = []
    skipped_empty = 0
    skipped_silent = 0

    for seg in transcript_segments:
        chunk = extract_segment(audio, sr, seg.start, seg.end)

        if chunk.size == 0:
            logger.warning("Empty chunk for segment %d — skipping", seg.segment_id)
            skipped_empty += 1
            continue

        amp = _rms(chunk)
        if amp < min_amplitude:
            logger.debug(
                "Segment %d is near-silent (RMS=%.5f < %.5f) — skipping",
                seg.segment_id, amp, min_amplitude,
            )
            skipped_silent += 1
            continue

        wav_name = f"{audio_path.stem}_seg{seg.segment_id:04d}.wav"
        wav_path = output_dir / wav_name
        save_wav(chunk, sr, wav_path)

        records.append({
            "audio_path":   str(wav_path),
            "sentence":     seg.text,
            "duration":     seg.end - seg.start,
            "source_audio": seg.source_audio,
            "segment_id":   seg.segment_id,
        })

    logger.info(
        "Saved %d segments from %s  (skipped: %d empty, %d silent)",
        len(records), audio_path.name, skipped_empty, skipped_silent,
    )
    return records