#!/usr/bin/env python3 """ Inference Pipeline for LaM-SLidE Autoencoder. This module provides functions for: - Extracting features from score graphs (with preprocessing) - Running inference through a trained autoencoder - Undoing preprocessing shifts/remaps for reconstruction - Full pipeline: graph -> model -> reconstructed features Usage: from scripts.inference import reconstruct_from_graph, load_model_and_reconstruct # With a loaded model features = reconstruct_from_graph(model, graph_data) # Load model and reconstruct in one call features = load_model_and_reconstruct("checkpoint.pt", "graph.pt") """ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import sys # Add app root to path sys.path.insert(0, str(Path(__file__).resolve().parent)) import numpy as np import torch # ============================================================================= # Feature Transformation Specs # ============================================================================= # Reverse mapping for ts_beat_type: token -> actual denominator TS_BEAT_TYPE_INV = {0: 1, 1: 2, 2: 4, 3: 8, 4: 16, 5: 32} # Forward mapping for ts_beat_type: denominator -> token TS_BEAT_TYPE_REMAP = {1: 0, 2: 1, 4: 2, 8: 3, 16: 4, 32: 5} # Reverse mapping for ts_beats: token -> actual numerator TS_BEATS_REMAP = { 1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9, 11: 10, 12: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 26: 25, 27: 26, 28: 27, 29: 28, 30: 29, 31: 30, 33: 31, 34: 32, 35: 33, 36: 34, 37: 35, 39: 36, 41: 37, 42: 38, 45: 39, 47: 40, 48: 41, 54: 42, 55: 43, 56: 44, 57: 45, 58: 46, 59: 47, 60: 48, 61: 49, 62: 50, 63: 51, 64: 52, 67: 53, 69: 54, 73: 55, 80: 56, } TS_BEATS_INV = {v: k for k, v in TS_BEATS_REMAP.items()} # Column indices in graph['note'].x tensor (from datamodule.py) COL_INDICES = { 'grid_position': 0, # tokenised grid position 'micro_offset': 1, # tokenised micro offset 'measure_idx': 2, # 0-based bar index 'voice': 3, # 0-based voice 'pitch_step': 4, # 0-6 (C...B) 'pitch_alter': 5, # 0-4 (shifted +2 in extraction) 'pitch_octave': 6, # raw octave (-3..11) 'duration': 7, # tokenised duration 'clef': 8, # 0-5 'ts_beats': 9, # raw numerator 'ts_beat_type': 10, # raw denominator {1,2,4,8,16,32} 'key_fifths': 11, # circle of fifths (-7..+7) 'key_mode': 12, # 0=major, 1=minor 'staff': 13, # 0-based global staff } # Shifts applied during preprocessing (from datamodule.py FEATURE_SPECS) # Shift applied: values = values + shift -> to undo: values = values - shift FEATURE_SHIFTS = { 'pitch_octave': 3, # [-3, 11] -> [0, 14], undo: -3 'key_fifths': 7, # [-7, +7] -> [0, 14], undo: -7 } # Features that use remapping (sparse -> dense) FEATURE_REMAPS = { 'ts_beat_type': TS_BEAT_TYPE_REMAP, 'ts_beats': TS_BEATS_REMAP, } FEATURE_REMAPS_INV = { 'ts_beat_type': TS_BEAT_TYPE_INV, 'ts_beats': TS_BEATS_INV, } # ============================================================================= # Feature Extraction # ============================================================================= def extract_features_from_graph( graph_data: Dict, feature_names: List[str], identifier_pool_size: int = 1600, seed: int = 42, id_assignment: str = 'sequential', ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor, Dict[str, torch.Tensor]]: """ Extract features from a score graph in the format expected by the autoencoder. Applies the same transformations as the datamodule (shifts, remaps). Also returns the original (non-model) features for reconstruction. Args: graph_data: Loaded graph data dict with 'graph' key feature_names: List of feature names to extract for the model identifier_pool_size: Size of entity ID pool seed: Random seed for entity ID assignment id_assignment: 'random' or 'sequential' Returns: model_features: Dict of feature_name -> (N,) tensors (shifted/remapped) entity_ids: (N,) entity identifiers raw_features: Dict of all raw features from graph (for reconstruction) """ graph = graph_data['graph'] note_x = graph['note'].x # Shape: [num_notes, 13] num_notes = note_x.shape[0] # Assign entity IDs if id_assignment == 'random': rng = np.random.RandomState(seed) ids = rng.choice(identifier_pool_size, size=num_notes, replace=False) entity_ids = torch.from_numpy(ids).long() else: # sequential entity_ids = torch.arange(num_notes, dtype=torch.long) # Extract model features (with shifts/remaps applied) model_features = {} for feat_name in feature_names: col_idx = COL_INDICES[feat_name] values = note_x[:, col_idx].long() # Apply transformations (same as datamodule __getitem__) if feat_name in FEATURE_REMAPS: remap = FEATURE_REMAPS[feat_name] max_key = max(remap.keys()) remap_tensor = torch.zeros(max_key + 1, dtype=torch.long) for old_val, new_val in remap.items(): remap_tensor[old_val] = new_val values = remap_tensor[values.clamp(0, max_key)] elif feat_name in FEATURE_SHIFTS: values = values + FEATURE_SHIFTS[feat_name] model_features[feat_name] = values # Extract ALL raw features for reconstruction (original values from graph) raw_features = { 'pitch_step': note_x[:, COL_INDICES['pitch_step']].long(), 'pitch_alter': note_x[:, COL_INDICES['pitch_alter']].long(), 'pitch_octave': note_x[:, COL_INDICES['pitch_octave']].long(), 'position_grid_token': note_x[:, COL_INDICES['grid_position']].long(), 'position_micro_token': note_x[:, COL_INDICES['micro_offset']].long(), 'duration_token': note_x[:, COL_INDICES['duration']].long(), 'measure_idx': note_x[:, COL_INDICES['measure_idx']].long(), 'voice': note_x[:, COL_INDICES['voice']].long(), 'staff': note_x[:, COL_INDICES['staff']].long(), 'clef': note_x[:, COL_INDICES['clef']].long(), 'ts_beats': note_x[:, COL_INDICES['ts_beats']].long(), 'ts_beat_type': note_x[:, COL_INDICES['ts_beat_type']].long(), 'key_fifths': note_x[:, COL_INDICES['key_fifths']].long(), 'key_mode': note_x[:, COL_INDICES['key_mode']].long(), } return model_features, entity_ids, raw_features # ============================================================================= # Feature Shift/Remap Undo # ============================================================================= def undo_feature_shifts( predictions: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: """ Undo the shifts/remaps applied during preprocessing. Converts model predictions back to their original value ranges. Args: predictions: Dict of feature_name -> (N,) predicted token indices Returns: raw_predictions: Dict with original value ranges """ raw_predictions = {} for feat_name, values in predictions.items(): if feat_name in FEATURE_REMAPS_INV: # Reverse remap: token -> original value inv_map = FEATURE_REMAPS_INV[feat_name] max_val = max(inv_map.keys()) + 1 inv_remap_tensor = torch.zeros(max_val, dtype=torch.long) for token, actual in inv_map.items(): inv_remap_tensor[token] = actual raw_predictions[feat_name] = inv_remap_tensor[values.clamp(0, max_val - 1)] elif feat_name in FEATURE_SHIFTS: # Undo shift: values = values - shift raw_predictions[feat_name] = values - FEATURE_SHIFTS[feat_name] else: raw_predictions[feat_name] = values return raw_predictions # ============================================================================= # Reconstruction Pipeline # ============================================================================= def reconstruct_from_graph( model: 'LaMSLiDEAutoencoder', graph_data: Dict, identifier_pool_size: int = 1600, seed: int = 42, id_assignment: str = 'sequential', device: Optional[Union[str, torch.device]] = None, ) -> Dict[str, np.ndarray]: """ Reconstruct features from a score graph using the autoencoder. This is the main inference pipeline: 1. Extract features from graph (with preprocessing) 2. Run through autoencoder (encode + decode) 3. Get predicted tokens (argmax) 4. Undo preprocessing shifts 5. Return features in format expected by reconstruct_score() Args: model: Trained LaMSLiDEAutoencoder graph_data: Loaded graph data dict with 'graph' key identifier_pool_size: Size of entity ID pool seed: Random seed for entity ID assignment id_assignment: 'random' or 'sequential' device: Device to run inference on Returns: features: Dict with reconstructed features ready for reconstruct_score() Keys match what reconstruct_mxl.py expects: - pitch_step, pitch_alter, pitch_octave - position_grid_token, position_micro_token - duration_token, bar_idx, voice, staff - clef, ts_beats, ts_beat_type, key_signature """ if device is None: device = next(model.parameters()).device model.eval() # Get feature names from model config model_feature_names = [f.name for f in model.config.input_features] # Extract features from graph model_features, entity_ids, raw_features = extract_features_from_graph( graph_data, model_feature_names, identifier_pool_size=identifier_pool_size, seed=seed, id_assignment=id_assignment, ) # Move to device and add batch dimension model_features_batch = { k: v.unsqueeze(0).to(device) for k, v in model_features.items() } entity_ids_batch = entity_ids.unsqueeze(0).to(device) num_notes = entity_ids.shape[0] mask = torch.ones(1, num_notes, dtype=torch.bool, device=device) # Forward pass with torch.no_grad(): logits = model(model_features_batch, entity_ids_batch, mask=mask) # Get predictions (argmax) predictions = { name: logits[name][0].argmax(dim=-1).cpu() # (N,) for name in logits.keys() } # Undo preprocessing shifts raw_predictions = undo_feature_shifts(predictions) # Build output features dict # Use reconstructed features for what the model predicts, # keep original features for everything else output_features = {} # Map model feature names to reconstruct_mxl.py expected keys feature_key_map = { 'grid_position': 'position_grid_token', 'micro_offset': 'position_micro_token', 'duration': 'duration_token', # These keep the same name 'pitch_step': 'pitch_step', 'pitch_alter': 'pitch_alter', 'pitch_octave': 'pitch_octave', 'measure_idx': 'measure_idx', 'voice': 'voice', 'staff': 'staff', 'clef': 'clef', 'ts_beats': 'ts_beats', 'ts_beat_type': 'ts_beat_type', 'key_fifths': 'key_fifths', } # Start with all raw features (for non-predicted features) for key, tensor in raw_features.items(): output_features[key] = tensor.numpy() # Override with predictions for features the model reconstructed for model_name, output_key in feature_key_map.items(): if model_name in raw_predictions: output_features[output_key] = raw_predictions[model_name].numpy() return output_features def load_model_and_reconstruct( checkpoint_path: Union[str, Path], graph_path: Union[str, Path], config_path: Optional[Union[str, Path]] = None, identifier_pool_size: Optional[int] = None, seed: int = 42, id_assignment: str = 'sequential', device: Optional[Union[str, torch.device]] = None, ) -> Dict[str, np.ndarray]: """ Load a trained model and reconstruct features from a graph file. Convenience function for the full inference pipeline. Args: checkpoint_path: Path to saved model checkpoint graph_path: Path to .pt graph file config_path: Path to YAML config (required if checkpoint lacks 'config' key) identifier_pool_size: Size of entity ID pool (overrides config if provided) seed: Random seed for entity ID assignment id_assignment: 'random' or 'sequential' device: Device to run inference on Returns: features: Dict with reconstructed features ready for reconstruct_score() """ # Import here to avoid circular imports from src.model.autoencoder import create_autoencoder_from_dict from omegaconf import OmegaConf checkpoint_path = Path(checkpoint_path) graph_path = Path(graph_path) if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) # Extract config and create model if 'config' in checkpoint: config_dict = checkpoint['config'] elif config_path is not None: config_path = Path(config_path) raw_config = OmegaConf.load(config_path) config_dict = OmegaConf.to_container(raw_config.model, resolve=True) else: raise ValueError("Checkpoint must contain 'config' key or config_path must be provided") model = create_autoencoder_from_dict(config_dict) # Load weights if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) elif 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) else: # Assume checkpoint is just the state dict model.load_state_dict(checkpoint) model = model.to(device) model.eval() # Load graph graph_data = torch.load(graph_path, weights_only=False) # Determine identifier pool size if identifier_pool_size is None: identifier_pool_size = model.config.identifier_pool_size # Reconstruct features = reconstruct_from_graph( model, graph_data, identifier_pool_size=identifier_pool_size, seed=seed, id_assignment=id_assignment, device=device, ) return features # ============================================================================= # Testing # ============================================================================= if __name__ == '__main__': print("Testing inference pipeline functions...") # Test feature shift undo print("\n1. Testing feature shift undo:") test_preds = { 'pitch_octave': torch.tensor([0, 3, 10, 14]), 'key_fifths': torch.tensor([0, 7, 14]), 'ts_beat_type': torch.tensor([0, 1, 2, 3, 4, 5]), 'ts_beats': torch.tensor([0, 3, 41]), # tokens -> actual values 'grid_position': torch.tensor([0, 16, 32]), 'voice': torch.tensor([0, 1, 24]), } raw_preds = undo_feature_shifts(test_preds) print(f"\tpitch_octave: {test_preds['pitch_octave'].tolist()} -> {raw_preds['pitch_octave'].tolist()} (shift -3)") print(f"\tkey_fifths: {test_preds['key_fifths'].tolist()} -> {raw_preds['key_fifths'].tolist()} (shift -7)") print(f"\tts_beat_type: {test_preds['ts_beat_type'].tolist()} -> {raw_preds['ts_beat_type'].tolist()} (remap)") print(f"\tts_beats: {test_preds['ts_beats'].tolist()} -> {raw_preds['ts_beats'].tolist()} (remap)") print(f"\tgrid_position: {test_preds['grid_position'].tolist()} -> {raw_preds['grid_position'].tolist()} (no change)") print(f"\tvoice: {test_preds['voice'].tolist()} -> {raw_preds['voice'].tolist()} (no change)") # Verify expected values assert raw_preds['pitch_octave'].tolist() == [-3, 0, 7, 11], "pitch_octave shift undo failed" assert raw_preds['key_fifths'].tolist() == [-7, 0, 7], "key_fifths shift undo failed" assert raw_preds['ts_beat_type'].tolist() == [1, 2, 4, 8, 16, 32], "ts_beat_type remap undo failed" assert raw_preds['ts_beats'].tolist() == [1, 4, 48], "ts_beats remap undo failed" assert raw_preds['grid_position'].tolist() == [0, 16, 32], "grid_position should be unchanged" assert raw_preds['voice'].tolist() == [0, 1, 24], "voice should be unchanged (already 0-based)" print("\t[OK] All shift tests passed!") print("\n[OK] All tests passed!")