Xinsheng-Wang's picture
Upload folder using huggingface_hub
c7f3ffb verified
# https://github.com/RickyL-2000/ROSVOT
import math
import sys
import traceback
import json
import time
from pathlib import Path
from typing import Any, Dict, Optional
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from .utils.os_utils import safe_path
from .utils.commons.hparams import set_hparams
from .utils.commons.ckpt_utils import load_ckpt
from .utils.commons.dataset_utils import pad_or_cut_xd
from .utils.audio.mel import MelNet
from .utils.audio.pitch_utils import (
norm_interp_f0,
denorm_f0,
f0_to_coarse,
boundary2Interval,
save_midi,
midi_to_hz,
)
from .utils.rosvot_utils import (
get_mel_len,
align_word,
regulate_real_note_itv,
regulate_ill_slur,
bd_to_durs,
)
from .modules.pe.rmvpe import RMVPE
from .modules.rosvot.rosvot import MidiExtractor, WordbdExtractor
@torch.no_grad()
def infer_sample(
item: Dict[str, Any],
hparams: Dict[str, Any],
models: Dict[str, Any],
device: torch.device,
*,
save_dir: Optional[str] = None,
apply_rwbd: Optional[bool] = None,
# outputs
save_plot: bool = False,
no_save_midi: bool = True,
no_save_npy: bool = True,
verbose: bool = False,
) -> Dict[str, Any]:
if "item_name" not in item or "wav_fn" not in item:
raise ValueError('item must contain keys: "item_name" and "wav_fn"')
item_name = item["item_name"]
wav_src = item["wav_fn"]
# Decide RWBD usage
if apply_rwbd is None:
apply_rwbd_ = ("word_durs" not in item)
else:
apply_rwbd_ = bool(apply_rwbd)
# Models
model = models["model"]
mel_net = models["mel_net"]
pe = models.get("pe")
wbd_predictor = models.get("wbd_predictor")
if wbd_predictor is None and apply_rwbd_:
raise ValueError("apply_rwbd is True but wbd_predictor model is not provided in models")
# ---- Prepare Data ----
if isinstance(wav_src, str):
wav, _ = librosa.core.load(wav_src, sr=hparams["audio_sample_rate"])
else:
wav = wav_src
if not isinstance(wav, np.ndarray):
wav = np.asarray(wav)
wav = wav.astype(np.float32)
# Calculate timestamps and alignment lengths
wav_len_samples = wav.shape[-1]
mel_len = get_mel_len(wav_len_samples, hparams["hop_size"])
# Word boundary preparation
mel2word = None
word_durs_filtered = None
if not apply_rwbd_:
if "word_durs" not in item:
raise ValueError('apply_rwbd=False but item has no "word_durs"')
wd_raw = list(item["word_durs"])
min_word_dur = hparams.get("min_word_dur", 20) / 1000
word_durs_filtered = []
for i, wd in enumerate(wd_raw):
if wd < min_word_dur:
if i == 0 and len(wd_raw) > 1:
wd_raw[i + 1] += wd
elif len(word_durs_filtered) > 0:
word_durs_filtered[-1] += wd
else:
word_durs_filtered.append(wd)
mel2word, _ = align_word(word_durs_filtered, mel_len, hparams["hop_size"], hparams["audio_sample_rate"])
mel2word = np.asarray(mel2word)
if mel2word.size > 0 and mel2word[0] == 0:
mel2word = mel2word + 1
mel2word_len = int(np.sum(mel2word > 0))
real_len = min(mel_len, mel2word_len)
else:
real_len = min(mel_len, hparams["max_frames"])
T = math.ceil(min(real_len, hparams["max_frames"]) / hparams["frames_multiple"]) * hparams["frames_multiple"]
# ---- Input Tensors & Padding ----
target_samples = T * hparams["hop_size"]
wav_t = torch.from_numpy(wav).float().to(device).unsqueeze(0) # [1, L]
if wav_t.shape[-1] < target_samples:
wav_t = pad_or_cut_xd(wav_t, target_samples, 1)
# ---- Pitch Extraction ----
if pe is not None:
f0s, uvs = pe.get_pitch_batch(
wav_t,
sample_rate=hparams["audio_sample_rate"],
hop_size=hparams["hop_size"],
lengths=[real_len],
fmax=hparams["f0_max"],
fmin=hparams["f0_min"],
)
f0_1d, uv_1d = norm_interp_f0(f0s[0][:T])
f0_t = pad_or_cut_xd(torch.FloatTensor(f0_1d).to(device), T, 0).unsqueeze(0)
uv_t = pad_or_cut_xd(torch.FloatTensor(uv_1d).to(device), T, 0).long().unsqueeze(0)
pitch_coarse = f0_to_coarse(denorm_f0(f0_t, uv_t)).to(device)
f0_np = denorm_f0(f0_t, uv_t)[0].detach().cpu().numpy()[:real_len]
else:
f0_t = uv_t = pitch_coarse = None
f0_np = None
# ---- Mel Extraction ----
mel = mel_net(wav_t) # [1, T_padded, C]
mel = pad_or_cut_xd(mel, T, 1)
# Construct non-padding mask
mel_nonpadding_mask = torch.zeros(1, T, device=device)
mel_nonpadding_mask[:, :real_len] = 1.0
# Apply mask to mel (zero out padding)
mel = (mel.transpose(1, 2) * mel_nonpadding_mask.unsqueeze(1)).transpose(1, 2)
# Re-calculate non_padding bool mask
mel_nonpadding = mel.abs().sum(-1) > 0
# ---- Word Boundary ----
word_durs_used = None
if apply_rwbd_:
mel_input = mel[:, :, : hparams.get("wbd_use_mel_bins", 80)]
wbd_outputs = wbd_predictor(
mel=mel_input,
pitch=pitch_coarse,
uv=uv_t,
non_padding=mel_nonpadding,
train=False,
)
word_bd = wbd_outputs["word_bd_pred"] # [1, T]
else:
# Construct word_bd from provided durs
mel2word_t = pad_or_cut_xd(torch.LongTensor(mel2word).to(device), T, 0)
word_bd = torch.zeros_like(mel2word_t)
# Vectorized check
word_bd[1:] = (mel2word_t[1:] != mel2word_t[:-1]).long()
word_bd[real_len:] = 0
word_bd = word_bd.unsqueeze(0) # [1, T]
word_durs_used = np.array(word_durs_filtered)
# ---- Main Inference ----
mel_input = mel[:, :, : hparams.get("use_mel_bins", 80)]
outputs = model(
mel=mel_input,
word_bd=word_bd,
pitch=pitch_coarse,
uv=uv_t,
non_padding=mel_nonpadding,
train=False,
)
note_lengths = outputs["note_lengths"].detach().cpu().numpy()
note_bd_pred = outputs["note_bd_pred"][0].detach().cpu().numpy()[:real_len]
note_pred = outputs["note_pred"][0].detach().cpu().numpy()[: note_lengths[0]]
note_bd_logits = torch.sigmoid(outputs["note_bd_logits"])[0].detach().cpu().numpy()[:real_len]
if note_pred.shape == (0,):
if verbose:
print(f"skip {item_name}: no notes detected")
return {
"item_name": item_name,
"pitches": [],
"note_durs": [],
"note2words": None,
}
# ---- Post-Processing & Regulation ----
note_itv_pred = boundary2Interval(note_bd_pred)
note2words = None
if apply_rwbd_:
word_bd_np = outputs['word_bd_pred'][0].detach().cpu().numpy()[:real_len]
word_durs_derived = np.array(bd_to_durs(word_bd_np)) * hparams['hop_size'] / hparams['audio_sample_rate']
word_durs_for_reg = word_durs_derived
word_bd_for_reg = word_bd_np
else:
word_bd_for_reg = word_bd[0].detach().cpu().numpy()[:real_len]
word_durs_for_reg = word_durs_used
should_regulate = hparams.get("infer_regulate_real_note_itv", True) and (not apply_rwbd_)
if should_regulate and (word_durs_for_reg is not None):
try:
note_itv_pred_secs, note2words = regulate_real_note_itv(
note_itv_pred,
note_bd_pred,
word_bd_for_reg,
word_durs_for_reg,
hparams["hop_size"],
hparams["audio_sample_rate"],
)
note_pred, note_itv_pred_secs, note2words = regulate_ill_slur(note_pred, note_itv_pred_secs, note2words)
except Exception as err:
if verbose:
_, exc_value, exc_tb = sys.exc_info()
tb = traceback.extract_tb(exc_tb)[-1]
print(f"postprocess failed: {err}: {exc_value} in {tb[0]}:{tb[1]} '{tb[2]}' in {tb[3]}")
# Fallback
note_itv_pred_secs = note_itv_pred * hparams["hop_size"] / hparams["audio_sample_rate"]
note2words = None
else:
note_itv_pred_secs = note_itv_pred * hparams["hop_size"] / hparams["audio_sample_rate"]
# ---- Output ----
note_durs = [float((itv[1] - itv[0])) for itv in note_itv_pred_secs]
out = {
"item_name": item_name,
"pitches": note_pred.tolist(),
"note_durs": note_durs,
"note2words": note2words.tolist() if note2words is not None else None,
}
# ---- Saving ----
if save_dir is not None:
save_dir_path = Path(save_dir)
save_dir_path.mkdir(parents=True, exist_ok=True)
fn = str(item_name)
if not no_save_midi:
save_midi(note_pred, note_itv_pred_secs, safe_path(save_dir_path / "midi" / f"{fn}.mid"))
if not no_save_npy:
np.save(safe_path(save_dir_path / "npy" / f"[note]{fn}.npy"), out, allow_pickle=True)
if save_plot:
fig = plt.figure()
if f0_np is not None:
plt.plot(f0_np, color="red", label="f0")
midi_pred = np.zeros(note_bd_pred.shape[0], dtype=np.float32)
itvs = np.round(note_itv_pred_secs * hparams["audio_sample_rate"] / hparams["hop_size"]).astype(int)
for i, itv in enumerate(itvs):
midi_pred[itv[0] : itv[1]] = note_pred[i]
plt.plot(midi_to_hz(midi_pred), color="blue", label="pred midi")
plt.plot(note_bd_logits * 100, color="green", label="note bd logits x100")
plt.legend()
plt.tight_layout()
plt.savefig(safe_path(save_dir_path / "plot" / f"[MIDI]{fn}.png"), format="png")
plt.close(fig)
return out
def load_rosvot_models(ckpt, config="", wbd_ckpt="", wbd_config="", device="cuda:0", verbose=False, thr=0.85):
"""
Load models once to reuse across multiple items.
"""
dev = torch.device(device)
# 1. Hparams
config_path = Path(ckpt).with_name("config.yaml") if config == "" else config
pe_ckpt = Path(ckpt).parent.parent / "rmvpe/model.pt"
hparams = set_hparams(
config=config_path,
print_hparams=verbose,
hparams_str=f"note_bd_threshold={thr}",
)
# 2. Main Model
model = MidiExtractor(hparams)
load_ckpt(model, ckpt, verbose=verbose)
model.eval().to(dev)
# 3. MelNet
mel_net = MelNet(hparams)
mel_net.to(dev)
# 4. Pitch Extractor
pe = None
if hparams.get("use_pitch_embed", False):
pe = RMVPE(pe_ckpt, device=dev)
# 5. Word Boundary Predictor (optional but we load if ckpt provided or needed)
wbd_predictor = None
if wbd_ckpt:
wbd_config_path = Path(wbd_ckpt).with_name("config.yaml") if wbd_config == "" else wbd_config
wbd_hparams = set_hparams(
config=wbd_config_path,
print_hparams=False,
hparams_str="",
)
hparams.update({
"wbd_use_mel_bins": wbd_hparams["use_mel_bins"],
"min_word_dur": wbd_hparams["min_word_dur"],
})
wbd_predictor = WordbdExtractor(wbd_hparams)
load_ckpt(wbd_predictor, wbd_ckpt, verbose=verbose)
wbd_predictor.eval().to(dev)
models = {
"model": model,
"mel_net": mel_net,
"pe": pe,
"wbd_predictor": wbd_predictor
}
return hparams, models
class NoteTranscriber:
"""Note transcription wrapper based on ROSVOT.
Loads ROSVOT and optional RWBD models once in ``__init__`` and
exposes a :py:meth:`process` API that turns an item dict into
aligned note metadata for downstream SVS.
"""
def __init__(
self,
rosvot_model_path: str,
rwbd_model_path: str,
*,
rosvot_config_path: str = "",
rwbd_config_path: str = "",
device: str = "cuda:0",
thr: float = 0.85,
verbose: bool = True,
):
"""Initialize the note transcriber.
Args:
ckpt: Path to the main ROSVOT checkpoint.
config: Optional config YAML path for ROSVOT.
wbd_ckpt: Optional word-boundary checkpoint path.
wbd_config: Optional config YAML path for RWBD.
device: Torch device string, e.g. ``"cuda:0"`` / ``"cpu"``.
thr: Note boundary threshold.
verbose: Whether to print verbose logs.
"""
self.verbose = verbose
self.device = torch.device(device)
self.hparams, self.models = load_rosvot_models(
ckpt=rosvot_model_path,
config=rosvot_config_path,
wbd_ckpt=rwbd_model_path,
wbd_config=rwbd_config_path,
device=device,
verbose=verbose,
thr=thr,
)
if self.verbose:
print(
"[note transcription] init success:",
f"device={self.device}",
f"rosvot_model_path={rosvot_model_path}",
f"rwbd_model_path={rwbd_model_path if rwbd_model_path else 'None'}",
f"thr={thr}",
)
def process(
self,
item: Dict[str, Any],
*,
segment_info: Optional[Dict[str, Any]] = None,
save_dir: Optional[str] = None,
apply_rwbd: Optional[bool] = None,
save_plot: bool = False,
no_save_midi: bool = True,
no_save_npy: bool = True,
verbose: Optional[bool] = None,
) -> Dict[str, Any]:
"""Run ROSVOT on a single item and post-process outputs.
Args:
item: Input metadata dict with at least ``item_name`` and ``wav_fn``.
segment_info: Optional segment metadata for sliced audio.
save_dir: Optional directory for debug artifacts (plots, midis).
apply_rwbd: Whether to run RWBD-based word boundary refinement.
save_plot: Whether to save diagnostic plots.
no_save_midi: If True, skip saving midi.
no_save_npy: If True, skip saving numpy intermediates.
verbose: Override instance-level verbose flag for this call.
Returns:
Dict with aligned note information for downstream SVS.
"""
v = self.verbose if verbose is None else verbose
if v:
item_name = item.get("item_name", "")
wav_fn = item.get("wav_fn", "")
print(f"[note transcription] process: start: item_name={item_name} wav_fn={wav_fn}")
t0 = time.time()
rosvot_out = infer_sample(
item,
self.hparams,
self.models,
device=self.device,
save_dir=save_dir,
apply_rwbd=apply_rwbd,
save_plot=save_plot,
no_save_midi=no_save_midi,
no_save_npy=no_save_npy,
verbose=v,
)
out = self.post_process(
metadata=item,
segment_info=segment_info,
rosvot_out=rosvot_out,
)
if v:
dt = time.time() - t0
print(
"[note transcription] process: done:",
f"item_name={out.get('item_name','')}",
f"n_notes={len(out.get('note_pitch', []) or [])}",
f"time={dt:.3f}s",
)
return out
@staticmethod
def _normalize_note2words(note2words: list[int]) -> list[int]:
if not note2words:
return []
normalized = [note2words[0]]
for idx in range(1, len(note2words)):
if note2words[idx] < normalized[-1]:
normalized.append(normalized[-1])
else:
normalized.append(note2words[idx])
return normalized
@staticmethod
def _build_ep_types(note2words: list[int], align_words: list[str]) -> list[int]:
ep_types: list[int] = []
prev = -1
for i, w in zip(note2words, align_words):
if w == "<SP>":
ep_types.append(1)
else:
ep_types.append(2 if i != prev else 3)
prev = i
return ep_types
def post_process(
self,
*,
metadata: Dict[str, Any],
segment_info: Dict[str, Any],
rosvot_out: Dict[str, Any],
) -> Dict[str, Any]:
"""Build aligned note metadata using ROSVOT outputs."""
note2words_raw = rosvot_out.get("note2words") or []
note2words = self._normalize_note2words(note2words_raw)
align_words = [
metadata["words"][idx - 1]
for idx in note2words_raw
if 0 < idx <= len(metadata["words"])
]
ep_types = self._build_ep_types(note2words, align_words) if align_words else []
return {
"item_name": rosvot_out.get("item_name", "") if not segment_info else segment_info["item_name"],
"wav_fn": metadata.get("wav_fn", "") if not segment_info else segment_info["wav_fn"],
"origin_wav_fn": metadata.get("origin_wav_fn", "") if not segment_info else segment_info["origin_wav_fn"],
"start_time_ms": "" if not segment_info else segment_info["start_time_ms"],
"end_time_ms": "" if not segment_info else segment_info["end_time_ms"],
"language": metadata.get("language", ""),
"note_text": align_words,
"note_dur": rosvot_out.get("note_durs", []),
"note_type": ep_types,
"note_pitch": rosvot_out.get("pitches", []),
}
if __name__ == "__main__":
items = json.load(open("example/test/rosvot_input.json", "r"))
item = items[0]
m = NoteTranscriber(
rosvot_model_path="pretrained_models/rosvot/rosvot/model.pt",
rwbd_model_path="pretrained_models/rosvot/rwbd/model.pt",
device="cuda"
)
out = m.process(item)
print(out)