rescored / backend /ensemble_transcriber.py
calebhan's picture
vocal separation and bytedance integration
e7bf1e6
"""
Ensemble Transcription Module
Combines multiple transcription models via voting for improved accuracy.
Ensemble Strategy:
- YourMT3+: Multi-instrument generalist, excellent polyphony & expressive timing (80-85% F1)
- ByteDance: Piano specialist, high precision on piano-only audio (90-95% F1)
- Combined: Voting reduces false positives and false negatives (90-95% F1 expected)
"""
from pathlib import Path
from typing import List, Dict, Optional, Literal
from dataclasses import dataclass
import numpy as np
from mido import MidiFile, MidiTrack, Message, MetaMessage
import pretty_midi
@dataclass
class Note:
"""Musical note with timing and pitch information."""
pitch: int # MIDI pitch (0-127)
onset: float # Start time in seconds
offset: float # End time in seconds
velocity: int = 64 # Note velocity (0-127)
confidence: float = 1.0 # Confidence score (0-1)
@property
def duration(self) -> float:
"""Note duration in seconds."""
return self.offset - self.onset
class EnsembleTranscriber:
"""
Ensemble transcription using multiple models with voting.
Voting Strategies:
1. 'weighted': Sum confidence scores, keep notes above threshold
2. 'intersection': Only keep notes agreed upon by all models (high precision)
3. 'union': Keep all notes from all models (high recall)
4. 'majority': Keep notes predicted by >=50% of models
"""
def __init__(
self,
yourmt3_transcriber,
bytedance_transcriber,
voting_strategy: Literal['weighted', 'intersection', 'union', 'majority'] = 'weighted',
onset_tolerance_ms: int = 50,
confidence_threshold: float = 0.6
):
"""
Initialize ensemble transcriber.
Args:
yourmt3_transcriber: YourMT3Transcriber instance
bytedance_transcriber: ByteDanceTranscriber instance
voting_strategy: How to combine predictions
onset_tolerance_ms: Time window for matching notes (milliseconds)
confidence_threshold: Minimum confidence for 'weighted' strategy
"""
self.yourmt3 = yourmt3_transcriber
self.bytedance = bytedance_transcriber
self.voting_strategy = voting_strategy
self.onset_tolerance = onset_tolerance_ms / 1000.0 # Convert to seconds
self.confidence_threshold = confidence_threshold
def transcribe(
self,
audio_path: Path,
output_dir: Optional[Path] = None
) -> Path:
"""
Transcribe audio using ensemble of models.
Args:
audio_path: Path to audio file (should be piano stem)
output_dir: Directory for output MIDI file
Returns:
Path to ensemble MIDI file
"""
if output_dir is None:
output_dir = audio_path.parent
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n ═══ Ensemble Transcription ═══")
print(f" Strategy: {self.voting_strategy}")
print(f" Onset tolerance: {self.onset_tolerance*1000:.0f}ms")
# Transcribe with YourMT3+
print(f"\n [1/2] Transcribing with YourMT3+...")
yourmt3_midi = self.yourmt3.transcribe_audio(audio_path, output_dir)
yourmt3_notes = self._extract_notes_from_midi(yourmt3_midi)
print(f" ✓ YourMT3+ found {len(yourmt3_notes)} notes")
# Transcribe with ByteDance
print(f"\n [2/2] Transcribing with ByteDance...")
bytedance_midi = self.bytedance.transcribe_audio(audio_path, output_dir)
bytedance_notes = self._extract_notes_from_midi(bytedance_midi)
print(f" ✓ ByteDance found {len(bytedance_notes)} notes")
# Vote and merge
print(f"\n Voting with '{self.voting_strategy}' strategy...")
ensemble_notes = self._vote_notes(
[yourmt3_notes, bytedance_notes],
model_names=['YourMT3+', 'ByteDance']
)
print(f" ✓ Ensemble result: {len(ensemble_notes)} notes")
# Convert to MIDI
ensemble_midi_path = output_dir / f"{audio_path.stem}_ensemble.mid"
self._notes_to_midi(ensemble_notes, ensemble_midi_path)
print(f" ✓ Ensemble MIDI saved: {ensemble_midi_path.name}")
print(f" ═══════════════════════════════\n")
return ensemble_midi_path
def _extract_notes_from_midi(self, midi_path: Path) -> List[Note]:
"""
Extract notes from MIDI file.
Args:
midi_path: Path to MIDI file
Returns:
List of Note objects
"""
pm = pretty_midi.PrettyMIDI(str(midi_path))
notes = []
for instrument in pm.instruments:
if instrument.is_drum:
continue
for note in instrument.notes:
notes.append(Note(
pitch=note.pitch,
onset=note.start,
offset=note.end,
velocity=note.velocity,
confidence=1.0 # Default confidence (TODO: extract from model if available)
))
# Sort by onset time
notes.sort(key=lambda n: n.onset)
return notes
def _vote_notes(
self,
note_lists: List[List[Note]],
model_names: List[str]
) -> List[Note]:
"""
Vote on notes from multiple models.
Args:
note_lists: List of note lists from different models
model_names: Names of models (for logging)
Returns:
Merged list of notes after voting
"""
if self.voting_strategy == 'weighted':
return self._vote_weighted(note_lists, model_names)
elif self.voting_strategy == 'intersection':
return self._vote_intersection(note_lists, model_names)
elif self.voting_strategy == 'union':
return self._vote_union(note_lists, model_names)
elif self.voting_strategy == 'majority':
return self._vote_majority(note_lists, model_names)
else:
raise ValueError(f"Unknown voting strategy: {self.voting_strategy}")
def _vote_weighted(
self,
note_lists: List[List[Note]],
model_names: List[str]
) -> List[Note]:
"""
Weighted voting: Sum confidence scores, keep notes above threshold.
Gives higher weight to ByteDance (piano specialist).
"""
# Model weights (ByteDance is more accurate for piano)
weights = {'YourMT3+': 0.4, 'ByteDance': 0.6}
# Group notes by (onset_bucket, pitch)
note_groups = {}
for model_idx, notes in enumerate(note_lists):
model_name = model_names[model_idx]
weight = weights.get(model_name, 1.0 / len(note_lists))
for note in notes:
# Quantize onset to tolerance bucket
onset_bucket = round(note.onset / self.onset_tolerance)
key = (onset_bucket, note.pitch)
if key not in note_groups:
note_groups[key] = []
# Add note with weighted confidence
note.confidence *= weight
note_groups[key].append(note)
# Merge notes in each group
merged_notes = []
for (onset_bucket, pitch), group in note_groups.items():
# Sum confidence across models
total_confidence = sum(n.confidence for n in group)
if total_confidence >= self.confidence_threshold:
# Use average timing and velocity
avg_onset = np.mean([n.onset for n in group])
avg_offset = np.mean([n.offset for n in group])
avg_velocity = int(np.mean([n.velocity for n in group]))
merged_notes.append(Note(
pitch=pitch,
onset=avg_onset,
offset=avg_offset,
velocity=avg_velocity,
confidence=total_confidence
))
merged_notes.sort(key=lambda n: n.onset)
return merged_notes
def _vote_intersection(
self,
note_lists: List[List[Note]],
model_names: List[str]
) -> List[Note]:
"""
Intersection voting: Only keep notes agreed upon by ALL models.
High precision, potentially lower recall.
"""
if len(note_lists) == 0:
return []
# Start with first model's notes
base_notes = note_lists[0]
matched_notes = []
for base_note in base_notes:
# Check if this note appears in ALL other models
found_in_all = True
for other_notes in note_lists[1:]:
if not self._find_matching_note(base_note, other_notes):
found_in_all = False
break
if found_in_all:
matched_notes.append(base_note)
return matched_notes
def _vote_union(
self,
note_lists: List[List[Note]],
model_names: List[str]
) -> List[Note]:
"""
Union voting: Keep ALL notes from ALL models.
High recall, potentially more false positives.
"""
# Combine all notes
all_notes = []
for notes in note_lists:
all_notes.extend(notes)
# Deduplicate: group similar notes and average them
note_groups = {}
for note in all_notes:
onset_bucket = round(note.onset / self.onset_tolerance)
key = (onset_bucket, note.pitch)
if key not in note_groups:
note_groups[key] = []
note_groups[key].append(note)
# Average duplicates
merged_notes = []
for (onset_bucket, pitch), group in note_groups.items():
avg_onset = np.mean([n.onset for n in group])
avg_offset = np.mean([n.offset for n in group])
avg_velocity = int(np.mean([n.velocity for n in group]))
merged_notes.append(Note(
pitch=pitch,
onset=avg_onset,
offset=avg_offset,
velocity=avg_velocity,
confidence=len(group) / len(note_lists) # Confidence = agreement ratio
))
merged_notes.sort(key=lambda n: n.onset)
return merged_notes
def _vote_majority(
self,
note_lists: List[List[Note]],
model_names: List[str]
) -> List[Note]:
"""
Majority voting: Keep notes predicted by >=50% of models.
Balanced precision and recall.
"""
threshold = len(note_lists) / 2.0
# Group notes by (onset_bucket, pitch)
note_groups = {}
for notes in note_lists:
for note in notes:
onset_bucket = round(note.onset / self.onset_tolerance)
key = (onset_bucket, note.pitch)
if key not in note_groups:
note_groups[key] = []
note_groups[key].append(note)
# Keep notes with majority agreement
merged_notes = []
for (onset_bucket, pitch), group in note_groups.items():
if len(group) >= threshold:
avg_onset = np.mean([n.onset for n in group])
avg_offset = np.mean([n.offset for n in group])
avg_velocity = int(np.mean([n.velocity for n in group]))
merged_notes.append(Note(
pitch=pitch,
onset=avg_onset,
offset=avg_offset,
velocity=avg_velocity,
confidence=len(group) / len(note_lists)
))
merged_notes.sort(key=lambda n: n.onset)
return merged_notes
def _find_matching_note(self, target: Note, notes: List[Note]) -> Optional[Note]:
"""Find a note that matches the target note within tolerance."""
for note in notes:
if (note.pitch == target.pitch and
abs(note.onset - target.onset) <= self.onset_tolerance):
return note
return None
def _notes_to_midi(self, notes: List[Note], output_path: Path):
"""
Convert list of notes to MIDI file.
Args:
notes: List of Note objects
output_path: Path for output MIDI file
"""
# Create MIDI file
mid = MidiFile()
track = MidiTrack()
mid.tracks.append(track)
# Add tempo (120 BPM default)
track.append(MetaMessage('set_tempo', tempo=500000, time=0))
# Convert notes to MIDI messages
# (simplified - assumes single instrument, no overlapping notes with same pitch)
# Use absolute timing, then convert to delta
events = []
for note in notes:
# Convert seconds to ticks (480 ticks per beat, 120 BPM)
ticks_per_second = 480 * 2 # 480 ticks/beat * 2 beats/second at 120 BPM
onset_ticks = int(note.onset * ticks_per_second)
offset_ticks = int(note.offset * ticks_per_second)
events.append((onset_ticks, 'note_on', note.pitch, note.velocity))
events.append((offset_ticks, 'note_off', note.pitch, 0))
# Sort by time
events.sort(key=lambda e: e[0])
# Convert to delta time and add to track
previous_time = 0
for abs_time, msg_type, pitch, velocity in events:
delta_time = abs_time - previous_time
previous_time = abs_time
track.append(Message(
msg_type,
note=pitch,
velocity=velocity,
time=delta_time
))
# Add end of track
track.append(MetaMessage('end_of_track', time=0))
# Save
mid.save(output_path)