Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Inference Pipeline for LaM-SLidE Autoencoder. | |
| This module provides functions for: | |
| - Extracting features from score graphs (with preprocessing) | |
| - Running inference through a trained autoencoder | |
| - Undoing preprocessing shifts/remaps for reconstruction | |
| - Full pipeline: graph -> model -> reconstructed features | |
| Usage: | |
| from scripts.inference import reconstruct_from_graph, load_model_and_reconstruct | |
| # With a loaded model | |
| features = reconstruct_from_graph(model, graph_data) | |
| # Load model and reconstruct in one call | |
| features = load_model_and_reconstruct("checkpoint.pt", "graph.pt") | |
| """ | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import sys | |
| # Add app root to path | |
| sys.path.insert(0, str(Path(__file__).resolve().parent)) | |
| import numpy as np | |
| import torch | |
| # ============================================================================= | |
| # Feature Transformation Specs | |
| # ============================================================================= | |
| # Reverse mapping for ts_beat_type: token -> actual denominator | |
| TS_BEAT_TYPE_INV = {0: 1, 1: 2, 2: 4, 3: 8, 4: 16, 5: 32} | |
| # Forward mapping for ts_beat_type: denominator -> token | |
| TS_BEAT_TYPE_REMAP = {1: 0, 2: 1, 4: 2, 8: 3, 16: 4, 32: 5} | |
| # Reverse mapping for ts_beats: token -> actual numerator | |
| TS_BEATS_REMAP = { | |
| 1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9, | |
| 11: 10, 12: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, | |
| 21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 26: 25, 27: 26, 28: 27, 29: 28, 30: 29, | |
| 31: 30, 33: 31, 34: 32, 35: 33, 36: 34, 37: 35, 39: 36, 41: 37, 42: 38, 45: 39, | |
| 47: 40, 48: 41, 54: 42, 55: 43, 56: 44, 57: 45, 58: 46, 59: 47, 60: 48, 61: 49, | |
| 62: 50, 63: 51, 64: 52, 67: 53, 69: 54, 73: 55, 80: 56, | |
| } | |
| TS_BEATS_INV = {v: k for k, v in TS_BEATS_REMAP.items()} | |
| # Column indices in graph['note'].x tensor (from datamodule.py) | |
| COL_INDICES = { | |
| 'grid_position': 0, # tokenised grid position | |
| 'micro_offset': 1, # tokenised micro offset | |
| 'measure_idx': 2, # 0-based bar index | |
| 'voice': 3, # 0-based voice | |
| 'pitch_step': 4, # 0-6 (C...B) | |
| 'pitch_alter': 5, # 0-4 (shifted +2 in extraction) | |
| 'pitch_octave': 6, # raw octave (-3..11) | |
| 'duration': 7, # tokenised duration | |
| 'clef': 8, # 0-5 | |
| 'ts_beats': 9, # raw numerator | |
| 'ts_beat_type': 10, # raw denominator {1,2,4,8,16,32} | |
| 'key_fifths': 11, # circle of fifths (-7..+7) | |
| 'key_mode': 12, # 0=major, 1=minor | |
| 'staff': 13, # 0-based global staff | |
| } | |
| # Shifts applied during preprocessing (from datamodule.py FEATURE_SPECS) | |
| # Shift applied: values = values + shift -> to undo: values = values - shift | |
| FEATURE_SHIFTS = { | |
| 'pitch_octave': 3, # [-3, 11] -> [0, 14], undo: -3 | |
| 'key_fifths': 7, # [-7, +7] -> [0, 14], undo: -7 | |
| } | |
| # Features that use remapping (sparse -> dense) | |
| FEATURE_REMAPS = { | |
| 'ts_beat_type': TS_BEAT_TYPE_REMAP, | |
| 'ts_beats': TS_BEATS_REMAP, | |
| } | |
| FEATURE_REMAPS_INV = { | |
| 'ts_beat_type': TS_BEAT_TYPE_INV, | |
| 'ts_beats': TS_BEATS_INV, | |
| } | |
| # ============================================================================= | |
| # Feature Extraction | |
| # ============================================================================= | |
| def extract_features_from_graph( | |
| graph_data: Dict, | |
| feature_names: List[str], | |
| identifier_pool_size: int = 1600, | |
| seed: int = 42, | |
| id_assignment: str = 'sequential', | |
| ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor, Dict[str, torch.Tensor]]: | |
| """ | |
| Extract features from a score graph in the format expected by the autoencoder. | |
| Applies the same transformations as the datamodule (shifts, remaps). | |
| Also returns the original (non-model) features for reconstruction. | |
| Args: | |
| graph_data: Loaded graph data dict with 'graph' key | |
| feature_names: List of feature names to extract for the model | |
| identifier_pool_size: Size of entity ID pool | |
| seed: Random seed for entity ID assignment | |
| id_assignment: 'random' or 'sequential' | |
| Returns: | |
| model_features: Dict of feature_name -> (N,) tensors (shifted/remapped) | |
| entity_ids: (N,) entity identifiers | |
| raw_features: Dict of all raw features from graph (for reconstruction) | |
| """ | |
| graph = graph_data['graph'] | |
| note_x = graph['note'].x # Shape: [num_notes, 13] | |
| num_notes = note_x.shape[0] | |
| # Assign entity IDs | |
| if id_assignment == 'random': | |
| rng = np.random.RandomState(seed) | |
| ids = rng.choice(identifier_pool_size, size=num_notes, replace=False) | |
| entity_ids = torch.from_numpy(ids).long() | |
| else: # sequential | |
| entity_ids = torch.arange(num_notes, dtype=torch.long) | |
| # Extract model features (with shifts/remaps applied) | |
| model_features = {} | |
| for feat_name in feature_names: | |
| col_idx = COL_INDICES[feat_name] | |
| values = note_x[:, col_idx].long() | |
| # Apply transformations (same as datamodule __getitem__) | |
| if feat_name in FEATURE_REMAPS: | |
| remap = FEATURE_REMAPS[feat_name] | |
| max_key = max(remap.keys()) | |
| remap_tensor = torch.zeros(max_key + 1, dtype=torch.long) | |
| for old_val, new_val in remap.items(): | |
| remap_tensor[old_val] = new_val | |
| values = remap_tensor[values.clamp(0, max_key)] | |
| elif feat_name in FEATURE_SHIFTS: | |
| values = values + FEATURE_SHIFTS[feat_name] | |
| model_features[feat_name] = values | |
| # Extract ALL raw features for reconstruction (original values from graph) | |
| raw_features = { | |
| 'pitch_step': note_x[:, COL_INDICES['pitch_step']].long(), | |
| 'pitch_alter': note_x[:, COL_INDICES['pitch_alter']].long(), | |
| 'pitch_octave': note_x[:, COL_INDICES['pitch_octave']].long(), | |
| 'position_grid_token': note_x[:, COL_INDICES['grid_position']].long(), | |
| 'position_micro_token': note_x[:, COL_INDICES['micro_offset']].long(), | |
| 'duration_token': note_x[:, COL_INDICES['duration']].long(), | |
| 'measure_idx': note_x[:, COL_INDICES['measure_idx']].long(), | |
| 'voice': note_x[:, COL_INDICES['voice']].long(), | |
| 'staff': note_x[:, COL_INDICES['staff']].long(), | |
| 'clef': note_x[:, COL_INDICES['clef']].long(), | |
| 'ts_beats': note_x[:, COL_INDICES['ts_beats']].long(), | |
| 'ts_beat_type': note_x[:, COL_INDICES['ts_beat_type']].long(), | |
| 'key_fifths': note_x[:, COL_INDICES['key_fifths']].long(), | |
| 'key_mode': note_x[:, COL_INDICES['key_mode']].long(), | |
| } | |
| return model_features, entity_ids, raw_features | |
| # ============================================================================= | |
| # Feature Shift/Remap Undo | |
| # ============================================================================= | |
| def undo_feature_shifts( | |
| predictions: Dict[str, torch.Tensor], | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Undo the shifts/remaps applied during preprocessing. | |
| Converts model predictions back to their original value ranges. | |
| Args: | |
| predictions: Dict of feature_name -> (N,) predicted token indices | |
| Returns: | |
| raw_predictions: Dict with original value ranges | |
| """ | |
| raw_predictions = {} | |
| for feat_name, values in predictions.items(): | |
| if feat_name in FEATURE_REMAPS_INV: | |
| # Reverse remap: token -> original value | |
| inv_map = FEATURE_REMAPS_INV[feat_name] | |
| max_val = max(inv_map.keys()) + 1 | |
| inv_remap_tensor = torch.zeros(max_val, dtype=torch.long) | |
| for token, actual in inv_map.items(): | |
| inv_remap_tensor[token] = actual | |
| raw_predictions[feat_name] = inv_remap_tensor[values.clamp(0, max_val - 1)] | |
| elif feat_name in FEATURE_SHIFTS: | |
| # Undo shift: values = values - shift | |
| raw_predictions[feat_name] = values - FEATURE_SHIFTS[feat_name] | |
| else: | |
| raw_predictions[feat_name] = values | |
| return raw_predictions | |
| # ============================================================================= | |
| # Reconstruction Pipeline | |
| # ============================================================================= | |
| def reconstruct_from_graph( | |
| model: 'LaMSLiDEAutoencoder', | |
| graph_data: Dict, | |
| identifier_pool_size: int = 1600, | |
| seed: int = 42, | |
| id_assignment: str = 'sequential', | |
| device: Optional[Union[str, torch.device]] = None, | |
| ) -> Dict[str, np.ndarray]: | |
| """ | |
| Reconstruct features from a score graph using the autoencoder. | |
| This is the main inference pipeline: | |
| 1. Extract features from graph (with preprocessing) | |
| 2. Run through autoencoder (encode + decode) | |
| 3. Get predicted tokens (argmax) | |
| 4. Undo preprocessing shifts | |
| 5. Return features in format expected by reconstruct_score() | |
| Args: | |
| model: Trained LaMSLiDEAutoencoder | |
| graph_data: Loaded graph data dict with 'graph' key | |
| identifier_pool_size: Size of entity ID pool | |
| seed: Random seed for entity ID assignment | |
| id_assignment: 'random' or 'sequential' | |
| device: Device to run inference on | |
| Returns: | |
| features: Dict with reconstructed features ready for reconstruct_score() | |
| Keys match what reconstruct_mxl.py expects: | |
| - pitch_step, pitch_alter, pitch_octave | |
| - position_grid_token, position_micro_token | |
| - duration_token, bar_idx, voice, staff | |
| - clef, ts_beats, ts_beat_type, key_signature | |
| """ | |
| if device is None: | |
| device = next(model.parameters()).device | |
| model.eval() | |
| # Get feature names from model config | |
| model_feature_names = [f.name for f in model.config.input_features] | |
| # Extract features from graph | |
| model_features, entity_ids, raw_features = extract_features_from_graph( | |
| graph_data, | |
| model_feature_names, | |
| identifier_pool_size=identifier_pool_size, | |
| seed=seed, | |
| id_assignment=id_assignment, | |
| ) | |
| # Move to device and add batch dimension | |
| model_features_batch = { | |
| k: v.unsqueeze(0).to(device) for k, v in model_features.items() | |
| } | |
| entity_ids_batch = entity_ids.unsqueeze(0).to(device) | |
| num_notes = entity_ids.shape[0] | |
| mask = torch.ones(1, num_notes, dtype=torch.bool, device=device) | |
| # Forward pass | |
| with torch.no_grad(): | |
| logits = model(model_features_batch, entity_ids_batch, mask=mask) | |
| # Get predictions (argmax) | |
| predictions = { | |
| name: logits[name][0].argmax(dim=-1).cpu() # (N,) | |
| for name in logits.keys() | |
| } | |
| # Undo preprocessing shifts | |
| raw_predictions = undo_feature_shifts(predictions) | |
| # Build output features dict | |
| # Use reconstructed features for what the model predicts, | |
| # keep original features for everything else | |
| output_features = {} | |
| # Map model feature names to reconstruct_mxl.py expected keys | |
| feature_key_map = { | |
| 'grid_position': 'position_grid_token', | |
| 'micro_offset': 'position_micro_token', | |
| 'duration': 'duration_token', | |
| # These keep the same name | |
| 'pitch_step': 'pitch_step', | |
| 'pitch_alter': 'pitch_alter', | |
| 'pitch_octave': 'pitch_octave', | |
| 'measure_idx': 'measure_idx', | |
| 'voice': 'voice', | |
| 'staff': 'staff', | |
| 'clef': 'clef', | |
| 'ts_beats': 'ts_beats', | |
| 'ts_beat_type': 'ts_beat_type', | |
| 'key_fifths': 'key_fifths', | |
| } | |
| # Start with all raw features (for non-predicted features) | |
| for key, tensor in raw_features.items(): | |
| output_features[key] = tensor.numpy() | |
| # Override with predictions for features the model reconstructed | |
| for model_name, output_key in feature_key_map.items(): | |
| if model_name in raw_predictions: | |
| output_features[output_key] = raw_predictions[model_name].numpy() | |
| return output_features | |
| def load_model_and_reconstruct( | |
| checkpoint_path: Union[str, Path], | |
| graph_path: Union[str, Path], | |
| config_path: Optional[Union[str, Path]] = None, | |
| identifier_pool_size: Optional[int] = None, | |
| seed: int = 42, | |
| id_assignment: str = 'sequential', | |
| device: Optional[Union[str, torch.device]] = None, | |
| ) -> Dict[str, np.ndarray]: | |
| """ | |
| Load a trained model and reconstruct features from a graph file. | |
| Convenience function for the full inference pipeline. | |
| Args: | |
| checkpoint_path: Path to saved model checkpoint | |
| graph_path: Path to .pt graph file | |
| config_path: Path to YAML config (required if checkpoint lacks 'config' key) | |
| identifier_pool_size: Size of entity ID pool (overrides config if provided) | |
| seed: Random seed for entity ID assignment | |
| id_assignment: 'random' or 'sequential' | |
| device: Device to run inference on | |
| Returns: | |
| features: Dict with reconstructed features ready for reconstruct_score() | |
| """ | |
| # Import here to avoid circular imports | |
| from src.model.autoencoder import create_autoencoder_from_dict | |
| from omegaconf import OmegaConf | |
| checkpoint_path = Path(checkpoint_path) | |
| graph_path = Path(graph_path) | |
| if device is None: | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load checkpoint | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| # Extract config and create model | |
| if 'config' in checkpoint: | |
| config_dict = checkpoint['config'] | |
| elif config_path is not None: | |
| config_path = Path(config_path) | |
| raw_config = OmegaConf.load(config_path) | |
| config_dict = OmegaConf.to_container(raw_config.model, resolve=True) | |
| else: | |
| raise ValueError("Checkpoint must contain 'config' key or config_path must be provided") | |
| model = create_autoencoder_from_dict(config_dict) | |
| # Load weights | |
| if 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| elif 'state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['state_dict']) | |
| else: | |
| # Assume checkpoint is just the state dict | |
| model.load_state_dict(checkpoint) | |
| model = model.to(device) | |
| model.eval() | |
| # Load graph | |
| graph_data = torch.load(graph_path, weights_only=False) | |
| # Determine identifier pool size | |
| if identifier_pool_size is None: | |
| identifier_pool_size = model.config.identifier_pool_size | |
| # Reconstruct | |
| features = reconstruct_from_graph( | |
| model, | |
| graph_data, | |
| identifier_pool_size=identifier_pool_size, | |
| seed=seed, | |
| id_assignment=id_assignment, | |
| device=device, | |
| ) | |
| return features | |
| # ============================================================================= | |
| # Testing | |
| # ============================================================================= | |
| if __name__ == '__main__': | |
| print("Testing inference pipeline functions...") | |
| # Test feature shift undo | |
| print("\n1. Testing feature shift undo:") | |
| test_preds = { | |
| 'pitch_octave': torch.tensor([0, 3, 10, 14]), | |
| 'key_fifths': torch.tensor([0, 7, 14]), | |
| 'ts_beat_type': torch.tensor([0, 1, 2, 3, 4, 5]), | |
| 'ts_beats': torch.tensor([0, 3, 41]), # tokens -> actual values | |
| 'grid_position': torch.tensor([0, 16, 32]), | |
| 'voice': torch.tensor([0, 1, 24]), | |
| } | |
| raw_preds = undo_feature_shifts(test_preds) | |
| print(f"\tpitch_octave: {test_preds['pitch_octave'].tolist()} -> {raw_preds['pitch_octave'].tolist()} (shift -3)") | |
| print(f"\tkey_fifths: {test_preds['key_fifths'].tolist()} -> {raw_preds['key_fifths'].tolist()} (shift -7)") | |
| print(f"\tts_beat_type: {test_preds['ts_beat_type'].tolist()} -> {raw_preds['ts_beat_type'].tolist()} (remap)") | |
| print(f"\tts_beats: {test_preds['ts_beats'].tolist()} -> {raw_preds['ts_beats'].tolist()} (remap)") | |
| print(f"\tgrid_position: {test_preds['grid_position'].tolist()} -> {raw_preds['grid_position'].tolist()} (no change)") | |
| print(f"\tvoice: {test_preds['voice'].tolist()} -> {raw_preds['voice'].tolist()} (no change)") | |
| # Verify expected values | |
| assert raw_preds['pitch_octave'].tolist() == [-3, 0, 7, 11], "pitch_octave shift undo failed" | |
| assert raw_preds['key_fifths'].tolist() == [-7, 0, 7], "key_fifths shift undo failed" | |
| assert raw_preds['ts_beat_type'].tolist() == [1, 2, 4, 8, 16, 32], "ts_beat_type remap undo failed" | |
| assert raw_preds['ts_beats'].tolist() == [1, 4, 48], "ts_beats remap undo failed" | |
| assert raw_preds['grid_position'].tolist() == [0, 16, 32], "grid_position should be unchanged" | |
| assert raw_preds['voice'].tolist() == [0, 1, 24], "voice should be unchanged (already 0-based)" | |
| print("\t[OK] All shift tests passed!") | |
| print("\n[OK] All tests passed!") | |