score-ae / inference.py
hroth's picture
Upload 90 files
b57c46e verified
raw
history blame
16.9 kB
#!/usr/bin/env python3
"""
Inference Pipeline for LaM-SLidE Autoencoder.
This module provides functions for:
- Extracting features from score graphs (with preprocessing)
- Running inference through a trained autoencoder
- Undoing preprocessing shifts/remaps for reconstruction
- Full pipeline: graph -> model -> reconstructed features
Usage:
from scripts.inference import reconstruct_from_graph, load_model_and_reconstruct
# With a loaded model
features = reconstruct_from_graph(model, graph_data)
# Load model and reconstruct in one call
features = load_model_and_reconstruct("checkpoint.pt", "graph.pt")
"""
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import sys
# Add app root to path
sys.path.insert(0, str(Path(__file__).resolve().parent))
import numpy as np
import torch
# =============================================================================
# Feature Transformation Specs
# =============================================================================
# Reverse mapping for ts_beat_type: token -> actual denominator
TS_BEAT_TYPE_INV = {0: 1, 1: 2, 2: 4, 3: 8, 4: 16, 5: 32}
# Forward mapping for ts_beat_type: denominator -> token
TS_BEAT_TYPE_REMAP = {1: 0, 2: 1, 4: 2, 8: 3, 16: 4, 32: 5}
# Reverse mapping for ts_beats: token -> actual numerator
TS_BEATS_REMAP = {
1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9,
11: 10, 12: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19,
21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 26: 25, 27: 26, 28: 27, 29: 28, 30: 29,
31: 30, 33: 31, 34: 32, 35: 33, 36: 34, 37: 35, 39: 36, 41: 37, 42: 38, 45: 39,
47: 40, 48: 41, 54: 42, 55: 43, 56: 44, 57: 45, 58: 46, 59: 47, 60: 48, 61: 49,
62: 50, 63: 51, 64: 52, 67: 53, 69: 54, 73: 55, 80: 56,
}
TS_BEATS_INV = {v: k for k, v in TS_BEATS_REMAP.items()}
# Column indices in graph['note'].x tensor (from datamodule.py)
COL_INDICES = {
'grid_position': 0, # tokenised grid position
'micro_offset': 1, # tokenised micro offset
'measure_idx': 2, # 0-based bar index
'voice': 3, # 0-based voice
'pitch_step': 4, # 0-6 (C...B)
'pitch_alter': 5, # 0-4 (shifted +2 in extraction)
'pitch_octave': 6, # raw octave (-3..11)
'duration': 7, # tokenised duration
'clef': 8, # 0-5
'ts_beats': 9, # raw numerator
'ts_beat_type': 10, # raw denominator {1,2,4,8,16,32}
'key_fifths': 11, # circle of fifths (-7..+7)
'key_mode': 12, # 0=major, 1=minor
'staff': 13, # 0-based global staff
}
# Shifts applied during preprocessing (from datamodule.py FEATURE_SPECS)
# Shift applied: values = values + shift -> to undo: values = values - shift
FEATURE_SHIFTS = {
'pitch_octave': 3, # [-3, 11] -> [0, 14], undo: -3
'key_fifths': 7, # [-7, +7] -> [0, 14], undo: -7
}
# Features that use remapping (sparse -> dense)
FEATURE_REMAPS = {
'ts_beat_type': TS_BEAT_TYPE_REMAP,
'ts_beats': TS_BEATS_REMAP,
}
FEATURE_REMAPS_INV = {
'ts_beat_type': TS_BEAT_TYPE_INV,
'ts_beats': TS_BEATS_INV,
}
# =============================================================================
# Feature Extraction
# =============================================================================
def extract_features_from_graph(
graph_data: Dict,
feature_names: List[str],
identifier_pool_size: int = 1600,
seed: int = 42,
id_assignment: str = 'sequential',
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor, Dict[str, torch.Tensor]]:
"""
Extract features from a score graph in the format expected by the autoencoder.
Applies the same transformations as the datamodule (shifts, remaps).
Also returns the original (non-model) features for reconstruction.
Args:
graph_data: Loaded graph data dict with 'graph' key
feature_names: List of feature names to extract for the model
identifier_pool_size: Size of entity ID pool
seed: Random seed for entity ID assignment
id_assignment: 'random' or 'sequential'
Returns:
model_features: Dict of feature_name -> (N,) tensors (shifted/remapped)
entity_ids: (N,) entity identifiers
raw_features: Dict of all raw features from graph (for reconstruction)
"""
graph = graph_data['graph']
note_x = graph['note'].x # Shape: [num_notes, 13]
num_notes = note_x.shape[0]
# Assign entity IDs
if id_assignment == 'random':
rng = np.random.RandomState(seed)
ids = rng.choice(identifier_pool_size, size=num_notes, replace=False)
entity_ids = torch.from_numpy(ids).long()
else: # sequential
entity_ids = torch.arange(num_notes, dtype=torch.long)
# Extract model features (with shifts/remaps applied)
model_features = {}
for feat_name in feature_names:
col_idx = COL_INDICES[feat_name]
values = note_x[:, col_idx].long()
# Apply transformations (same as datamodule __getitem__)
if feat_name in FEATURE_REMAPS:
remap = FEATURE_REMAPS[feat_name]
max_key = max(remap.keys())
remap_tensor = torch.zeros(max_key + 1, dtype=torch.long)
for old_val, new_val in remap.items():
remap_tensor[old_val] = new_val
values = remap_tensor[values.clamp(0, max_key)]
elif feat_name in FEATURE_SHIFTS:
values = values + FEATURE_SHIFTS[feat_name]
model_features[feat_name] = values
# Extract ALL raw features for reconstruction (original values from graph)
raw_features = {
'pitch_step': note_x[:, COL_INDICES['pitch_step']].long(),
'pitch_alter': note_x[:, COL_INDICES['pitch_alter']].long(),
'pitch_octave': note_x[:, COL_INDICES['pitch_octave']].long(),
'position_grid_token': note_x[:, COL_INDICES['grid_position']].long(),
'position_micro_token': note_x[:, COL_INDICES['micro_offset']].long(),
'duration_token': note_x[:, COL_INDICES['duration']].long(),
'measure_idx': note_x[:, COL_INDICES['measure_idx']].long(),
'voice': note_x[:, COL_INDICES['voice']].long(),
'staff': note_x[:, COL_INDICES['staff']].long(),
'clef': note_x[:, COL_INDICES['clef']].long(),
'ts_beats': note_x[:, COL_INDICES['ts_beats']].long(),
'ts_beat_type': note_x[:, COL_INDICES['ts_beat_type']].long(),
'key_fifths': note_x[:, COL_INDICES['key_fifths']].long(),
'key_mode': note_x[:, COL_INDICES['key_mode']].long(),
}
return model_features, entity_ids, raw_features
# =============================================================================
# Feature Shift/Remap Undo
# =============================================================================
def undo_feature_shifts(
predictions: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""
Undo the shifts/remaps applied during preprocessing.
Converts model predictions back to their original value ranges.
Args:
predictions: Dict of feature_name -> (N,) predicted token indices
Returns:
raw_predictions: Dict with original value ranges
"""
raw_predictions = {}
for feat_name, values in predictions.items():
if feat_name in FEATURE_REMAPS_INV:
# Reverse remap: token -> original value
inv_map = FEATURE_REMAPS_INV[feat_name]
max_val = max(inv_map.keys()) + 1
inv_remap_tensor = torch.zeros(max_val, dtype=torch.long)
for token, actual in inv_map.items():
inv_remap_tensor[token] = actual
raw_predictions[feat_name] = inv_remap_tensor[values.clamp(0, max_val - 1)]
elif feat_name in FEATURE_SHIFTS:
# Undo shift: values = values - shift
raw_predictions[feat_name] = values - FEATURE_SHIFTS[feat_name]
else:
raw_predictions[feat_name] = values
return raw_predictions
# =============================================================================
# Reconstruction Pipeline
# =============================================================================
def reconstruct_from_graph(
model: 'LaMSLiDEAutoencoder',
graph_data: Dict,
identifier_pool_size: int = 1600,
seed: int = 42,
id_assignment: str = 'sequential',
device: Optional[Union[str, torch.device]] = None,
) -> Dict[str, np.ndarray]:
"""
Reconstruct features from a score graph using the autoencoder.
This is the main inference pipeline:
1. Extract features from graph (with preprocessing)
2. Run through autoencoder (encode + decode)
3. Get predicted tokens (argmax)
4. Undo preprocessing shifts
5. Return features in format expected by reconstruct_score()
Args:
model: Trained LaMSLiDEAutoencoder
graph_data: Loaded graph data dict with 'graph' key
identifier_pool_size: Size of entity ID pool
seed: Random seed for entity ID assignment
id_assignment: 'random' or 'sequential'
device: Device to run inference on
Returns:
features: Dict with reconstructed features ready for reconstruct_score()
Keys match what reconstruct_mxl.py expects:
- pitch_step, pitch_alter, pitch_octave
- position_grid_token, position_micro_token
- duration_token, bar_idx, voice, staff
- clef, ts_beats, ts_beat_type, key_signature
"""
if device is None:
device = next(model.parameters()).device
model.eval()
# Get feature names from model config
model_feature_names = [f.name for f in model.config.input_features]
# Extract features from graph
model_features, entity_ids, raw_features = extract_features_from_graph(
graph_data,
model_feature_names,
identifier_pool_size=identifier_pool_size,
seed=seed,
id_assignment=id_assignment,
)
# Move to device and add batch dimension
model_features_batch = {
k: v.unsqueeze(0).to(device) for k, v in model_features.items()
}
entity_ids_batch = entity_ids.unsqueeze(0).to(device)
num_notes = entity_ids.shape[0]
mask = torch.ones(1, num_notes, dtype=torch.bool, device=device)
# Forward pass
with torch.no_grad():
logits = model(model_features_batch, entity_ids_batch, mask=mask)
# Get predictions (argmax)
predictions = {
name: logits[name][0].argmax(dim=-1).cpu() # (N,)
for name in logits.keys()
}
# Undo preprocessing shifts
raw_predictions = undo_feature_shifts(predictions)
# Build output features dict
# Use reconstructed features for what the model predicts,
# keep original features for everything else
output_features = {}
# Map model feature names to reconstruct_mxl.py expected keys
feature_key_map = {
'grid_position': 'position_grid_token',
'micro_offset': 'position_micro_token',
'duration': 'duration_token',
# These keep the same name
'pitch_step': 'pitch_step',
'pitch_alter': 'pitch_alter',
'pitch_octave': 'pitch_octave',
'measure_idx': 'measure_idx',
'voice': 'voice',
'staff': 'staff',
'clef': 'clef',
'ts_beats': 'ts_beats',
'ts_beat_type': 'ts_beat_type',
'key_fifths': 'key_fifths',
}
# Start with all raw features (for non-predicted features)
for key, tensor in raw_features.items():
output_features[key] = tensor.numpy()
# Override with predictions for features the model reconstructed
for model_name, output_key in feature_key_map.items():
if model_name in raw_predictions:
output_features[output_key] = raw_predictions[model_name].numpy()
return output_features
def load_model_and_reconstruct(
checkpoint_path: Union[str, Path],
graph_path: Union[str, Path],
config_path: Optional[Union[str, Path]] = None,
identifier_pool_size: Optional[int] = None,
seed: int = 42,
id_assignment: str = 'sequential',
device: Optional[Union[str, torch.device]] = None,
) -> Dict[str, np.ndarray]:
"""
Load a trained model and reconstruct features from a graph file.
Convenience function for the full inference pipeline.
Args:
checkpoint_path: Path to saved model checkpoint
graph_path: Path to .pt graph file
config_path: Path to YAML config (required if checkpoint lacks 'config' key)
identifier_pool_size: Size of entity ID pool (overrides config if provided)
seed: Random seed for entity ID assignment
id_assignment: 'random' or 'sequential'
device: Device to run inference on
Returns:
features: Dict with reconstructed features ready for reconstruct_score()
"""
# Import here to avoid circular imports
from src.model.autoencoder import create_autoencoder_from_dict
from omegaconf import OmegaConf
checkpoint_path = Path(checkpoint_path)
graph_path = Path(graph_path)
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
# Extract config and create model
if 'config' in checkpoint:
config_dict = checkpoint['config']
elif config_path is not None:
config_path = Path(config_path)
raw_config = OmegaConf.load(config_path)
config_dict = OmegaConf.to_container(raw_config.model, resolve=True)
else:
raise ValueError("Checkpoint must contain 'config' key or config_path must be provided")
model = create_autoencoder_from_dict(config_dict)
# Load weights
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
# Assume checkpoint is just the state dict
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
# Load graph
graph_data = torch.load(graph_path, weights_only=False)
# Determine identifier pool size
if identifier_pool_size is None:
identifier_pool_size = model.config.identifier_pool_size
# Reconstruct
features = reconstruct_from_graph(
model,
graph_data,
identifier_pool_size=identifier_pool_size,
seed=seed,
id_assignment=id_assignment,
device=device,
)
return features
# =============================================================================
# Testing
# =============================================================================
if __name__ == '__main__':
print("Testing inference pipeline functions...")
# Test feature shift undo
print("\n1. Testing feature shift undo:")
test_preds = {
'pitch_octave': torch.tensor([0, 3, 10, 14]),
'key_fifths': torch.tensor([0, 7, 14]),
'ts_beat_type': torch.tensor([0, 1, 2, 3, 4, 5]),
'ts_beats': torch.tensor([0, 3, 41]), # tokens -> actual values
'grid_position': torch.tensor([0, 16, 32]),
'voice': torch.tensor([0, 1, 24]),
}
raw_preds = undo_feature_shifts(test_preds)
print(f"\tpitch_octave: {test_preds['pitch_octave'].tolist()} -> {raw_preds['pitch_octave'].tolist()} (shift -3)")
print(f"\tkey_fifths: {test_preds['key_fifths'].tolist()} -> {raw_preds['key_fifths'].tolist()} (shift -7)")
print(f"\tts_beat_type: {test_preds['ts_beat_type'].tolist()} -> {raw_preds['ts_beat_type'].tolist()} (remap)")
print(f"\tts_beats: {test_preds['ts_beats'].tolist()} -> {raw_preds['ts_beats'].tolist()} (remap)")
print(f"\tgrid_position: {test_preds['grid_position'].tolist()} -> {raw_preds['grid_position'].tolist()} (no change)")
print(f"\tvoice: {test_preds['voice'].tolist()} -> {raw_preds['voice'].tolist()} (no change)")
# Verify expected values
assert raw_preds['pitch_octave'].tolist() == [-3, 0, 7, 11], "pitch_octave shift undo failed"
assert raw_preds['key_fifths'].tolist() == [-7, 0, 7], "key_fifths shift undo failed"
assert raw_preds['ts_beat_type'].tolist() == [1, 2, 4, 8, 16, 32], "ts_beat_type remap undo failed"
assert raw_preds['ts_beats'].tolist() == [1, 4, 48], "ts_beats remap undo failed"
assert raw_preds['grid_position'].tolist() == [0, 16, 32], "grid_position should be unchanged"
assert raw_preds['voice'].tolist() == [0, 1, 24], "voice should be unchanged (already 0-based)"
print("\t[OK] All shift tests passed!")
print("\n[OK] All tests passed!")