Spaces:
Sleeping
Sleeping
| """ | |
| Robust per-note feature extraction from MusicXML files using partitura. | |
| Uses per-part part.note_array() with all relevant kwargs to obtain a | |
| vectorised structured numpy array. Only clef and tie information (which are | |
| not exposed in the note array) are resolved separately via the part's | |
| timeline objects. | |
| Extracted features per note: | |
| Positional: | |
| measure_idx 0-based bar index | |
| staff 0-based global staff (part-staff made unique) | |
| voice 0-based voice | |
| grid_position quantised macro position in 16ths within bar (Fraction str) | |
| micro_offset quantised micro offset from 16th grid (Fraction str) | |
| Content: | |
| clef encoded clef token (treble=0, bass=1, alto=2, ...) | |
| key_fifths circle-of-fifths (-7..+7) | |
| key_mode 0=major, 1=minor | |
| pitch_step 0-6 (C D E F G A B) | |
| pitch_alter 0-4 (mapping -2..+2 -> 0..4) | |
| pitch_octave raw octave number | |
| duration_quarters duration as Fraction string (for vocab lookup) | |
| ts_beats numerator of time signature | |
| ts_beat_type denominator of time signature | |
| Output: list[dict], one dict per note. Tied note continuations are merged by | |
| partitura (notes_tied used by note_array), so each entry already carries the | |
| total sounding duration. Grace notes are included. | |
| Usage:: | |
| from src.utils.data.extract_features import extract_features_from_file | |
| notes = extract_features_from_file("path/to/score.mxl", verbose=True) | |
| """ | |
| from __future__ import annotations | |
| import warnings | |
| from dataclasses import dataclass, field | |
| from fractions import Fraction | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| warnings.filterwarnings("ignore", category=UserWarning, module="partitura") | |
| import partitura as pt | |
| from partitura import score as pt_score | |
| STEP_TO_INT: Dict[str, int] = {"C": 0, "D": 1, "E": 2, "F": 3, "G": 4, "A": 5, "B": 6} | |
| CLEF_MAP: Dict[Tuple[str, int], int] = { | |
| ("G", 2): 0, # Treble | |
| ("F", 4): 1, # Bass | |
| ("C", 3): 2, # Alto | |
| ("C", 4): 3, # Tenor | |
| ("F", 3): 4, # Baritone F | |
| ("G", 1): 5, # French violin | |
| } | |
| # 16th-note grid resolution in quarter notes | |
| GRID_RESOLUTION = Fraction(1, 4) | |
| # Predefined micro-shift vocabulary | |
| # Maximum offset from a 16th-note grid point = half a 16th = 1/8 quarter. | |
| _MAX_MICRO = Fraction(1, 8) | |
| def _build_micro_shifts() -> List[Fraction]: | |
| """Predefined micro-shifts for common tuplet patterns and binary subdivisions. | |
| The 16th-note grid has 0.25q resolution. Notes that fall between grid | |
| points need a micro-shift. This includes: | |
| Tuplets: triplets (1/12), quintuplets (1/20, 1/10), septuplets | |
| (1/28, 1/14, 3/28), sextuplets, and 32nd-note triplets. | |
| Binary subdivisions: 32nd notes (+/-1/8), 64th notes (+/-1/16), and | |
| 128th notes (+/-1/32, +/-3/32) that fall between 16th-note grid points. | |
| Without the binary entries, 32nd-note positions get mis-quantised to the | |
| nearest tuplet offset (e.g. 1/8 -> 3/28), which causes spurious tuplet | |
| ratios (7:4, 7:5) in MusicXML reconstruction. | |
| """ | |
| shifts: set[Fraction] = {Fraction(0)} | |
| # Tuplet offsets (triplets, quintuplets, sextuplets, septuplets) | |
| for subdivisions in (3, 5, 6, 7): | |
| for i in range(subdivisions): | |
| pos = Fraction(i, subdivisions) | |
| grid = Fraction(round(float(pos) * 4), 4) # nearest 16th | |
| offset = pos - grid | |
| if offset != 0 and abs(offset) <= _MAX_MICRO: | |
| shifts.add(offset) | |
| shifts.add(-offset) | |
| # Also cover 32nd-note triplets within one 16th | |
| for i in range(3): | |
| pos = Fraction(i, 12) | |
| if 0 < pos <= _MAX_MICRO: | |
| shifts.add(pos) | |
| shifts.add(-pos) | |
| # Binary subdivision offsets (32nd, 64th, 128th notes) | |
| for denom in (8, 16, 32): # 32nd, 64th, 128th note subdivisions | |
| for numer in range(1, denom): | |
| offset = Fraction(numer, denom) | |
| if 0 < offset <= _MAX_MICRO: | |
| shifts.add(offset) | |
| shifts.add(-offset) | |
| return sorted(shifts) | |
| MICRO_SHIFTS: List[Fraction] = _build_micro_shifts() | |
| MICRO_SHIFTS_FLOAT: List[float] = [float(f) for f in MICRO_SHIFTS] | |
| def quantize_micro(residual_q: Fraction) -> Fraction: | |
| """Snap a residual (Fraction in quarters) to the nearest predefined micro-shift.""" | |
| res_f = float(residual_q) | |
| return _quantize_micro_float(res_f) | |
| _MICRO_ZERO_IDX = next(i for i, f in enumerate(MICRO_SHIFTS) if f == 0) | |
| def _quantize_micro_float(res_f: float) -> Fraction: | |
| """Fast float-only micro quantisation (avoids Fraction construction).""" | |
| best_idx = _MICRO_ZERO_IDX | |
| best_dist = abs(res_f) # distance to 0 | |
| for i, ff in enumerate(MICRO_SHIFTS_FLOAT): | |
| d = abs(res_f - ff) | |
| if d < best_dist: | |
| best_dist = d | |
| best_idx = i | |
| return MICRO_SHIFTS[best_idx] | |
| class _ClefTracker: | |
| """Sorted list of (onset_div, clef_token) per staff, with bisect lookup.""" | |
| entries: List[Tuple[int, int]] = field(default_factory=list) | |
| def add(self, onset_div: int, clef_token: int) -> None: | |
| self.entries.append((onset_div, clef_token)) | |
| def finalise(self) -> None: | |
| """Sort by onset time -- call once after all clefs are added.""" | |
| self.entries.sort(key=lambda x: x[0]) | |
| def at(self, onset_div: int) -> int: | |
| """Return the active clef token at *onset_div* (last clef <= onset).""" | |
| active = 0 # default treble | |
| for t, tok in self.entries: | |
| if t <= onset_div: | |
| active = tok | |
| else: | |
| break | |
| return active | |
| def _extract_part_features( | |
| part: pt_score.Part, | |
| part_idx: int, | |
| global_staff_offset: int, | |
| verbose: bool = False, | |
| ) -> Tuple[List[dict], int]: | |
| """Extract per-note features from a single Part using part.note_array(). | |
| The heavy lifting is done by partitura's vectorised note_array which | |
| provides pitch spelling, key/time signature, metrical position, staff and | |
| divs_per_quarter in one call. Only clef information is resolved | |
| separately because it is not part of the note-array API. | |
| Tied note continuations are already merged by note_array (which uses | |
| part.notes_tied), so each row represents the full sounding note with | |
| combined duration. Grace notes are not filtered out by partitura's | |
| note_array; the include_grace_notes flag only controls whether the | |
| is_grace / grace_type columns are present. | |
| Parameters | |
| ---------- | |
| part : partitura Part | |
| part_idx : int | |
| global_staff_offset : int | |
| Staves already counted from previous parts. | |
| verbose : bool | |
| Returns | |
| ------- | |
| notes : list[dict] | |
| num_staves : int | |
| """ | |
| # Vectorised note array | |
| na = part.note_array( | |
| include_pitch_spelling=True, | |
| include_key_signature=True, | |
| include_time_signature=True, | |
| include_metrical_position=True, | |
| include_staff=True, | |
| include_divs_per_quarter=True, | |
| ) | |
| n_notes = len(na) | |
| if n_notes == 0: | |
| return [], 0 | |
| # Local -> global staff mapping | |
| local_staves = sorted(np.unique(na["staff"]).tolist()) | |
| local_to_global: Dict[int, int] = { | |
| s: global_staff_offset + i for i, s in enumerate(local_staves) | |
| } | |
| num_staves = len(local_staves) | |
| if verbose: | |
| print( | |
| f" Part {part_idx}: {n_notes} notes, " | |
| f"local staves {local_staves} -> global {list(local_to_global.values())}" | |
| ) | |
| # Vectorised global staff | |
| global_staff_arr = np.array( | |
| [local_to_global[int(s)] for s in na["staff"]], dtype=np.int32 | |
| ) | |
| # Vectorised position-in-bar -> grid + micro | |
| divs_pq = na["divs_pq"].astype(np.float64) | |
| rel_onset_div = na["rel_onset_div"].astype(np.float64) | |
| # Position within bar in quarter notes (exact integer-div-based) | |
| pos_in_quarters = rel_onset_div / divs_pq | |
| grid_res_f = float(GRID_RESOLUTION) # 0.25 | |
| grid_indices = np.rint(pos_in_quarters / grid_res_f).astype(np.int64) | |
| grid_points = grid_indices * grid_res_f | |
| residuals = pos_in_quarters - grid_points | |
| # Grid position as Fraction strings (vectorised via lookup table) | |
| # Build a mapping from grid_idx -> Fraction string once | |
| max_grid_idx = int(grid_indices.max()) if n_notes else 0 | |
| _grid_str_lut = { | |
| idx: str(Fraction(idx, 4)) for idx in range(max_grid_idx + 1) | |
| } | |
| grid_position_strs = [_grid_str_lut.get(int(gi), str(Fraction(int(gi), 4))) | |
| for gi in grid_indices] | |
| # Micro offsets (fast float-only quantisation to predefined shifts) | |
| micro_strs: List[str] = [] | |
| micro_floats: List[float] = [] | |
| for res in residuals: | |
| micro = _quantize_micro_float(float(res)) | |
| micro_strs.append(str(micro)) | |
| micro_floats.append(float(micro)) | |
| # Vectorised pitch spelling | |
| step_arr = na["step"] | |
| pitch_step_arr = np.array( | |
| [STEP_TO_INT.get(s.decode("utf-8") if isinstance(s, bytes) else s, 0) | |
| for s in step_arr], | |
| dtype=np.int32, | |
| ) | |
| alter_arr = na["alter"].astype(np.float64) | |
| alter_arr = np.where(np.isnan(alter_arr), 0.0, alter_arr) | |
| pitch_alter_arr = alter_arr.astype(np.int32) + 2 # shift -2..+2 -> 0..4 | |
| pitch_octave_arr = na["octave"].astype(np.int32) | |
| # Vectorised key signature | |
| key_fifths_arr = na["ks_fifths"].astype(np.int32) | |
| key_mode_arr = na["ks_mode"].astype(np.int32) | |
| key_mode_arr = np.where(key_mode_arr < 0, 0, key_mode_arr) # None -> major | |
| # Vectorised time signature | |
| ts_beats_arr = na["ts_beats"].astype(np.int32) | |
| ts_beat_type_arr = na["ts_beat_type"].astype(np.int32) | |
| # Vectorised voice (0-based) | |
| voice_arr = na["voice"].astype(np.int32) - 1 | |
| # Vectorised duration as Fraction strings | |
| # Use integer divisions for exact fractions: dur = duration_div / divs_pq | |
| dur_div = na["duration_div"] | |
| dur_strs = [ | |
| str(Fraction(int(dd), int(dpq))) | |
| for dd, dpq in zip(dur_div, na["divs_pq"]) | |
| ] | |
| # Vectorised measure index (0-based) | |
| # Build sorted measure-onset array for O(log N) binary search lookup | |
| # (part.measure_number_map is O(N) per call -- too slow for vectorised use) | |
| measures = list(part.iter_all(pt_score.Measure)) | |
| if measures: | |
| _m_starts = np.array([m.start.t for m in measures], dtype=np.int64) | |
| _m_nums = np.array([m.number for m in measures], dtype=np.int64) | |
| else: | |
| _m_starts = np.array([0], dtype=np.int64) | |
| _m_nums = np.array([1], dtype=np.int64) | |
| onset_divs = na["onset_div"].astype(np.int64) | |
| # searchsorted(side='right') - 1 gives the last measure whose start <= onset | |
| measure_idx_arr = _m_nums[np.searchsorted(_m_starts, onset_divs, side="right") - 1] - 1 | |
| # Clef (not in note_array, resolved from timeline) | |
| clef_trackers: Dict[int, _ClefTracker] = {} | |
| for obj in part.iter_all(): | |
| if obj.__class__.__name__ == "Clef": | |
| staff = getattr(obj, "staff", 1) or 1 | |
| if staff not in clef_trackers: | |
| clef_trackers[staff] = _ClefTracker() | |
| clef_trackers[staff].add( | |
| obj.start.t, CLEF_MAP.get((obj.sign, obj.line), 0) | |
| ) | |
| for ct in clef_trackers.values(): | |
| ct.finalise() | |
| clef_arr = np.zeros(n_notes, dtype=np.int32) | |
| for i in range(n_notes): | |
| local_staff = int(na["staff"][i]) | |
| ct = clef_trackers.get(local_staff) | |
| clef_arr[i] = ct.at(int(onset_divs[i])) if ct else 0 | |
| na_ids = na["id"] | |
| # Build output list | |
| notes_out: List[dict] = [] | |
| for i in range(n_notes): | |
| nid = na_ids[i] | |
| if isinstance(nid, bytes): | |
| nid = nid.decode("utf-8") | |
| else: | |
| nid = str(nid) # np.str_ -> native str | |
| notes_out.append({ | |
| # Identifiers (debug, not model features) | |
| "note_id": nid, | |
| "part_idx": part_idx, | |
| "midi_pitch": int(na["pitch"][i]), | |
| # -- Positional -- | |
| "measure_idx": int(measure_idx_arr[i]), | |
| "staff": int(global_staff_arr[i]), | |
| "voice": int(voice_arr[i]), | |
| "grid_position": grid_position_strs[i], | |
| "grid_position_idx": int(grid_indices[i]), | |
| "micro_offset": micro_strs[i], | |
| # -- Content -- | |
| "clef": int(clef_arr[i]), | |
| "key_fifths": int(key_fifths_arr[i]), | |
| "key_mode": int(key_mode_arr[i]), | |
| "pitch_step": int(pitch_step_arr[i]), | |
| "pitch_alter": int(pitch_alter_arr[i]), | |
| "pitch_octave": int(pitch_octave_arr[i]), | |
| "duration_quarters": dur_strs[i], | |
| "ts_beats": int(ts_beats_arr[i]), | |
| "ts_beat_type": int(ts_beat_type_arr[i]), | |
| # -- Debug -- | |
| "onset_div": int(onset_divs[i]), | |
| "position_in_quarters": float(pos_in_quarters[i]), | |
| "grid_point_quarters": float(grid_points[i]), | |
| "micro_offset_quarters": micro_floats[i], | |
| }) | |
| return notes_out, num_staves | |
| def extract_features( | |
| score: pt_score.Score, | |
| *, | |
| verbose: bool = False, | |
| ) -> List[dict]: | |
| """Extract per-note features from a loaded partitura Score. | |
| Grace notes are included (partitura does not filter them; they typically | |
| have zero or very short duration). Tied note continuations are already | |
| merged (each note carries total sounding duration). | |
| Parameters | |
| ---------- | |
| score : partitura Score | |
| verbose : bool | |
| Returns | |
| ------- | |
| List[dict] - one dict per note, sorted by (onset_div, part, staff, voice). | |
| """ | |
| all_notes: List[dict] = [] | |
| global_staff_offset = 0 | |
| for part_idx, part in enumerate(score.parts): | |
| part_notes, n_staves = _extract_part_features( | |
| part, part_idx, global_staff_offset, verbose=verbose | |
| ) | |
| all_notes.extend(part_notes) | |
| global_staff_offset += n_staves | |
| # Sort: by onset time, then part, then staff, then voice | |
| all_notes.sort(key=lambda n: (n["onset_div"], n["part_idx"], n["staff"], n["voice"])) | |
| if verbose: | |
| print(f"Total notes extracted: {len(all_notes)}") | |
| print(f"Total global staves: {global_staff_offset}") | |
| return all_notes | |
| def extract_features_from_file( | |
| path: str | Path, | |
| *, | |
| verbose: bool = False, | |
| ) -> List[dict]: | |
| """Convenience wrapper: load a MusicXML file and extract features. | |
| Parameters | |
| ---------- | |
| path : str or Path | |
| Path to .xml, .mxl, or .musicxml file. | |
| verbose : bool | |
| Returns | |
| ------- | |
| List[dict] | |
| """ | |
| score = pt.load_score(str(path)) | |
| return extract_features(score, verbose=verbose) | |
| _ALTER_SYM = {0: "𝄫", 1: "♭", 2: "♮", 3: "♯", 4: "𝄪"} | |
| _STEP_NAMES = ["C", "D", "E", "F", "G", "A", "B"] | |
| def pretty_print_notes(notes: List[dict], max_notes: int = 40) -> None: | |
| """Print a human-readable table for visual verification.""" | |
| header = ( | |
| f"{'#':>4} {'id':<8} {'bar':>3} {'st':>2} {'v':>1} " | |
| f"{'grid':>5} {'µ':>6} {'pitch':<6} {'oct':>3} {'dur':<8} " | |
| f"{'clef':>4} {'ks':>3} {'ts':>5} {'midi':>4}" | |
| ) | |
| print(header) | |
| print("-" * len(header)) | |
| for i, n in enumerate(notes[:max_notes]): | |
| step_name = _STEP_NAMES[n["pitch_step"]] | |
| alter_sym = _ALTER_SYM.get(n["pitch_alter"], "?") | |
| if n["pitch_alter"] == 2: | |
| alter_sym = "" # natural - keep it clean | |
| pitch_str = f"{step_name}{alter_sym}" | |
| ks_str = f"{n['key_fifths']:+d}{'m' if n['key_mode'] else 'M'}" | |
| ts_str = f"{n['ts_beats']}/{n['ts_beat_type']}" | |
| print( | |
| f"{i:4d} {str(n['note_id']):<8} {n['measure_idx']:3d} " | |
| f"{n['staff']:2d} {n['voice']:1d} " | |
| f"{n['grid_position']:>5} {n['micro_offset']:>6} " | |
| f"{pitch_str:<6} {n['pitch_octave']:3d} {n['duration_quarters']:<8} " | |
| f"{n['clef']:4d} {ks_str:>3} {ts_str:>5} {n['midi_pitch']:4d}" | |
| ) | |
| if len(notes) > max_notes: | |
| print(f"\t... ({len(notes) - max_notes} more notes)") | |
| class VocabStats: | |
| """Accumulator for vocabulary statistics across many files.""" | |
| durations: set = field(default_factory=set) | |
| grid_positions: set = field(default_factory=set) | |
| micro_offsets: set = field(default_factory=set) | |
| pitch_steps: set = field(default_factory=set) | |
| pitch_alters: set = field(default_factory=set) | |
| pitch_octaves: set = field(default_factory=set) | |
| clefs: set = field(default_factory=set) | |
| key_fifths: set = field(default_factory=set) | |
| key_modes: set = field(default_factory=set) | |
| ts_beats: set = field(default_factory=set) | |
| ts_beat_types: set = field(default_factory=set) | |
| voices: set = field(default_factory=set) | |
| staffs: set = field(default_factory=set) | |
| max_measure_idx: int = 0 | |
| n_notes: int = 0 | |
| n_files: int = 0 | |
| def update(self, notes: List[dict]) -> None: | |
| for n in notes: | |
| self.durations.add(n["duration_quarters"]) | |
| self.grid_positions.add(n["grid_position"]) | |
| self.micro_offsets.add(n["micro_offset"]) | |
| self.pitch_steps.add(n["pitch_step"]) | |
| self.pitch_alters.add(n["pitch_alter"]) | |
| self.pitch_octaves.add(n["pitch_octave"]) | |
| self.clefs.add(n["clef"]) | |
| self.key_fifths.add(n["key_fifths"]) | |
| self.key_modes.add(n["key_mode"]) | |
| self.ts_beats.add(n["ts_beats"]) | |
| self.ts_beat_types.add(n["ts_beat_type"]) | |
| self.voices.add(n["voice"]) | |
| self.staffs.add(n["staff"]) | |
| if n["measure_idx"] > self.max_measure_idx: | |
| self.max_measure_idx = n["measure_idx"] | |
| self.n_notes += len(notes) | |
| self.n_files += 1 | |
| def summary(self) -> str: | |
| lines = [ | |
| f"VocabStats ({self.n_files} files, {self.n_notes:,} notes)", | |
| f" durations: {len(self.durations):>6} unique fraction strings", | |
| f" grid_positions: {len(self.grid_positions):>6} unique (max grid_idx ~ {max((int(Fraction(g) / GRID_RESOLUTION) for g in self.grid_positions), default=0)})", | |
| f" micro_offsets: {len(self.micro_offsets):>6} unique", | |
| f" pitch_steps: {sorted(self.pitch_steps)} (expect 0-6)", | |
| f" pitch_alters: {sorted(self.pitch_alters)} (expect 0-4)", | |
| f" pitch_octaves: {sorted(self.pitch_octaves)}", | |
| f" clefs: {sorted(self.clefs)}", | |
| f" key_fifths: {sorted(self.key_fifths)} (range -7..+7)", | |
| f" key_modes: {sorted(self.key_modes)} (0=major, 1=minor)", | |
| f" ts_beats: {sorted(self.ts_beats)}", | |
| f" ts_beat_types: {sorted(self.ts_beat_types)}", | |
| f" voices: {sorted(self.voices)} (0-based)", | |
| f" staffs: {sorted(self.staffs)} (0-based global)", | |
| f" max_measure_idx: {self.max_measure_idx}", | |
| ] | |
| return "\n".join(lines) | |
| def main() -> None: | |
| """CLI entry point - extract & pretty-print features from a file.""" | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="Extract & display per-note features from a MusicXML file.", | |
| ) | |
| parser.add_argument("input", type=Path, help="Path to .xml / .mxl file") | |
| parser.add_argument("-n", "--max-notes", type=int, default=60, | |
| help="Max notes to display (default 60)") | |
| parser.add_argument("-v", "--verbose", action="store_true") | |
| parser.add_argument("--stats", action="store_true", | |
| help="Print vocab stats summary") | |
| args = parser.parse_args() | |
| notes = extract_features_from_file(args.input, verbose=args.verbose) | |
| pretty_print_notes(notes, max_notes=args.max_notes) | |
| if args.stats: | |
| vs = VocabStats() | |
| vs.update(notes) | |
| print("\n" + vs.summary()) | |
| if __name__ == "__main__": | |
| main() | |