File size: 3,469 Bytes
20e9692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""VAD inference utilities."""

from typing import List, Tuple

import numpy as np
import torch

from .segmenter_aoti import is_aoti_applied
from .segmenter_model import load_segmenter, _log_env_once


def detect_speech_segments(
    audio: np.ndarray,
    sample_rate: int,
    min_silence_ms: int,
    min_speech_ms: int,
    pad_ms: int
) -> tuple[List[Tuple[float, float]], dict]:
    """
    Detect speech segments in audio using VAD.

    Args:
        audio: Audio waveform (mono, float32)
        sample_rate: Sample rate of audio
        min_silence_ms: Minimum silence duration to split segments
        min_speech_ms: Minimum speech duration for a valid segment
        pad_ms: Padding around speech segments

    Returns:
        Tuple of (intervals, profiling_dict, raw_speech_intervals, raw_is_complete) where:
        - intervals: List of (start_time, end_time) tuples in seconds
        - profiling_dict: {"model_load_time": float, "inference_time": float}
        - raw_speech_intervals: Raw VAD intervals before cleaning (for resegmentation)
        - raw_is_complete: Raw VAD completeness flags (for resegmentation)
    """
    import time

    model, processor, model_load_time = load_segmenter()
    if model is None:
        # Fallback: treat whole audio as one segment
        return [(0, len(audio) / sample_rate)], {"model_load_time": 0.0, "inference_time": 0.0}, None, None

    inference_start = time.time()
    _log_env_once()

    try:
        from recitations_segmenter import segment_recitations, clean_speech_intervals

        audio_tensor = torch.from_numpy(audio).float()

        device = next(model.parameters()).device
        dtype = next(model.parameters()).dtype

        # Log AoTI status
        if is_aoti_applied():
            print("[VAD] Using AOTInductor-compiled model")

        # Run segmentation
        outputs = segment_recitations(
            [audio_tensor], model, processor,
            device=device, dtype=dtype, batch_size=1,
        )

        if not outputs:
            inference_time = time.time() - inference_start
            return [(0, len(audio) / sample_rate)], {"model_load_time": model_load_time, "inference_time": inference_time}, None, None

        # Clean speech intervals with user parameters
        clean_out = clean_speech_intervals(
            outputs[0].speech_intervals,
            outputs[0].is_complete,
            min_silence_duration_ms=min_silence_ms,
            min_speech_duration_ms=min_speech_ms,
            pad_duration_ms=pad_ms,
            return_seconds=True,
        )

        inference_time = time.time() - inference_start
        intervals = clean_out.clean_speech_intervals.tolist()

        raw_count = len(outputs[0].speech_intervals)
        final_count = len(intervals)
        removed = raw_count - final_count
        print(f"[VAD] Raw model intervals: {raw_count}, after cleaning: {final_count} "
              f"({removed} removed by silence merge + min_speech={min_speech_ms}ms filter)")

        raw_speech_intervals = outputs[0].speech_intervals
        raw_is_complete = outputs[0].is_complete

        return [(start, end) for start, end in intervals], {"model_load_time": model_load_time, "inference_time": inference_time}, raw_speech_intervals, raw_is_complete

    except Exception as e:
        print(f"VAD error: {e}")
        import traceback
        traceback.print_exc()
        # Let gpu_with_fallback handle retries on CPU
        raise