""" Multi-feature score graph dataset for LaM-SLidE autoencoder training. Provides ScoreGraphMultiFeatureDataset, collate functions, and feature specs. The dataset loads .pt graph files produced by process_dataset.py and extracts configurable subsets of 14 per-note features for ablation studies. Feature tensor: 14 columns in graph['note'].x (from process_dataset.py). Col Feature Raw Range Transform Vocab Notes 0 grid_position [0, 32] none 33 tokenised (16th grid) 1 micro_offset [0, 12] none 13 tokenised micro-shift 2 measure_idx [0, 255] none 256 0-based bar index 3 voice [0, 24] none 25 already 0-based (20 unique) 4 pitch_step [0, 6] none 7 C D E F G A B 5 pitch_alter [0, 4] none 5 shifted +2 in extraction 6 pitch_octave [-3, 11] +3 15 shift to 0-base 7 duration [0, 726] none 727 tokenised duration vocab 8 clef [0, 5] none 6 treble=0 bass=1 ... 9 ts_beats [1, 80] remap 57 remap 57 observed values 10 ts_beat_type {1,2,4,8,16,32} remap 6 remap 6 sparse -> [0,5] 11 key_fifths [-7, +7] +7 15 circle of fifths 12 key_mode [0, 1] none 2 0=major, 1=minor (optional) 13 staff [0, 36] none 37 already 0-based """ from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional import numpy as np import torch from torch.utils.data import Dataset @dataclass class NoteFeatureSpec: """Specification for a single note feature.""" name: str col_index: int # Column index in graph['note'].x tensor vocab_size: int shift: int = 0 # Value to add to shift negative values to positive # Optional remapping for sparse value sets (e.g., ts_beat_type: 2,4,8,16,32 -> 0,1,2,3,4) remap: Optional[Dict[int, int]] = None # Remapping for ts_beat_type: actual denominators -> consecutive tokens # After dropping files with beat types {3, 5, 6, 9, 64}: remaining {1, 2, 4, 8, 16, 32} TS_BEAT_TYPE_REMAP = {1: 0, 2: 1, 4: 2, 8: 3, 16: 4, 32: 5} # Remapping for ts_beats: observed numerator values -> consecutive tokens. # 57 unique values found in filtered data (max_notes=1530, max_bars=256). # Unknown values are clamped to the closest entry in __getitem__. TS_BEATS_REMAP = { 1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9, 11: 10, 12: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 26: 25, 27: 26, 28: 27, 29: 28, 30: 29, 31: 30, 33: 31, 34: 32, 35: 33, 36: 34, 37: 35, 39: 36, 41: 37, 42: 38, 45: 39, 47: 40, 48: 41, 54: 42, 55: 43, 56: 44, 57: 45, 58: 46, 59: 47, 60: 48, 61: 49, 62: 50, 63: 51, 64: 52, 67: 53, 69: 54, 73: 55, 80: 56, } # Column indices in graph['note'].x tensor (shape: [num_notes, 14]) # Matches FEATURE_COLUMNS in process_dataset.py (verified 2025-02-07) COL_INDICES = { 'grid_position': 0, # tokenised grid position 'micro_offset': 1, # tokenised micro offset 'measure_idx': 2, # 0-based bar index 'voice': 3, # 0-based voice 'pitch_step': 4, # 0-6 (C...B) 'pitch_alter': 5, # 0-4 (shifted +2) 'pitch_octave': 6, # raw octave (-3..11) 'duration': 7, # tokenised duration 'clef': 8, # 0-5 'ts_beats': 9, # raw numerator 'ts_beat_type': 10, # raw denominator {2,4,8,16} 'key_fifths': 11, # circle of fifths (-7..+7) 'key_mode': 12, # 0=major, 1=minor 'staff': 13, # 0-based global staff } # Pre-defined feature specifications for the 14-column feature tensor. # Updated 2026-02-10 for new pipeline (process_dataset.py, max_notes=1530, max_bars=256). # Value ranges verified via data_utils/analyze_features.py on full 68,542 graphs. FEATURE_SPECS = { # Position features 'grid_position': NoteFeatureSpec('grid_position', col_index=0, vocab_size=33), # tokenised [0, 32] 'micro_offset': NoteFeatureSpec('micro_offset', col_index=1, vocab_size=21), # tokenised [0, 20] # Bar index 'measure_idx': NoteFeatureSpec('measure_idx', col_index=2, vocab_size=256), # 0-based [0, 255] # Voice and staff (already 0-based from extraction) 'voice': NoteFeatureSpec('voice', col_index=3, vocab_size=25), # [0, 24] (20 unique) 'staff': NoteFeatureSpec('staff', col_index=13, vocab_size=37), # [0, 36] # Pitch features 'pitch_step': NoteFeatureSpec('pitch_step', col_index=4, vocab_size=7), # [0, 6] C...B 'pitch_alter': NoteFeatureSpec('pitch_alter', col_index=5, vocab_size=5), # [0, 4] (shifted +2 in extraction) 'pitch_octave': NoteFeatureSpec('pitch_octave', col_index=6, vocab_size=15, shift=3), # [-3, 11] -> [0, 14] # Duration (tokenised) 'duration': NoteFeatureSpec('duration', col_index=7, vocab_size=727), # [0, 726] (filtered vocab) # Context: clef, time signature, key 'clef': NoteFeatureSpec('clef', col_index=8, vocab_size=6), # [0, 5] 'ts_beats': NoteFeatureSpec('ts_beats', col_index=9, vocab_size=57, remap=TS_BEATS_REMAP), # 57 unique values 'ts_beat_type': NoteFeatureSpec('ts_beat_type', col_index=10, vocab_size=6, remap=TS_BEAT_TYPE_REMAP), # 6 values: {1,2,4,8,16,32} 'key_fifths': NoteFeatureSpec('key_fifths', col_index=11, vocab_size=15, shift=7), # [-7, +7] -> [0, 14] 'key_mode': NoteFeatureSpec('key_mode', col_index=12, vocab_size=2), # 0=major, 1=minor (optional) } class ScoreGraphMultiFeatureDataset(Dataset): """ Dataset for loading score graphs with multiple configurable features. Supports: - Multiple discrete features per note - Graph structure (edges) for potential GNN integration - Configurable feature selection for ablation studies Args: graph_dir: Directory containing .pt graph files features: List of feature names to extract (from FEATURE_SPECS) file_list: Optional list of specific files to use max_notes: Maximum number of notes per graph (filter out larger) max_bars: Maximum number of bars per graph (filter out larger) identifier_pool_size: Size of entity ID pool include_graph: Whether to include graph structure (edges) id_assignment: 'sequential' (0,1,2,...) or 'random' (random permutation) seed: Random seed for entity ID assignment """ def __init__( self, graph_dir: str, features: List[str] = ['grid_position'], file_list: Optional[List[str]] = None, max_notes: Optional[int] = 256, max_bars: Optional[int] = None, identifier_pool_size: int = 512, include_graph: bool = False, id_assignment: str = 'sequential', seed: int = 42, ): self.graph_dir = Path(graph_dir) self.features = features self.max_notes = max_notes self.max_bars = max_bars self.identifier_pool_size = identifier_pool_size self.include_graph = include_graph self.id_assignment = id_assignment self.seed = seed # Validate features for feat in features: if feat not in FEATURE_SPECS: raise ValueError(f"Unknown feature: {feat}. Available: {list(FEATURE_SPECS.keys())}") self.feature_specs = [FEATURE_SPECS[f] for f in features] # Track if file_list was provided (for filtering logic) self._has_file_list = file_list is not None # Load file list if file_list is not None: self.graph_files = [self.graph_dir / f for f in file_list] self.graph_files = [f for f in self.graph_files if f.exists()] else: self.graph_files = sorted(self.graph_dir.glob("*.pt")) # Filter by max_notes and max_bars if max_notes is not None or max_bars is not None: self._filter_by_constraints() print(f"ScoreGraphMultiFeatureDataset: {len(self.graph_files)} graphs") print(f"\tFeatures: {features}") print(f"\tMax notes: {max_notes}, Max bars: {max_bars}, ID pool: {identifier_pool_size}") def _filter_by_constraints(self): """Filter graphs that exceed max_notes or max_bars, using cache when available.""" # Build cache filename based on constraints cache_parts = [] if self.max_notes is not None: cache_parts.append(f"notes{self.max_notes}") if self.max_bars is not None: cache_parts.append(f"bars{self.max_bars}") cache_name = "_".join(cache_parts) cache_file = self.graph_dir / f".filtered_files_{cache_name}.txt" if cache_file.exists(): with open(cache_file) as f: cached_filenames = set(f.read().strip().split('\n')) if self._has_file_list: # Intersect with provided file_list (don't replace!) current_filenames = {f.name for f in self.graph_files} valid_filenames = current_filenames & cached_filenames self.graph_files = [f for f in self.graph_files if f.name in valid_filenames] else: # Use cache directly - sort for consistent ordering self.graph_files = sorted([self.graph_dir / fn for fn in cached_filenames if (self.graph_dir / fn).exists()]) else: # Build cache by checking each file filtered_files = [] for f in self.graph_files: try: data = torch.load(f, weights_only=False) if self._check_constraints(data): filtered_files.append(f) except Exception: continue # Save cache (only when not using file_list) if not self._has_file_list: with open(cache_file, 'w') as cf: cf.write('\n'.join(f.name for f in filtered_files)) self.graph_files = filtered_files def _check_constraints(self, data) -> bool: """Check if a graph satisfies max_notes and max_bars constraints.""" if not isinstance(data, dict) or 'graph' not in data: return False graph = data['graph'] if 'note' not in graph.node_types: return False note_x = graph['note'].x num_notes = note_x.shape[0] # Check max_notes constraint if self.max_notes is not None and num_notes > self.max_notes: return False # Check max_bars constraint (measure_idx is col 2, already 0-based) if self.max_bars is not None: bar_idx_col = note_x[:, 2] # col_index for measure_idx max_bar = bar_idx_col.max().item() + 1 # +1 since measure_idx is 0-based if max_bar > self.max_bars: return False return True def _get_num_notes(self, data) -> Optional[int]: """Extract number of notes from graph['note'].x.""" if isinstance(data, dict) and 'graph' in data: graph = data['graph'] if 'note' in graph.node_types: return graph['note'].x.shape[0] # Fallback to num_notes key if 'num_notes' in data: return data['num_notes'] return None def __len__(self): return len(self.graph_files) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Load a graph and extract configured features from graph['note'].x.""" graph_path = self.graph_files[idx] data = torch.load(graph_path, weights_only=False) # Get the HeteroData graph and note features graph = data['graph'] note_x = graph['note'].x # Shape: [num_notes, 14] num_notes = note_x.shape[0] # Generate unique entity IDs for this sample if self.id_assignment == 'sequential': entity_ids = torch.arange(num_notes, dtype=torch.long) else: rng = np.random.RandomState(self.seed + idx) entity_ids = torch.from_numpy( rng.choice(self.identifier_pool_size, size=num_notes, replace=False) ).long() # Extract features from graph['note'].x columns result = { 'entity_ids': entity_ids, 'num_entities': num_notes, } for spec in self.feature_specs: # Get values from the column index values = note_x[:, spec.col_index].long() # Apply remapping if specified (for sparse value sets like ts_beat_type) if spec.remap is not None: # Vectorized remapping using a lookup tensor. # Values not in the remap dict get mapped to the nearest known key. max_key = max(spec.remap.keys()) remap_tensor = torch.zeros(max_key + 1, dtype=torch.long) for old_val, new_val in spec.remap.items(): remap_tensor[old_val] = new_val clamped = values.clamp(0, max_key) # For values not in the remap, clamp ensures a valid index; # stray values map to 0 which is acceptable as fallback. values = remap_tensor[clamped] elif spec.shift != 0: # Apply shift if needed (only if no remap) values = values + spec.shift # Clamp to valid range values = values.clamp(0, spec.vocab_size - 1) result[spec.name] = values # Include full graph if requested (for HGT pre-processing) if self.include_graph: result['graph'] = graph return result def get_vocab_sizes(self) -> Dict[str, int]: """Get vocabulary sizes for configured features.""" return {spec.name: spec.vocab_size for spec in self.feature_specs} def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]: """ Collate function for multi-feature dataset. Pads all samples to the maximum number of entities in the batch. When graphs are present (HGT mode), extracts edge_dicts for each sample. """ from src.model.note_hgt import NoteHGT batch_size = len(batch) max_entities = max(sample['num_entities'] for sample in batch) # Get feature names from first sample (excluding special keys) special_keys = {'entity_ids', 'num_entities', 'edge_index', 'graph'} feature_names = [k for k in batch[0].keys() if k not in special_keys] # Initialize tensors entity_ids = torch.zeros(batch_size, max_entities, dtype=torch.long) mask = torch.zeros(batch_size, max_entities, dtype=torch.bool) num_entities = torch.tensor([s['num_entities'] for s in batch]) features = {name: torch.zeros(batch_size, max_entities, dtype=torch.long) for name in feature_names} # Fill tensors for i, sample in enumerate(batch): n = sample['num_entities'] entity_ids[i, :n] = sample['entity_ids'] mask[i, :n] = True for name in feature_names: features[name][i, :n] = sample[name] result = { 'entity_ids': entity_ids, 'mask': mask, 'num_entities': num_entities, **features, } # Handle graph data if present (for HGT): extract edge_dicts if 'graph' in batch[0]: result['edge_dicts'] = [ NoteHGT.extract_edge_dict(s['graph']) for s in batch ] return result