rescored / backend /yourmt3_wrapper.py
calebhan's picture
yourmt3 fix 6
5f595fd
"""
YourMT3 Transcription Wrapper
This module provides a simplified interface to the YourMT3+ model for
music transcription. It wraps the HuggingFace Spaces implementation
to provide a clean API for transcription services.
Based on: https://huggingface.co/spaces/mimbres/YourMT3
"""
import sys
import os
import shutil
from pathlib import Path
from typing import Optional
# Add paths for imports
_base_dir = Path(__file__).parent
sys.path.insert(0, str(_base_dir / "ymt" / "yourmt3_core")) # For model_helper
sys.path.insert(0, str(_base_dir / "ymt" / "yourmt3_core" / "amt" / "src")) # For model/utils
import torch
import torchaudio
import soundfile as sf
import numpy as np
class YourMT3Transcriber:
"""
Wrapper class for YourMT3+ music transcription model.
This class handles model loading and provides a simple transcribe() method
for converting audio files to MIDI.
"""
def __init__(
self,
model_name: str = "YPTF.MoE+Multi (noPS)",
device: Optional[str] = None,
checkpoint_dir: Optional[Path] = None
):
"""
Initialize the YourMT3 transcriber.
Args:
model_name: Model variant to use. Options:
- "YMT3+"
- "YPTF+Single (noPS)"
- "YPTF+Multi (PS)"
- "YPTF.MoE+Multi (noPS)" (default, best quality)
- "YPTF.MoE+Multi (PS)"
device: Device to run on ('cuda', 'cpu', or None for auto-detect)
checkpoint_dir: Directory containing model checkpoints
"""
self.model_name = model_name
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.checkpoint_dir = checkpoint_dir or Path(__file__).parent / "ymt" / "yourmt3_core" / "logs" / "2024"
print(f"Initializing YourMT3+ ({model_name}) on {self.device}")
print(f"Checkpoint dir: {self.checkpoint_dir}")
# Import after path setup
try:
from model_helper import load_model_checkpoint
self._load_model_checkpoint = load_model_checkpoint
except ImportError as e:
raise RuntimeError(
f"Failed to import YourMT3 model helpers: {e}\n"
f"Make sure the amt/src directory is properly set up in yourmt3_core/"
)
# Load model
self.model = self._load_model(model_name)
def _get_model_args(self, model_name: str) -> list:
"""Get command-line arguments for model loading."""
project = '2024'
# Use float16 for GPU devices (CUDA/MPS) for better performance and lower memory
precision = '16' if self.device in ['cuda', 'mps'] else '32'
if model_name == "YMT3+":
checkpoint = "notask_all_cross_v6_xk2_amp0811_gm_ext_plus_nops_b72@model.ckpt"
args = [checkpoint, '-p', project, '-pr', precision]
elif model_name == "YPTF+Single (noPS)":
checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
'-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF+Multi (PS)":
checkpoint = "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
'-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (noPS)":
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (PS)":
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
else:
raise ValueError(f"Unknown model name: {model_name}")
return args
def _load_model(self, model_name: str):
"""Load the YourMT3 model checkpoint."""
args = self._get_model_args(model_name)
print(f"Loading model with args: {' '.join(args)}")
# YourMT3 expects to be run from amt/src directory for checkpoint paths
# Save current directory and temporarily change to amt/src
original_cwd = os.getcwd()
amt_src_dir = _base_dir / "ymt" / "yourmt3_core" / "amt" / "src"
try:
os.chdir(str(amt_src_dir))
# Load on CPU first, then move to target device
model = self._load_model_checkpoint(args=args, device="cpu")
model.to(self.device)
model.eval()
finally:
# Always restore original directory
os.chdir(original_cwd)
# Enable optimizations for inference
if hasattr(torch, 'set_float32_matmul_precision'):
torch.set_float32_matmul_precision('high') # Use TF32 on Ampere GPUs
# Disable gradient computation for inference (reduces memory)
for param in model.parameters():
param.requires_grad = False
print(f"Model loaded successfully on {self.device}")
return model
def transcribe_audio(self, audio_path: Path, output_dir: Optional[Path] = None) -> Path:
"""
Transcribe an audio file to MIDI.
Args:
audio_path: Path to input audio file (WAV, MP3, etc.)
output_dir: Directory to save MIDI output (default: current directory)
Returns:
Path to the generated MIDI file
Raises:
FileNotFoundError: If audio_path doesn't exist
RuntimeError: If transcription fails
"""
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"Audio file not found: {audio_path}")
output_dir = Path(output_dir) if output_dir else Path("./")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Transcribing: {audio_path}")
try:
# Monkey-patch torchaudio.load to use soundfile instead of torchcodec
# PyTorch 2.9+ hardcodes torchcodec which requires FFmpeg libraries
# This workaround uses soundfile which is more portable
original_load = torchaudio.load
def soundfile_load(uri, **kwargs):
"""Load audio using soundfile instead of torchcodec."""
audio_data, sample_rate = sf.read(uri, dtype='float32')
# Convert to torch tensor and ensure (channels, samples) shape
audio_tensor = torch.from_numpy(audio_data.T if audio_data.ndim > 1 else audio_data.reshape(1, -1))
return audio_tensor, sample_rate
# Replace torchaudio.load with our soundfile version
torchaudio.load = soundfile_load
# Import transcribe function
from model_helper import transcribe
# Prepare audio info dict (as expected by transcribe function)
audio_info = {
'filepath': str(audio_path),
'track_name': audio_path.stem
}
try:
# Run transcription
midi_path = transcribe(self.model, audio_info)
midi_path = Path(midi_path)
# Move to output directory if needed (use shutil.move for cross-filesystem support)
if midi_path.parent != output_dir:
final_path = output_dir / midi_path.name
shutil.move(str(midi_path), str(final_path))
midi_path = final_path
print(f"Transcription complete: {midi_path}")
return midi_path
finally:
# Restore original torchaudio.load
torchaudio.load = original_load
except Exception as e:
raise RuntimeError(f"Transcription failed: {e}")
if __name__ == "__main__":
# Test the transcriber
import argparse
parser = argparse.ArgumentParser(description="Test YourMT3 Transcriber")
parser.add_argument("audio_file", type=str, help="Path to audio file")
parser.add_argument("--model", type=str, default="YPTF.MoE+Multi (noPS)",
help="Model variant to use")
parser.add_argument("--output", type=str, default="./output",
help="Output directory for MIDI files")
args = parser.parse_args()
# Initialize transcriber
transcriber = YourMT3Transcriber(model_name=args.model)
# Transcribe audio
midi_path = transcriber.transcribe_audio(
audio_path=Path(args.audio_file),
output_dir=Path(args.output)
)
print(f"MIDI saved to: {midi_path}")