Spaces:
Running
Running
| """ | |
| Multi-feature score graph dataset for LaM-SLidE autoencoder training. | |
| Provides ScoreGraphMultiFeatureDataset, collate functions, and feature specs. | |
| The dataset loads .pt graph files produced by process_dataset.py and extracts | |
| configurable subsets of 14 per-note features for ablation studies. | |
| Feature tensor: 14 columns in graph['note'].x (from process_dataset.py). | |
| Col Feature Raw Range Transform Vocab Notes | |
| 0 grid_position [0, 32] none 33 tokenised (16th grid) | |
| 1 micro_offset [0, 12] none 13 tokenised micro-shift | |
| 2 measure_idx [0, 255] none 256 0-based bar index | |
| 3 voice [0, 24] none 25 already 0-based (20 unique) | |
| 4 pitch_step [0, 6] none 7 C D E F G A B | |
| 5 pitch_alter [0, 4] none 5 shifted +2 in extraction | |
| 6 pitch_octave [-3, 11] +3 15 shift to 0-base | |
| 7 duration [0, 726] none 727 tokenised duration vocab | |
| 8 clef [0, 5] none 6 treble=0 bass=1 ... | |
| 9 ts_beats [1, 80] remap 57 remap 57 observed values | |
| 10 ts_beat_type {1,2,4,8,16,32} remap 6 remap 6 sparse -> [0,5] | |
| 11 key_fifths [-7, +7] +7 15 circle of fifths | |
| 12 key_mode [0, 1] none 2 0=major, 1=minor (optional) | |
| 13 staff [0, 36] none 37 already 0-based | |
| """ | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| class NoteFeatureSpec: | |
| """Specification for a single note feature.""" | |
| name: str | |
| col_index: int # Column index in graph['note'].x tensor | |
| vocab_size: int | |
| shift: int = 0 # Value to add to shift negative values to positive | |
| # Optional remapping for sparse value sets (e.g., ts_beat_type: 2,4,8,16,32 -> 0,1,2,3,4) | |
| remap: Optional[Dict[int, int]] = None | |
| # Remapping for ts_beat_type: actual denominators -> consecutive tokens | |
| # After dropping files with beat types {3, 5, 6, 9, 64}: remaining {1, 2, 4, 8, 16, 32} | |
| TS_BEAT_TYPE_REMAP = {1: 0, 2: 1, 4: 2, 8: 3, 16: 4, 32: 5} | |
| # Remapping for ts_beats: observed numerator values -> consecutive tokens. | |
| # 57 unique values found in filtered data (max_notes=1530, max_bars=256). | |
| # Unknown values are clamped to the closest entry in __getitem__. | |
| TS_BEATS_REMAP = { | |
| 1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9, | |
| 11: 10, 12: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, | |
| 21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 26: 25, 27: 26, 28: 27, 29: 28, 30: 29, | |
| 31: 30, 33: 31, 34: 32, 35: 33, 36: 34, 37: 35, 39: 36, 41: 37, 42: 38, 45: 39, | |
| 47: 40, 48: 41, 54: 42, 55: 43, 56: 44, 57: 45, 58: 46, 59: 47, 60: 48, 61: 49, | |
| 62: 50, 63: 51, 64: 52, 67: 53, 69: 54, 73: 55, 80: 56, | |
| } | |
| # Column indices in graph['note'].x tensor (shape: [num_notes, 14]) | |
| # Matches FEATURE_COLUMNS in process_dataset.py (verified 2025-02-07) | |
| COL_INDICES = { | |
| 'grid_position': 0, # tokenised grid position | |
| 'micro_offset': 1, # tokenised micro offset | |
| 'measure_idx': 2, # 0-based bar index | |
| 'voice': 3, # 0-based voice | |
| 'pitch_step': 4, # 0-6 (C...B) | |
| 'pitch_alter': 5, # 0-4 (shifted +2) | |
| 'pitch_octave': 6, # raw octave (-3..11) | |
| 'duration': 7, # tokenised duration | |
| 'clef': 8, # 0-5 | |
| 'ts_beats': 9, # raw numerator | |
| 'ts_beat_type': 10, # raw denominator {2,4,8,16} | |
| 'key_fifths': 11, # circle of fifths (-7..+7) | |
| 'key_mode': 12, # 0=major, 1=minor | |
| 'staff': 13, # 0-based global staff | |
| } | |
| # Pre-defined feature specifications for the 14-column feature tensor. | |
| # Updated 2026-02-10 for new pipeline (process_dataset.py, max_notes=1530, max_bars=256). | |
| # Value ranges verified via data_utils/analyze_features.py on full 68,542 graphs. | |
| FEATURE_SPECS = { | |
| # Position features | |
| 'grid_position': NoteFeatureSpec('grid_position', col_index=0, vocab_size=33), # tokenised [0, 32] | |
| 'micro_offset': NoteFeatureSpec('micro_offset', col_index=1, vocab_size=21), # tokenised [0, 20] | |
| # Bar index | |
| 'measure_idx': NoteFeatureSpec('measure_idx', col_index=2, vocab_size=256), # 0-based [0, 255] | |
| # Voice and staff (already 0-based from extraction) | |
| 'voice': NoteFeatureSpec('voice', col_index=3, vocab_size=25), # [0, 24] (20 unique) | |
| 'staff': NoteFeatureSpec('staff', col_index=13, vocab_size=37), # [0, 36] | |
| # Pitch features | |
| 'pitch_step': NoteFeatureSpec('pitch_step', col_index=4, vocab_size=7), # [0, 6] C...B | |
| 'pitch_alter': NoteFeatureSpec('pitch_alter', col_index=5, vocab_size=5), # [0, 4] (shifted +2 in extraction) | |
| 'pitch_octave': NoteFeatureSpec('pitch_octave', col_index=6, vocab_size=15, shift=3), # [-3, 11] -> [0, 14] | |
| # Duration (tokenised) | |
| 'duration': NoteFeatureSpec('duration', col_index=7, vocab_size=727), # [0, 726] (filtered vocab) | |
| # Context: clef, time signature, key | |
| 'clef': NoteFeatureSpec('clef', col_index=8, vocab_size=6), # [0, 5] | |
| 'ts_beats': NoteFeatureSpec('ts_beats', col_index=9, vocab_size=57, remap=TS_BEATS_REMAP), # 57 unique values | |
| 'ts_beat_type': NoteFeatureSpec('ts_beat_type', col_index=10, vocab_size=6, remap=TS_BEAT_TYPE_REMAP), # 6 values: {1,2,4,8,16,32} | |
| 'key_fifths': NoteFeatureSpec('key_fifths', col_index=11, vocab_size=15, shift=7), # [-7, +7] -> [0, 14] | |
| 'key_mode': NoteFeatureSpec('key_mode', col_index=12, vocab_size=2), # 0=major, 1=minor (optional) | |
| } | |
| class ScoreGraphMultiFeatureDataset(Dataset): | |
| """ | |
| Dataset for loading score graphs with multiple configurable features. | |
| Supports: | |
| - Multiple discrete features per note | |
| - Graph structure (edges) for potential GNN integration | |
| - Configurable feature selection for ablation studies | |
| Args: | |
| graph_dir: Directory containing .pt graph files | |
| features: List of feature names to extract (from FEATURE_SPECS) | |
| file_list: Optional list of specific files to use | |
| max_notes: Maximum number of notes per graph (filter out larger) | |
| max_bars: Maximum number of bars per graph (filter out larger) | |
| identifier_pool_size: Size of entity ID pool | |
| include_graph: Whether to include graph structure (edges) | |
| id_assignment: 'sequential' (0,1,2,...) or 'random' (random permutation) | |
| seed: Random seed for entity ID assignment | |
| """ | |
| def __init__( | |
| self, | |
| graph_dir: str, | |
| features: List[str] = ['grid_position'], | |
| file_list: Optional[List[str]] = None, | |
| max_notes: Optional[int] = 256, | |
| max_bars: Optional[int] = None, | |
| identifier_pool_size: int = 512, | |
| include_graph: bool = False, | |
| id_assignment: str = 'sequential', | |
| seed: int = 42, | |
| ): | |
| self.graph_dir = Path(graph_dir) | |
| self.features = features | |
| self.max_notes = max_notes | |
| self.max_bars = max_bars | |
| self.identifier_pool_size = identifier_pool_size | |
| self.include_graph = include_graph | |
| self.id_assignment = id_assignment | |
| self.seed = seed | |
| # Validate features | |
| for feat in features: | |
| if feat not in FEATURE_SPECS: | |
| raise ValueError(f"Unknown feature: {feat}. Available: {list(FEATURE_SPECS.keys())}") | |
| self.feature_specs = [FEATURE_SPECS[f] for f in features] | |
| # Track if file_list was provided (for filtering logic) | |
| self._has_file_list = file_list is not None | |
| # Load file list | |
| if file_list is not None: | |
| self.graph_files = [self.graph_dir / f for f in file_list] | |
| self.graph_files = [f for f in self.graph_files if f.exists()] | |
| else: | |
| self.graph_files = sorted(self.graph_dir.glob("*.pt")) | |
| # Filter by max_notes and max_bars | |
| if max_notes is not None or max_bars is not None: | |
| self._filter_by_constraints() | |
| print(f"ScoreGraphMultiFeatureDataset: {len(self.graph_files)} graphs") | |
| print(f"\tFeatures: {features}") | |
| print(f"\tMax notes: {max_notes}, Max bars: {max_bars}, ID pool: {identifier_pool_size}") | |
| def _filter_by_constraints(self): | |
| """Filter graphs that exceed max_notes or max_bars, using cache when available.""" | |
| # Build cache filename based on constraints | |
| cache_parts = [] | |
| if self.max_notes is not None: | |
| cache_parts.append(f"notes{self.max_notes}") | |
| if self.max_bars is not None: | |
| cache_parts.append(f"bars{self.max_bars}") | |
| cache_name = "_".join(cache_parts) | |
| cache_file = self.graph_dir / f".filtered_files_{cache_name}.txt" | |
| if cache_file.exists(): | |
| with open(cache_file) as f: | |
| cached_filenames = set(f.read().strip().split('\n')) | |
| if self._has_file_list: | |
| # Intersect with provided file_list (don't replace!) | |
| current_filenames = {f.name for f in self.graph_files} | |
| valid_filenames = current_filenames & cached_filenames | |
| self.graph_files = [f for f in self.graph_files if f.name in valid_filenames] | |
| else: | |
| # Use cache directly - sort for consistent ordering | |
| self.graph_files = sorted([self.graph_dir / fn for fn in cached_filenames if (self.graph_dir / fn).exists()]) | |
| else: | |
| # Build cache by checking each file | |
| filtered_files = [] | |
| for f in self.graph_files: | |
| try: | |
| data = torch.load(f, weights_only=False) | |
| if self._check_constraints(data): | |
| filtered_files.append(f) | |
| except Exception: | |
| continue | |
| # Save cache (only when not using file_list) | |
| if not self._has_file_list: | |
| with open(cache_file, 'w') as cf: | |
| cf.write('\n'.join(f.name for f in filtered_files)) | |
| self.graph_files = filtered_files | |
| def _check_constraints(self, data) -> bool: | |
| """Check if a graph satisfies max_notes and max_bars constraints.""" | |
| if not isinstance(data, dict) or 'graph' not in data: | |
| return False | |
| graph = data['graph'] | |
| if 'note' not in graph.node_types: | |
| return False | |
| note_x = graph['note'].x | |
| num_notes = note_x.shape[0] | |
| # Check max_notes constraint | |
| if self.max_notes is not None and num_notes > self.max_notes: | |
| return False | |
| # Check max_bars constraint (measure_idx is col 2, already 0-based) | |
| if self.max_bars is not None: | |
| bar_idx_col = note_x[:, 2] # col_index for measure_idx | |
| max_bar = bar_idx_col.max().item() + 1 # +1 since measure_idx is 0-based | |
| if max_bar > self.max_bars: | |
| return False | |
| return True | |
| def _get_num_notes(self, data) -> Optional[int]: | |
| """Extract number of notes from graph['note'].x.""" | |
| if isinstance(data, dict) and 'graph' in data: | |
| graph = data['graph'] | |
| if 'note' in graph.node_types: | |
| return graph['note'].x.shape[0] | |
| # Fallback to num_notes key | |
| if 'num_notes' in data: | |
| return data['num_notes'] | |
| return None | |
| def __len__(self): | |
| return len(self.graph_files) | |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| """Load a graph and extract configured features from graph['note'].x.""" | |
| graph_path = self.graph_files[idx] | |
| data = torch.load(graph_path, weights_only=False) | |
| # Get the HeteroData graph and note features | |
| graph = data['graph'] | |
| note_x = graph['note'].x # Shape: [num_notes, 14] | |
| num_notes = note_x.shape[0] | |
| # Generate unique entity IDs for this sample | |
| if self.id_assignment == 'sequential': | |
| entity_ids = torch.arange(num_notes, dtype=torch.long) | |
| else: | |
| rng = np.random.RandomState(self.seed + idx) | |
| entity_ids = torch.from_numpy( | |
| rng.choice(self.identifier_pool_size, size=num_notes, replace=False) | |
| ).long() | |
| # Extract features from graph['note'].x columns | |
| result = { | |
| 'entity_ids': entity_ids, | |
| 'num_entities': num_notes, | |
| } | |
| for spec in self.feature_specs: | |
| # Get values from the column index | |
| values = note_x[:, spec.col_index].long() | |
| # Apply remapping if specified (for sparse value sets like ts_beat_type) | |
| if spec.remap is not None: | |
| # Vectorized remapping using a lookup tensor. | |
| # Values not in the remap dict get mapped to the nearest known key. | |
| max_key = max(spec.remap.keys()) | |
| remap_tensor = torch.zeros(max_key + 1, dtype=torch.long) | |
| for old_val, new_val in spec.remap.items(): | |
| remap_tensor[old_val] = new_val | |
| clamped = values.clamp(0, max_key) | |
| # For values not in the remap, clamp ensures a valid index; | |
| # stray values map to 0 which is acceptable as fallback. | |
| values = remap_tensor[clamped] | |
| elif spec.shift != 0: | |
| # Apply shift if needed (only if no remap) | |
| values = values + spec.shift | |
| # Clamp to valid range | |
| values = values.clamp(0, spec.vocab_size - 1) | |
| result[spec.name] = values | |
| # Include full graph if requested (for HGT pre-processing) | |
| if self.include_graph: | |
| result['graph'] = graph | |
| return result | |
| def get_vocab_sizes(self) -> Dict[str, int]: | |
| """Get vocabulary sizes for configured features.""" | |
| return {spec.name: spec.vocab_size for spec in self.feature_specs} | |
| def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]: | |
| """ | |
| Collate function for multi-feature dataset. | |
| Pads all samples to the maximum number of entities in the batch. | |
| When graphs are present (HGT mode), extracts edge_dicts for each sample. | |
| """ | |
| from src.model.note_hgt import NoteHGT | |
| batch_size = len(batch) | |
| max_entities = max(sample['num_entities'] for sample in batch) | |
| # Get feature names from first sample (excluding special keys) | |
| special_keys = {'entity_ids', 'num_entities', 'edge_index', 'graph'} | |
| feature_names = [k for k in batch[0].keys() if k not in special_keys] | |
| # Initialize tensors | |
| entity_ids = torch.zeros(batch_size, max_entities, dtype=torch.long) | |
| mask = torch.zeros(batch_size, max_entities, dtype=torch.bool) | |
| num_entities = torch.tensor([s['num_entities'] for s in batch]) | |
| features = {name: torch.zeros(batch_size, max_entities, dtype=torch.long) | |
| for name in feature_names} | |
| # Fill tensors | |
| for i, sample in enumerate(batch): | |
| n = sample['num_entities'] | |
| entity_ids[i, :n] = sample['entity_ids'] | |
| mask[i, :n] = True | |
| for name in feature_names: | |
| features[name][i, :n] = sample[name] | |
| result = { | |
| 'entity_ids': entity_ids, | |
| 'mask': mask, | |
| 'num_entities': num_entities, | |
| **features, | |
| } | |
| # Handle graph data if present (for HGT): extract edge_dicts | |
| if 'graph' in batch[0]: | |
| result['edge_dicts'] = [ | |
| NoteHGT.extract_edge_dict(s['graph']) for s in batch | |
| ] | |
| return result | |