SoulX-Singer / preprocess /tools /vocal_detection.py
Xinsheng-Wang's picture
Upload folder using huggingface_hub
c7f3ffb verified
import os
import time
from dataclasses import dataclass
from typing import List, Optional
import librosa
import numpy as np
from soundfile import write
@dataclass(frozen=True)
class VocalDetectionConfig:
hop_ms: int = 20
smooth_ms: int = 200
start_ms: int = 120
end_ms: int = 200
prepad_ms: int = 80
postpad_ms: int = 120
min_len_ms: int = 1000
max_len_ms: int = 20000
short_seg_merge_gap_ms: int = 8000
small_gap_ms: int = 500
lookback_ms: int = 200
lookahead_ms: int = 100
def _moving_average(x: np.ndarray, win: int) -> np.ndarray:
if win <= 1:
return x
kernel = np.ones(win, dtype=np.float32) / float(win)
return np.convolve(x, kernel, mode="same")
def _merge_short_segments(
segments_ms: List[List[int]],
*,
min_len_ms: int,
max_len_ms: int,
short_seg_merge_gap_ms: int,
small_gap_ms: int,
) -> List[List[int]]:
if not segments_ms:
return []
merged: List[List[int]] = []
cur_start, cur_end = segments_ms[0]
for next_start, next_end in segments_ms[1:]:
cur_len = cur_end - cur_start
gap_ms = next_start - cur_end
merged_len = next_end - cur_start
should_merge = (
(cur_len < min_len_ms and gap_ms < short_seg_merge_gap_ms)
or (gap_ms < small_gap_ms and merged_len < max_len_ms)
)
if should_merge:
cur_end = next_end
continue
if (cur_end - cur_start) >= min_len_ms:
merged.append([cur_start, cur_end])
cur_start, cur_end = next_start, next_end
if (cur_end - cur_start) >= min_len_ms:
merged.append([cur_start, cur_end])
if not merged:
return segments_ms
return merged
def _voiced_to_segments(
voiced: np.ndarray,
*,
hop_ms: int,
smooth_ms: int,
start_ms: int,
end_ms: int,
prepad_ms: int,
postpad_ms: int,
max_len_ms: int,
) -> List[List[int]]:
smooth_frames = max(1, int(round(smooth_ms / hop_ms)))
smooth_voiced = _moving_average(voiced.astype(np.float32), smooth_frames)
active = smooth_voiced >= 0.5
segments: List[List[int]] = []
start_idx = None
start_frames = max(1, int(round(start_ms / hop_ms)))
end_frames = max(1, int(round(end_ms / hop_ms)))
prepad_frames = max(0, int(round(prepad_ms / hop_ms)))
postpad_frames = max(0, int(round(postpad_ms / hop_ms)))
active_count = 0
inactive_count = 0
for i, flag in enumerate(active):
if flag:
active_count += 1
inactive_count = 0
else:
inactive_count += 1
active_count = 0
if start_idx is None:
if active_count >= start_frames:
start_idx = max(0, i - start_frames + 1 - prepad_frames)
else:
if inactive_count >= end_frames:
end_idx = min(len(active) - 1, i - end_frames + 1 + postpad_frames)
start_ms_val = start_idx * hop_ms
end_ms_val = end_idx * hop_ms + hop_ms
if end_ms_val > start_ms_val:
segments.append([int(start_ms_val), int(end_ms_val)])
start_idx = None
if start_idx is not None:
start_ms_val = start_idx * hop_ms
end_idx = min(len(active) - 1, len(active) - 1 + postpad_frames)
end_ms_val = end_idx * hop_ms + hop_ms
if end_ms_val > start_ms_val:
segments.append([int(start_ms_val), int(end_ms_val)])
def _split_segment(seg: List[int]) -> List[List[int]]:
start_ms_val, end_ms_val = seg
start_frame = int(start_ms_val // hop_ms)
end_frame = int((end_ms_val - 1) // hop_ms)
end_frame = max(start_frame, min(end_frame, len(active) - 1))
best_start = None
best_len = 0
cur_start = None
cur_len = 0
for idx in range(start_frame, end_frame + 1):
if not active[idx]:
if cur_start is None:
cur_start = idx
cur_len = 1
else:
cur_len += 1
else:
if cur_start is not None and cur_len > best_len:
best_start, best_len = cur_start, cur_len
cur_start = None
cur_len = 0
if cur_start is not None and cur_len > best_len:
best_start, best_len = cur_start, cur_len
if best_start is None:
split_frame = (start_frame + end_frame) // 2
else:
split_frame = best_start + best_len // 2
split_ms = split_frame * hop_ms
if split_ms <= start_ms_val:
split_ms = start_ms_val + hop_ms
if split_ms >= end_ms_val:
split_ms = end_ms_val - hop_ms
if split_ms <= start_ms_val or split_ms >= end_ms_val:
return [seg]
return [[start_ms_val, int(split_ms)], [int(split_ms), end_ms_val]]
queue = segments[:]
segments = []
while queue:
seg = queue.pop(0)
if (seg[1] - seg[0]) <= max_len_ms:
segments.append(seg)
continue
parts = _split_segment(seg)
if len(parts) == 1:
segments.append(seg)
else:
queue = parts + queue
return segments
class VocalDetector:
"""Detect vocal segments based on f0 voiced decisions.
This component consumes a precomputed ``*_f0.npy`` track and
produces vocal segments (and cuts wav files) for downstream
transcription or singing voice tasks.
"""
def __init__(
self,
cut_wavs_output_dir: str = "cut_wavs",
config: VocalDetectionConfig | None = None,
*,
verbose: bool = True,
):
"""Initialize the vocal detector.
Args:
cut_wavs_output_dir: Directory to save cut wav segments.
config: Detection configuration; uses :class:`VocalDetectionConfig` by default.
verbose: Whether to print verbose logs.
"""
self.cut_wavs_output_dir = cut_wavs_output_dir
self.config = config or VocalDetectionConfig()
self.verbose = verbose
if self.verbose:
print(
"[vocal detection] init success:",
f"cut_wavs_output_dir={self.cut_wavs_output_dir}",
f"hop_ms={self.config.hop_ms}",
)
def process(self, audio_path: str, f0: np.ndarray, *, verbose: Optional[bool] = None) -> List[dict]:
"""Run vocal detection on a single wav.
Args:
audio_path: Path to the input wav file.
f0: The f0 contour to use for vocal detection.
verbose: Override instance-level verbose flag for this call.
Returns:
A list of segment metadata dicts with fields like
``item_name``, ``wav_fn``, ``start_time_ms``, ``end_time_ms``.
"""
verbose = self.verbose if verbose is None else verbose
if verbose:
print(f"[vocal detection] process: start: {audio_path}")
t0 = time.time()
os.makedirs(self.cut_wavs_output_dir, exist_ok=True)
base_name = os.path.basename(audio_path)
base_name_no_ext = os.path.splitext(base_name)[0]
voiced = f0 > 0
segments_ms = _voiced_to_segments(
voiced,
hop_ms=self.config.hop_ms,
smooth_ms=self.config.smooth_ms,
start_ms=self.config.start_ms,
end_ms=self.config.end_ms,
prepad_ms=self.config.prepad_ms,
postpad_ms=self.config.postpad_ms,
max_len_ms=self.config.max_len_ms,
)
if verbose:
print(f"[vocal detection] segments(before_merge)={len(segments_ms)}")
segments_ms = _merge_short_segments(
segments_ms,
min_len_ms=self.config.min_len_ms,
max_len_ms=self.config.max_len_ms,
short_seg_merge_gap_ms=self.config.short_seg_merge_gap_ms,
small_gap_ms=self.config.small_gap_ms,
)
if verbose:
print(f"[vocal detection] segments(after_merge)={len(segments_ms)}")
y, sr = librosa.load(audio_path, sr=None, mono=True)
# Apply global lookback/lookahead in milliseconds
lookback_ms = self.config.lookback_ms
lookahead_ms = self.config.lookahead_ms
adjusted_segments: List[List[int]] = []
prev_end = 0
for start_ms, end_ms in segments_ms:
start_ms = max(0, start_ms - lookback_ms)
end_ms = min(end_ms + lookahead_ms, int(y.shape[0] / sr * 1000))
# Enforce non-overlap with previous segment, move backward the previous one.
if start_ms < prev_end and len(adjusted_segments) > 0:
adjusted_segments[-1][1] = start_ms
adjusted_segments.append([start_ms, end_ms])
prev_end = end_ms
segment_infos = []
for idx, (start_ms, end_ms) in enumerate(adjusted_segments):
if end_ms - start_ms > self.config.max_len_ms:
start_ms = end_ms - self.config.max_len_ms
key = f"{base_name_no_ext}_{idx}"
start_sample = librosa.time_to_samples(start_ms / 1000, sr=sr)
end_sample = librosa.time_to_samples(end_ms / 1000, sr=sr)
segment = y[start_sample:end_sample]
write(f"{self.cut_wavs_output_dir}/{key}.wav", segment, sr)
segment_infos.append(
{
"item_name": key,
"wav_fn": f"{self.cut_wavs_output_dir}/{key}.wav",
"start_time_ms": int(start_sample * 1000 / sr),
"end_time_ms": int(end_sample * 1000 / sr),
"origin_wav_fn": audio_path,
"duration": int((end_sample - start_sample) * 1000 / sr),
}
)
if verbose:
dt = time.time() - t0
print(
"[vocal detection] process: done:",
f"n_segments={len(segment_infos)}",
f"time={dt:.3f}s",
)
return segment_infos
if __name__ == "__main__":
m = VocalDetector(cut_wavs_output_dir="outputs/transcription/cut_wavs")
segment_infos = m.process("./outputs/transcription/test.wav")
print(segment_infos)