score-ae / src /dataset /datamodule.py
hroth's picture
Upload 90 files
b57c46e verified
raw
history blame
16.1 kB
"""
Multi-feature score graph dataset for LaM-SLidE autoencoder training.
Provides ScoreGraphMultiFeatureDataset, collate functions, and feature specs.
The dataset loads .pt graph files produced by process_dataset.py and extracts
configurable subsets of 14 per-note features for ablation studies.
Feature tensor: 14 columns in graph['note'].x (from process_dataset.py).
Col Feature Raw Range Transform Vocab Notes
0 grid_position [0, 32] none 33 tokenised (16th grid)
1 micro_offset [0, 12] none 13 tokenised micro-shift
2 measure_idx [0, 255] none 256 0-based bar index
3 voice [0, 24] none 25 already 0-based (20 unique)
4 pitch_step [0, 6] none 7 C D E F G A B
5 pitch_alter [0, 4] none 5 shifted +2 in extraction
6 pitch_octave [-3, 11] +3 15 shift to 0-base
7 duration [0, 726] none 727 tokenised duration vocab
8 clef [0, 5] none 6 treble=0 bass=1 ...
9 ts_beats [1, 80] remap 57 remap 57 observed values
10 ts_beat_type {1,2,4,8,16,32} remap 6 remap 6 sparse -> [0,5]
11 key_fifths [-7, +7] +7 15 circle of fifths
12 key_mode [0, 1] none 2 0=major, 1=minor (optional)
13 staff [0, 36] none 37 already 0-based
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import torch
from torch.utils.data import Dataset
@dataclass
class NoteFeatureSpec:
"""Specification for a single note feature."""
name: str
col_index: int # Column index in graph['note'].x tensor
vocab_size: int
shift: int = 0 # Value to add to shift negative values to positive
# Optional remapping for sparse value sets (e.g., ts_beat_type: 2,4,8,16,32 -> 0,1,2,3,4)
remap: Optional[Dict[int, int]] = None
# Remapping for ts_beat_type: actual denominators -> consecutive tokens
# After dropping files with beat types {3, 5, 6, 9, 64}: remaining {1, 2, 4, 8, 16, 32}
TS_BEAT_TYPE_REMAP = {1: 0, 2: 1, 4: 2, 8: 3, 16: 4, 32: 5}
# Remapping for ts_beats: observed numerator values -> consecutive tokens.
# 57 unique values found in filtered data (max_notes=1530, max_bars=256).
# Unknown values are clamped to the closest entry in __getitem__.
TS_BEATS_REMAP = {
1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9,
11: 10, 12: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19,
21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 26: 25, 27: 26, 28: 27, 29: 28, 30: 29,
31: 30, 33: 31, 34: 32, 35: 33, 36: 34, 37: 35, 39: 36, 41: 37, 42: 38, 45: 39,
47: 40, 48: 41, 54: 42, 55: 43, 56: 44, 57: 45, 58: 46, 59: 47, 60: 48, 61: 49,
62: 50, 63: 51, 64: 52, 67: 53, 69: 54, 73: 55, 80: 56,
}
# Column indices in graph['note'].x tensor (shape: [num_notes, 14])
# Matches FEATURE_COLUMNS in process_dataset.py (verified 2025-02-07)
COL_INDICES = {
'grid_position': 0, # tokenised grid position
'micro_offset': 1, # tokenised micro offset
'measure_idx': 2, # 0-based bar index
'voice': 3, # 0-based voice
'pitch_step': 4, # 0-6 (C...B)
'pitch_alter': 5, # 0-4 (shifted +2)
'pitch_octave': 6, # raw octave (-3..11)
'duration': 7, # tokenised duration
'clef': 8, # 0-5
'ts_beats': 9, # raw numerator
'ts_beat_type': 10, # raw denominator {2,4,8,16}
'key_fifths': 11, # circle of fifths (-7..+7)
'key_mode': 12, # 0=major, 1=minor
'staff': 13, # 0-based global staff
}
# Pre-defined feature specifications for the 14-column feature tensor.
# Updated 2026-02-10 for new pipeline (process_dataset.py, max_notes=1530, max_bars=256).
# Value ranges verified via data_utils/analyze_features.py on full 68,542 graphs.
FEATURE_SPECS = {
# Position features
'grid_position': NoteFeatureSpec('grid_position', col_index=0, vocab_size=33), # tokenised [0, 32]
'micro_offset': NoteFeatureSpec('micro_offset', col_index=1, vocab_size=21), # tokenised [0, 20]
# Bar index
'measure_idx': NoteFeatureSpec('measure_idx', col_index=2, vocab_size=256), # 0-based [0, 255]
# Voice and staff (already 0-based from extraction)
'voice': NoteFeatureSpec('voice', col_index=3, vocab_size=25), # [0, 24] (20 unique)
'staff': NoteFeatureSpec('staff', col_index=13, vocab_size=37), # [0, 36]
# Pitch features
'pitch_step': NoteFeatureSpec('pitch_step', col_index=4, vocab_size=7), # [0, 6] C...B
'pitch_alter': NoteFeatureSpec('pitch_alter', col_index=5, vocab_size=5), # [0, 4] (shifted +2 in extraction)
'pitch_octave': NoteFeatureSpec('pitch_octave', col_index=6, vocab_size=15, shift=3), # [-3, 11] -> [0, 14]
# Duration (tokenised)
'duration': NoteFeatureSpec('duration', col_index=7, vocab_size=727), # [0, 726] (filtered vocab)
# Context: clef, time signature, key
'clef': NoteFeatureSpec('clef', col_index=8, vocab_size=6), # [0, 5]
'ts_beats': NoteFeatureSpec('ts_beats', col_index=9, vocab_size=57, remap=TS_BEATS_REMAP), # 57 unique values
'ts_beat_type': NoteFeatureSpec('ts_beat_type', col_index=10, vocab_size=6, remap=TS_BEAT_TYPE_REMAP), # 6 values: {1,2,4,8,16,32}
'key_fifths': NoteFeatureSpec('key_fifths', col_index=11, vocab_size=15, shift=7), # [-7, +7] -> [0, 14]
'key_mode': NoteFeatureSpec('key_mode', col_index=12, vocab_size=2), # 0=major, 1=minor (optional)
}
class ScoreGraphMultiFeatureDataset(Dataset):
"""
Dataset for loading score graphs with multiple configurable features.
Supports:
- Multiple discrete features per note
- Graph structure (edges) for potential GNN integration
- Configurable feature selection for ablation studies
Args:
graph_dir: Directory containing .pt graph files
features: List of feature names to extract (from FEATURE_SPECS)
file_list: Optional list of specific files to use
max_notes: Maximum number of notes per graph (filter out larger)
max_bars: Maximum number of bars per graph (filter out larger)
identifier_pool_size: Size of entity ID pool
include_graph: Whether to include graph structure (edges)
id_assignment: 'sequential' (0,1,2,...) or 'random' (random permutation)
seed: Random seed for entity ID assignment
"""
def __init__(
self,
graph_dir: str,
features: List[str] = ['grid_position'],
file_list: Optional[List[str]] = None,
max_notes: Optional[int] = 256,
max_bars: Optional[int] = None,
identifier_pool_size: int = 512,
include_graph: bool = False,
id_assignment: str = 'sequential',
seed: int = 42,
):
self.graph_dir = Path(graph_dir)
self.features = features
self.max_notes = max_notes
self.max_bars = max_bars
self.identifier_pool_size = identifier_pool_size
self.include_graph = include_graph
self.id_assignment = id_assignment
self.seed = seed
# Validate features
for feat in features:
if feat not in FEATURE_SPECS:
raise ValueError(f"Unknown feature: {feat}. Available: {list(FEATURE_SPECS.keys())}")
self.feature_specs = [FEATURE_SPECS[f] for f in features]
# Track if file_list was provided (for filtering logic)
self._has_file_list = file_list is not None
# Load file list
if file_list is not None:
self.graph_files = [self.graph_dir / f for f in file_list]
self.graph_files = [f for f in self.graph_files if f.exists()]
else:
self.graph_files = sorted(self.graph_dir.glob("*.pt"))
# Filter by max_notes and max_bars
if max_notes is not None or max_bars is not None:
self._filter_by_constraints()
print(f"ScoreGraphMultiFeatureDataset: {len(self.graph_files)} graphs")
print(f"\tFeatures: {features}")
print(f"\tMax notes: {max_notes}, Max bars: {max_bars}, ID pool: {identifier_pool_size}")
def _filter_by_constraints(self):
"""Filter graphs that exceed max_notes or max_bars, using cache when available."""
# Build cache filename based on constraints
cache_parts = []
if self.max_notes is not None:
cache_parts.append(f"notes{self.max_notes}")
if self.max_bars is not None:
cache_parts.append(f"bars{self.max_bars}")
cache_name = "_".join(cache_parts)
cache_file = self.graph_dir / f".filtered_files_{cache_name}.txt"
if cache_file.exists():
with open(cache_file) as f:
cached_filenames = set(f.read().strip().split('\n'))
if self._has_file_list:
# Intersect with provided file_list (don't replace!)
current_filenames = {f.name for f in self.graph_files}
valid_filenames = current_filenames & cached_filenames
self.graph_files = [f for f in self.graph_files if f.name in valid_filenames]
else:
# Use cache directly - sort for consistent ordering
self.graph_files = sorted([self.graph_dir / fn for fn in cached_filenames if (self.graph_dir / fn).exists()])
else:
# Build cache by checking each file
filtered_files = []
for f in self.graph_files:
try:
data = torch.load(f, weights_only=False)
if self._check_constraints(data):
filtered_files.append(f)
except Exception:
continue
# Save cache (only when not using file_list)
if not self._has_file_list:
with open(cache_file, 'w') as cf:
cf.write('\n'.join(f.name for f in filtered_files))
self.graph_files = filtered_files
def _check_constraints(self, data) -> bool:
"""Check if a graph satisfies max_notes and max_bars constraints."""
if not isinstance(data, dict) or 'graph' not in data:
return False
graph = data['graph']
if 'note' not in graph.node_types:
return False
note_x = graph['note'].x
num_notes = note_x.shape[0]
# Check max_notes constraint
if self.max_notes is not None and num_notes > self.max_notes:
return False
# Check max_bars constraint (measure_idx is col 2, already 0-based)
if self.max_bars is not None:
bar_idx_col = note_x[:, 2] # col_index for measure_idx
max_bar = bar_idx_col.max().item() + 1 # +1 since measure_idx is 0-based
if max_bar > self.max_bars:
return False
return True
def _get_num_notes(self, data) -> Optional[int]:
"""Extract number of notes from graph['note'].x."""
if isinstance(data, dict) and 'graph' in data:
graph = data['graph']
if 'note' in graph.node_types:
return graph['note'].x.shape[0]
# Fallback to num_notes key
if 'num_notes' in data:
return data['num_notes']
return None
def __len__(self):
return len(self.graph_files)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Load a graph and extract configured features from graph['note'].x."""
graph_path = self.graph_files[idx]
data = torch.load(graph_path, weights_only=False)
# Get the HeteroData graph and note features
graph = data['graph']
note_x = graph['note'].x # Shape: [num_notes, 14]
num_notes = note_x.shape[0]
# Generate unique entity IDs for this sample
if self.id_assignment == 'sequential':
entity_ids = torch.arange(num_notes, dtype=torch.long)
else:
rng = np.random.RandomState(self.seed + idx)
entity_ids = torch.from_numpy(
rng.choice(self.identifier_pool_size, size=num_notes, replace=False)
).long()
# Extract features from graph['note'].x columns
result = {
'entity_ids': entity_ids,
'num_entities': num_notes,
}
for spec in self.feature_specs:
# Get values from the column index
values = note_x[:, spec.col_index].long()
# Apply remapping if specified (for sparse value sets like ts_beat_type)
if spec.remap is not None:
# Vectorized remapping using a lookup tensor.
# Values not in the remap dict get mapped to the nearest known key.
max_key = max(spec.remap.keys())
remap_tensor = torch.zeros(max_key + 1, dtype=torch.long)
for old_val, new_val in spec.remap.items():
remap_tensor[old_val] = new_val
clamped = values.clamp(0, max_key)
# For values not in the remap, clamp ensures a valid index;
# stray values map to 0 which is acceptable as fallback.
values = remap_tensor[clamped]
elif spec.shift != 0:
# Apply shift if needed (only if no remap)
values = values + spec.shift
# Clamp to valid range
values = values.clamp(0, spec.vocab_size - 1)
result[spec.name] = values
# Include full graph if requested (for HGT pre-processing)
if self.include_graph:
result['graph'] = graph
return result
def get_vocab_sizes(self) -> Dict[str, int]:
"""Get vocabulary sizes for configured features."""
return {spec.name: spec.vocab_size for spec in self.feature_specs}
def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
"""
Collate function for multi-feature dataset.
Pads all samples to the maximum number of entities in the batch.
When graphs are present (HGT mode), extracts edge_dicts for each sample.
"""
from src.model.note_hgt import NoteHGT
batch_size = len(batch)
max_entities = max(sample['num_entities'] for sample in batch)
# Get feature names from first sample (excluding special keys)
special_keys = {'entity_ids', 'num_entities', 'edge_index', 'graph'}
feature_names = [k for k in batch[0].keys() if k not in special_keys]
# Initialize tensors
entity_ids = torch.zeros(batch_size, max_entities, dtype=torch.long)
mask = torch.zeros(batch_size, max_entities, dtype=torch.bool)
num_entities = torch.tensor([s['num_entities'] for s in batch])
features = {name: torch.zeros(batch_size, max_entities, dtype=torch.long)
for name in feature_names}
# Fill tensors
for i, sample in enumerate(batch):
n = sample['num_entities']
entity_ids[i, :n] = sample['entity_ids']
mask[i, :n] = True
for name in feature_names:
features[name][i, :n] = sample[name]
result = {
'entity_ids': entity_ids,
'mask': mask,
'num_entities': num_entities,
**features,
}
# Handle graph data if present (for HGT): extract edge_dicts
if 'graph' in batch[0]:
result['edge_dicts'] = [
NoteHGT.extract_edge_dict(s['graph']) for s in batch
]
return result