| """ |
| Ensemble Transcription Module |
| |
| Combines multiple transcription models via voting for improved accuracy. |
| |
| Ensemble Strategy: |
| - YourMT3+: Multi-instrument generalist, excellent polyphony & expressive timing (80-85% F1) |
| - ByteDance: Piano specialist, high precision on piano-only audio (90-95% F1) |
| - Combined: Voting reduces false positives and false negatives (90-95% F1 expected) |
| """ |
|
|
| from pathlib import Path |
| from typing import List, Dict, Optional, Literal |
| from dataclasses import dataclass |
| import numpy as np |
| from mido import MidiFile, MidiTrack, Message, MetaMessage |
| import pretty_midi |
|
|
|
|
| @dataclass |
| class Note: |
| """Musical note with timing and pitch information.""" |
| pitch: int |
| onset: float |
| offset: float |
| velocity: int = 64 |
| confidence: float = 1.0 |
|
|
| @property |
| def duration(self) -> float: |
| """Note duration in seconds.""" |
| return self.offset - self.onset |
|
|
|
|
| class EnsembleTranscriber: |
| """ |
| Ensemble transcription using multiple models with voting. |
| |
| Voting Strategies: |
| 1. 'weighted': Sum confidence scores, keep notes above threshold |
| 2. 'intersection': Only keep notes agreed upon by all models (high precision) |
| 3. 'union': Keep all notes from all models (high recall) |
| 4. 'majority': Keep notes predicted by >=50% of models |
| """ |
|
|
| def __init__( |
| self, |
| yourmt3_transcriber, |
| bytedance_transcriber, |
| voting_strategy: Literal['weighted', 'intersection', 'union', 'majority'] = 'weighted', |
| onset_tolerance_ms: int = 50, |
| confidence_threshold: float = 0.6 |
| ): |
| """ |
| Initialize ensemble transcriber. |
| |
| Args: |
| yourmt3_transcriber: YourMT3Transcriber instance |
| bytedance_transcriber: ByteDanceTranscriber instance |
| voting_strategy: How to combine predictions |
| onset_tolerance_ms: Time window for matching notes (milliseconds) |
| confidence_threshold: Minimum confidence for 'weighted' strategy |
| """ |
| self.yourmt3 = yourmt3_transcriber |
| self.bytedance = bytedance_transcriber |
| self.voting_strategy = voting_strategy |
| self.onset_tolerance = onset_tolerance_ms / 1000.0 |
| self.confidence_threshold = confidence_threshold |
|
|
| def transcribe( |
| self, |
| audio_path: Path, |
| output_dir: Optional[Path] = None |
| ) -> Path: |
| """ |
| Transcribe audio using ensemble of models. |
| |
| Args: |
| audio_path: Path to audio file (should be piano stem) |
| output_dir: Directory for output MIDI file |
| |
| Returns: |
| Path to ensemble MIDI file |
| """ |
| if output_dir is None: |
| output_dir = audio_path.parent |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"\n ═══ Ensemble Transcription ═══") |
| print(f" Strategy: {self.voting_strategy}") |
| print(f" Onset tolerance: {self.onset_tolerance*1000:.0f}ms") |
|
|
| |
| print(f"\n [1/2] Transcribing with YourMT3+...") |
| yourmt3_midi = self.yourmt3.transcribe_audio(audio_path, output_dir) |
| yourmt3_notes = self._extract_notes_from_midi(yourmt3_midi) |
| print(f" ✓ YourMT3+ found {len(yourmt3_notes)} notes") |
|
|
| |
| print(f"\n [2/2] Transcribing with ByteDance...") |
| bytedance_midi = self.bytedance.transcribe_audio(audio_path, output_dir) |
| bytedance_notes = self._extract_notes_from_midi(bytedance_midi) |
| print(f" ✓ ByteDance found {len(bytedance_notes)} notes") |
|
|
| |
| print(f"\n Voting with '{self.voting_strategy}' strategy...") |
| ensemble_notes = self._vote_notes( |
| [yourmt3_notes, bytedance_notes], |
| model_names=['YourMT3+', 'ByteDance'] |
| ) |
| print(f" ✓ Ensemble result: {len(ensemble_notes)} notes") |
|
|
| |
| ensemble_midi_path = output_dir / f"{audio_path.stem}_ensemble.mid" |
| self._notes_to_midi(ensemble_notes, ensemble_midi_path) |
|
|
| print(f" ✓ Ensemble MIDI saved: {ensemble_midi_path.name}") |
| print(f" ═══════════════════════════════\n") |
|
|
| return ensemble_midi_path |
|
|
| def _extract_notes_from_midi(self, midi_path: Path) -> List[Note]: |
| """ |
| Extract notes from MIDI file. |
| |
| Args: |
| midi_path: Path to MIDI file |
| |
| Returns: |
| List of Note objects |
| """ |
| pm = pretty_midi.PrettyMIDI(str(midi_path)) |
|
|
| notes = [] |
| for instrument in pm.instruments: |
| if instrument.is_drum: |
| continue |
|
|
| for note in instrument.notes: |
| notes.append(Note( |
| pitch=note.pitch, |
| onset=note.start, |
| offset=note.end, |
| velocity=note.velocity, |
| confidence=1.0 |
| )) |
|
|
| |
| notes.sort(key=lambda n: n.onset) |
| return notes |
|
|
| def _vote_notes( |
| self, |
| note_lists: List[List[Note]], |
| model_names: List[str] |
| ) -> List[Note]: |
| """ |
| Vote on notes from multiple models. |
| |
| Args: |
| note_lists: List of note lists from different models |
| model_names: Names of models (for logging) |
| |
| Returns: |
| Merged list of notes after voting |
| """ |
| if self.voting_strategy == 'weighted': |
| return self._vote_weighted(note_lists, model_names) |
| elif self.voting_strategy == 'intersection': |
| return self._vote_intersection(note_lists, model_names) |
| elif self.voting_strategy == 'union': |
| return self._vote_union(note_lists, model_names) |
| elif self.voting_strategy == 'majority': |
| return self._vote_majority(note_lists, model_names) |
| else: |
| raise ValueError(f"Unknown voting strategy: {self.voting_strategy}") |
|
|
| def _vote_weighted( |
| self, |
| note_lists: List[List[Note]], |
| model_names: List[str] |
| ) -> List[Note]: |
| """ |
| Weighted voting: Sum confidence scores, keep notes above threshold. |
| |
| Gives higher weight to ByteDance (piano specialist). |
| """ |
| |
| weights = {'YourMT3+': 0.4, 'ByteDance': 0.6} |
|
|
| |
| note_groups = {} |
|
|
| for model_idx, notes in enumerate(note_lists): |
| model_name = model_names[model_idx] |
| weight = weights.get(model_name, 1.0 / len(note_lists)) |
|
|
| for note in notes: |
| |
| onset_bucket = round(note.onset / self.onset_tolerance) |
| key = (onset_bucket, note.pitch) |
|
|
| if key not in note_groups: |
| note_groups[key] = [] |
|
|
| |
| note.confidence *= weight |
| note_groups[key].append(note) |
|
|
| |
| merged_notes = [] |
| for (onset_bucket, pitch), group in note_groups.items(): |
| |
| total_confidence = sum(n.confidence for n in group) |
|
|
| if total_confidence >= self.confidence_threshold: |
| |
| avg_onset = np.mean([n.onset for n in group]) |
| avg_offset = np.mean([n.offset for n in group]) |
| avg_velocity = int(np.mean([n.velocity for n in group])) |
|
|
| merged_notes.append(Note( |
| pitch=pitch, |
| onset=avg_onset, |
| offset=avg_offset, |
| velocity=avg_velocity, |
| confidence=total_confidence |
| )) |
|
|
| merged_notes.sort(key=lambda n: n.onset) |
| return merged_notes |
|
|
| def _vote_intersection( |
| self, |
| note_lists: List[List[Note]], |
| model_names: List[str] |
| ) -> List[Note]: |
| """ |
| Intersection voting: Only keep notes agreed upon by ALL models. |
| High precision, potentially lower recall. |
| """ |
| if len(note_lists) == 0: |
| return [] |
|
|
| |
| base_notes = note_lists[0] |
| matched_notes = [] |
|
|
| for base_note in base_notes: |
| |
| found_in_all = True |
|
|
| for other_notes in note_lists[1:]: |
| if not self._find_matching_note(base_note, other_notes): |
| found_in_all = False |
| break |
|
|
| if found_in_all: |
| matched_notes.append(base_note) |
|
|
| return matched_notes |
|
|
| def _vote_union( |
| self, |
| note_lists: List[List[Note]], |
| model_names: List[str] |
| ) -> List[Note]: |
| """ |
| Union voting: Keep ALL notes from ALL models. |
| High recall, potentially more false positives. |
| """ |
| |
| all_notes = [] |
| for notes in note_lists: |
| all_notes.extend(notes) |
|
|
| |
| note_groups = {} |
|
|
| for note in all_notes: |
| onset_bucket = round(note.onset / self.onset_tolerance) |
| key = (onset_bucket, note.pitch) |
|
|
| if key not in note_groups: |
| note_groups[key] = [] |
| note_groups[key].append(note) |
|
|
| |
| merged_notes = [] |
| for (onset_bucket, pitch), group in note_groups.items(): |
| avg_onset = np.mean([n.onset for n in group]) |
| avg_offset = np.mean([n.offset for n in group]) |
| avg_velocity = int(np.mean([n.velocity for n in group])) |
|
|
| merged_notes.append(Note( |
| pitch=pitch, |
| onset=avg_onset, |
| offset=avg_offset, |
| velocity=avg_velocity, |
| confidence=len(group) / len(note_lists) |
| )) |
|
|
| merged_notes.sort(key=lambda n: n.onset) |
| return merged_notes |
|
|
| def _vote_majority( |
| self, |
| note_lists: List[List[Note]], |
| model_names: List[str] |
| ) -> List[Note]: |
| """ |
| Majority voting: Keep notes predicted by >=50% of models. |
| Balanced precision and recall. |
| """ |
| threshold = len(note_lists) / 2.0 |
|
|
| |
| note_groups = {} |
|
|
| for notes in note_lists: |
| for note in notes: |
| onset_bucket = round(note.onset / self.onset_tolerance) |
| key = (onset_bucket, note.pitch) |
|
|
| if key not in note_groups: |
| note_groups[key] = [] |
| note_groups[key].append(note) |
|
|
| |
| merged_notes = [] |
| for (onset_bucket, pitch), group in note_groups.items(): |
| if len(group) >= threshold: |
| avg_onset = np.mean([n.onset for n in group]) |
| avg_offset = np.mean([n.offset for n in group]) |
| avg_velocity = int(np.mean([n.velocity for n in group])) |
|
|
| merged_notes.append(Note( |
| pitch=pitch, |
| onset=avg_onset, |
| offset=avg_offset, |
| velocity=avg_velocity, |
| confidence=len(group) / len(note_lists) |
| )) |
|
|
| merged_notes.sort(key=lambda n: n.onset) |
| return merged_notes |
|
|
| def _find_matching_note(self, target: Note, notes: List[Note]) -> Optional[Note]: |
| """Find a note that matches the target note within tolerance.""" |
| for note in notes: |
| if (note.pitch == target.pitch and |
| abs(note.onset - target.onset) <= self.onset_tolerance): |
| return note |
| return None |
|
|
| def _notes_to_midi(self, notes: List[Note], output_path: Path): |
| """ |
| Convert list of notes to MIDI file. |
| |
| Args: |
| notes: List of Note objects |
| output_path: Path for output MIDI file |
| """ |
| |
| mid = MidiFile() |
| track = MidiTrack() |
| mid.tracks.append(track) |
|
|
| |
| track.append(MetaMessage('set_tempo', tempo=500000, time=0)) |
|
|
| |
| |
|
|
| |
| events = [] |
|
|
| for note in notes: |
| |
| ticks_per_second = 480 * 2 |
| onset_ticks = int(note.onset * ticks_per_second) |
| offset_ticks = int(note.offset * ticks_per_second) |
|
|
| events.append((onset_ticks, 'note_on', note.pitch, note.velocity)) |
| events.append((offset_ticks, 'note_off', note.pitch, 0)) |
|
|
| |
| events.sort(key=lambda e: e[0]) |
|
|
| |
| previous_time = 0 |
| for abs_time, msg_type, pitch, velocity in events: |
| delta_time = abs_time - previous_time |
| previous_time = abs_time |
|
|
| track.append(Message( |
| msg_type, |
| note=pitch, |
| velocity=velocity, |
| time=delta_time |
| )) |
|
|
| |
| track.append(MetaMessage('end_of_track', time=0)) |
|
|
| |
| mid.save(output_path) |
|
|