Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Gradio demo for LaM-SLidE Music Score Autoencoder. | |
| Upload a MusicXML file (.mxl / .musicxml) or pick an example, choose a | |
| model, and get back the reconstructed MusicXML plus per-feature accuracy | |
| and a visual comparison of original vs. reconstructed scores. | |
| """ | |
| import shutil | |
| import sys | |
| import warnings | |
| from pathlib import Path | |
| # Add app root to path | |
| sys.path.insert(0, str(Path(__file__).resolve().parent)) | |
| warnings.filterwarnings("ignore") | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import verovio | |
| from omegaconf import OmegaConf | |
| VEROVIO_DATA_DIR = os.path.join(os.path.dirname(verovio.__file__), "data") | |
| from src.model.autoencoder import create_autoencoder_from_dict | |
| from inference import ( | |
| extract_features_from_graph, | |
| reconstruct_from_graph, | |
| undo_feature_shifts, | |
| ) | |
| from reconstruct_mxl import ( | |
| reconstruct_score, | |
| load_duration_vocabulary, | |
| load_position_vocabulary, | |
| ) | |
| from convert_mxl import mxl_to_graph_data, load_vocab_forward | |
| # ============================================================================= | |
| # Paths | |
| # ============================================================================= | |
| APP_DIR = Path(__file__).resolve().parent | |
| VOCAB_DIR = APP_DIR / "vocabs" | |
| MODELS_DIR = APP_DIR / "models" | |
| EXAMPLES_DIR = APP_DIR / "examples" | |
| TMP_DIR = APP_DIR / "tmp" | |
| TMP_DIR.mkdir(exist_ok=True) | |
| # Available models | |
| MODELS = { | |
| "Wide Large Factorized (best model)": "wide_large_factorized", | |
| "HGT Wide Factorized (graph-aware)": "hgt_wide_factorized", | |
| } | |
| # Example .pt graph files (each may have a paired .mxl for original rendering) | |
| # Sort by note count (extracted from filename pattern *_<N>notes_*) so the | |
| # dropdown goes from smallest to largest score. | |
| def _note_count_key(p: Path) -> int: | |
| import re | |
| m = re.search(r"_(\d+)notes", p.stem) | |
| return int(m.group(1)) if m else 0 | |
| def _example_display_name(p: Path) -> str: | |
| """Format: '64 notes — Schubert: An die Laute (QmNcR2oAkq)'""" | |
| import re | |
| stem = p.stem | |
| m = re.search(r"_(\d+)notes_(.+)$", stem) | |
| if not m: | |
| return stem | |
| notes = m.group(1) | |
| qm_hash = m.group(2) | |
| # Everything before _<N>notes is the composer_title part | |
| prefix = stem[:m.start()] | |
| # Convert underscores to spaces and title-case | |
| title = prefix.replace("_", " ").title() | |
| return f"{notes} notes \u2014 {title} ({qm_hash})" | |
| EXAMPLES = sorted( | |
| (p for p in EXAMPLES_DIR.iterdir() if p.suffix == ".pt"), | |
| key=_note_count_key, | |
| ) if EXAMPLES_DIR.exists() else [] | |
| # Map display name -> filename for dropdown | |
| _EXAMPLE_DISPLAY = {_example_display_name(p): p.name for p in EXAMPLES} | |
| _EXAMPLE_LOOKUP = {v: k for k, v in _EXAMPLE_DISPLAY.items()} # filename -> display | |
| # File extensions recognised as MusicXML | |
| MXL_EXTENSIONS = {".mxl", ".musicxml", ".xml"} | |
| # Short display names for the accuracy table | |
| FEATURE_SHORT_NAMES = { | |
| "grid_position": "grid_pos", | |
| "micro_offset": "micro", | |
| "measure_idx": "bar", | |
| "voice": "voice", | |
| "pitch_step": "step", | |
| "pitch_alter": "alter", | |
| "pitch_octave": "oct", | |
| "duration": "dur", | |
| "clef": "clef", | |
| "ts_beats": "ts_b", | |
| "ts_beat_type": "ts_bt", | |
| "key_fifths": "key", | |
| "staff": "staff", | |
| } | |
| # ============================================================================= | |
| # Load vocabs once | |
| # ============================================================================= | |
| duration_vocab_inv = load_duration_vocabulary(VOCAB_DIR / "duration_vocab.json") | |
| grid_vocab_inv = load_position_vocabulary(VOCAB_DIR / "grid_vocab.json") | |
| micro_vocab_inv = load_position_vocabulary(VOCAB_DIR / "micro_vocab.json") | |
| # Forward vocabs for MXL → graph conversion (loaded from training data) | |
| duration_vocab_fwd = load_vocab_forward(VOCAB_DIR / "duration_vocab.json") | |
| grid_vocab_fwd = load_vocab_forward(VOCAB_DIR / "grid_vocab.json") | |
| micro_vocab_fwd = load_vocab_forward(VOCAB_DIR / "micro_vocab.json") | |
| # ============================================================================= | |
| # Model cache | |
| # ============================================================================= | |
| _model_cache = {} | |
| def get_model(model_key: str): | |
| """Load model (cached).""" | |
| if model_key in _model_cache: | |
| return _model_cache[model_key] | |
| model_dir = MODELS_DIR / model_key | |
| config_path = list(model_dir.glob("*.yaml"))[0] | |
| checkpoint_path = model_dir / "best_model.pt" | |
| cfg = OmegaConf.load(config_path) | |
| config_dict = OmegaConf.to_container(cfg.model, resolve=True) | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) | |
| model = create_autoencoder_from_dict(config_dict) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.eval() | |
| model.cpu() | |
| _model_cache[model_key] = model | |
| return model | |
| # ============================================================================= | |
| # Score rendering with verovio | |
| # ============================================================================= | |
| def render_score_to_svg_pages( | |
| mxl_path: str, prefix: str, | |
| page_width: int = 2100, scale: int = 35, | |
| ) -> list[str]: | |
| """Render a MusicXML file to one SVG file per page. Returns list of paths.""" | |
| tk = verovio.toolkit(False) | |
| tk.setResourcePath(VEROVIO_DATA_DIR) | |
| tk.loadFile(mxl_path) | |
| tk.setOptions({ | |
| "pageWidth": page_width, | |
| "scale": scale, | |
| "adjustPageHeight": True, | |
| "footer": "none", | |
| "header": "none", | |
| }) | |
| tk.redoLayout() | |
| paths = [] | |
| for page in range(1, tk.getPageCount() + 1): | |
| svg_path = f"{prefix}_p{page}.svg" | |
| Path(svg_path).write_text(tk.renderToSVG(page), encoding="utf-8") | |
| paths.append(svg_path) | |
| return paths | |
| # ============================================================================= | |
| # Core logic | |
| # ============================================================================= | |
| FEATURE_KEY_MAP = { | |
| "grid_position": "position_grid_token", | |
| "micro_offset": "position_micro_token", | |
| "duration": "duration_token", | |
| "pitch_step": "pitch_step", | |
| "pitch_alter": "pitch_alter", | |
| "pitch_octave": "pitch_octave", | |
| "measure_idx": "measure_idx", | |
| "voice": "voice", | |
| "staff": "staff", | |
| "clef": "clef", | |
| "ts_beats": "ts_beats", | |
| "ts_beat_type": "ts_beat_type", | |
| "key_fifths": "key_fifths", | |
| } | |
| def build_recon_features(gt_raw: dict, raw_predictions: dict) -> dict: | |
| """Build feature dict for reconstruct_score from raw ground-truth and predictions.""" | |
| recon_features = {} | |
| for raw_key, tensor in gt_raw.items(): | |
| recon_features[raw_key] = tensor.numpy() | |
| for model_key, output_key in FEATURE_KEY_MAP.items(): | |
| if model_key in raw_predictions: | |
| val = raw_predictions[model_key] | |
| recon_features[output_key] = val.numpy() if isinstance(val, torch.Tensor) else val | |
| return recon_features | |
| def score_to_mxl(score, path: Path) -> str: | |
| """Write a music21 score to MusicXML, return the path string.""" | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| score.write("musicxml", fp=str(path)) | |
| return str(path) | |
| def compute_per_feature_accuracy( | |
| predictions: dict, ground_truth: dict, feature_names: list | |
| ) -> dict: | |
| """Compute per-feature token accuracy.""" | |
| accs = {} | |
| for name in feature_names: | |
| if name in predictions and name in ground_truth: | |
| pred = predictions[name] if isinstance(predictions[name], torch.Tensor) else torch.tensor(predictions[name]) | |
| gt = ground_truth[name] if isinstance(ground_truth[name], torch.Tensor) else torch.tensor(ground_truth[name]) | |
| accs[name] = (pred == gt).float().mean().item() | |
| return accs | |
| def run_reconstruction(input_file, model_name: str): | |
| """ | |
| Main pipeline: load/convert input -> run model -> reconstruct MusicXML. | |
| Accepts either a pre-processed score graph (.pt) or a MusicXML file | |
| (.mxl / .musicxml / .xml). MusicXML files are converted on-the-fly. | |
| Returns | |
| ------- | |
| (gt_first, recon_first, gt_orig_pages, gt_recon_pages, recon_pages, | |
| has_original_mxl, num_pages, mxl_path, accuracy_md, info_md) | |
| """ | |
| empty = (None, None, [], [], [], False, 1, None, "", "") | |
| if input_file is None: | |
| return empty | |
| # Clean tmp directory | |
| if TMP_DIR.exists(): | |
| shutil.rmtree(TMP_DIR) | |
| TMP_DIR.mkdir(exist_ok=True) | |
| model_key = MODELS[model_name] | |
| model = get_model(model_key) | |
| # Load or convert input | |
| input_path = Path(input_file) if isinstance(input_file, str) else Path(input_file.name) | |
| is_mxl = input_path.suffix.lower() in MXL_EXTENSIONS | |
| original_mxl_path = str(input_path) if is_mxl else None | |
| # For .pt graphs, check for a paired original .mxl alongside | |
| if not is_mxl: | |
| paired_mxl = input_path.with_suffix(".mxl") | |
| if paired_mxl.exists(): | |
| original_mxl_path = str(paired_mxl) | |
| if is_mxl: | |
| try: | |
| graph_data = mxl_to_graph_data( | |
| str(input_path), duration_vocab_fwd, | |
| grid_vocab_fwd, micro_vocab_fwd, | |
| ) | |
| except Exception as e: | |
| return (*empty[:8], f"**Conversion error**: {e}", "") | |
| else: | |
| graph_data = torch.load(input_path, map_location="cpu", weights_only=False) | |
| num_notes = graph_data["num_notes"] | |
| source = graph_data.get("source_file", "unknown") | |
| # Clean source path: keep only from /mxl/ onwards, prefix with pdmx | |
| if "/mxl/" in source: | |
| source = "pdmx" + source[source.index("/mxl/"):] | |
| feature_names = [f.name for f in model.config.input_features] | |
| # Extract ground truth features | |
| gt_features, entity_ids, gt_raw = extract_features_from_graph( | |
| graph_data, feature_names, | |
| identifier_pool_size=model.config.identifier_pool_size, | |
| id_assignment="sequential", | |
| ) | |
| # Run model | |
| features_batch = {k: v.unsqueeze(0) for k, v in gt_features.items()} | |
| entity_ids_batch = entity_ids.unsqueeze(0) | |
| mask = torch.ones(1, num_notes, dtype=torch.bool) | |
| kwargs = {} | |
| if getattr(model.config, "use_hgt", False): | |
| from src.model.note_hgt import NoteHGT | |
| edge_dict = NoteHGT.extract_edge_dict(graph_data["graph"]) | |
| kwargs["edge_dicts"] = [edge_dict] | |
| with torch.no_grad(): | |
| logits = model(features_batch, entity_ids_batch, mask=mask, **kwargs) | |
| predictions = { | |
| name: logits[name][0].argmax(dim=-1).cpu() | |
| for name in logits.keys() | |
| } | |
| # ---- Accuracy ---- | |
| accs = compute_per_feature_accuracy(predictions, gt_features, feature_names) | |
| core_names = [ | |
| "grid_position", "micro_offset", "measure_idx", | |
| "pitch_step", "pitch_alter", "pitch_octave", | |
| "duration", "staff", "voice", | |
| ] | |
| core_correct = None | |
| for name in core_names: | |
| if name in predictions and name in gt_features: | |
| match = (predictions[name] == gt_features[name]) | |
| core_correct = match if core_correct is None else (core_correct & match) | |
| core_joint = core_correct.float().mean().item() if core_correct is not None else 0.0 | |
| all_correct = None | |
| for name in feature_names: | |
| if name in predictions and name in gt_features: | |
| match = (predictions[name] == gt_features[name]) | |
| all_correct = match if all_correct is None else (all_correct & match) | |
| all_joint = all_correct.float().mean().item() if all_correct is not None else 0.0 | |
| # Horizontal accuracy table | |
| header_cells = " | ".join( | |
| FEATURE_SHORT_NAMES.get(n, n) for n in feature_names if n in accs | |
| ) | |
| value_cells = " | ".join( | |
| f"{accs[n]*100:.1f}%" for n in feature_names if n in accs | |
| ) | |
| accuracy_md = ( | |
| f"| {header_cells} | **core** | **all** |\n" | |
| f"|{'---|' * (sum(1 for n in feature_names if n in accs) + 2)}\n" | |
| f"| {value_cells} | **{core_joint*100:.1f}%** | **{all_joint*100:.1f}%** |" | |
| ) | |
| # ---- Reconstruct MusicXML (predicted) ---- | |
| raw_predictions = undo_feature_shifts(predictions) | |
| recon_features = build_recon_features(gt_raw, raw_predictions) | |
| mxl_path = None | |
| recon_pages: list[str] = [] | |
| try: | |
| recon_score = reconstruct_score( | |
| recon_features, grid_vocab_inv, micro_vocab_inv, | |
| duration_vocab_inv, verbose=False, | |
| ) | |
| mxl_path = score_to_mxl(recon_score, TMP_DIR / "reconstructed.musicxml") | |
| recon_pages = render_score_to_svg_pages( | |
| mxl_path, str(TMP_DIR / "reconstructed"), | |
| ) | |
| except Exception as e: | |
| accuracy_md += f"\n\n**Reconstruction error**: {e}" | |
| # ---- Render ground truth: original MXL (if available) ---- | |
| gt_orig_pages: list[str] = [] | |
| if original_mxl_path is not None: | |
| try: | |
| gt_orig_pages = render_score_to_svg_pages( | |
| original_mxl_path, str(TMP_DIR / "gt_original"), | |
| ) | |
| except Exception: | |
| pass | |
| # ---- Render ground truth: reconstructed from features (always) ---- | |
| gt_recon_pages: list[str] = [] | |
| gt_raw_np = {k: v.numpy() if isinstance(v, torch.Tensor) else v for k, v in gt_raw.items()} | |
| try: | |
| gt_score = reconstruct_score( | |
| gt_raw_np, grid_vocab_inv, micro_vocab_inv, | |
| duration_vocab_inv, verbose=False, | |
| ) | |
| gt_mxl_path = score_to_mxl(gt_score, TMP_DIR / "gt_reconstructed.musicxml") | |
| gt_recon_pages = render_score_to_svg_pages( | |
| gt_mxl_path, str(TMP_DIR / "gt_reconstructed"), | |
| ) | |
| except Exception: | |
| pass # ground-truth rendering is best-effort | |
| has_original_mxl = bool(gt_orig_pages) | |
| # Default: show original MXL if available, otherwise features-reconstructed | |
| gt_pages = gt_orig_pages if has_original_mxl else gt_recon_pages | |
| info_md = ( | |
| f"**Source**: {source} | " | |
| f"**Notes**: {num_notes} | " | |
| f"**Model**: {model_name}" | |
| ) | |
| if is_mxl and graph_data.get("truncated", False): | |
| info_md += ( | |
| f" | **Truncated** to {num_notes} notes " | |
| f"({graph_data['total_bars']} bars) to fit model limit" | |
| ) | |
| # Page count = max of the *active* GT variant and the reconstructed pages | |
| num_pages = max(len(gt_pages), len(recon_pages), 1) | |
| gt_first = gt_pages[0] if gt_pages else None | |
| recon_first = recon_pages[0] if recon_pages else None | |
| return ( | |
| gt_first, recon_first, | |
| gt_orig_pages, gt_recon_pages, recon_pages, | |
| has_original_mxl, num_pages, | |
| mxl_path, accuracy_md, info_md, | |
| ) | |
| # ============================================================================= | |
| # Gradio UI | |
| # ============================================================================= | |
| def build_demo(): | |
| with gr.Blocks( | |
| title="A Fixed-size Latent Space Autoencoder for Music Scores", | |
| theme=gr.themes.Soft(), | |
| ) as demo: | |
| gr.Markdown( | |
| "# A Fixed-size Latent Space Autoencoder for Music Scores\n" | |
| "Upload a MusicXML file (`.mxl` / `.musicxml`) or select an example from [PDMX](https://zenodo.org/records/15571083), " | |
| "choose a model, and reconstruct a MusicXML file. Scores are rendered " | |
| "with [Verovio](https://www.verovio.org).\n\n" | |
| "*Companion demo for: A Fixed-size Latent Space Autoencoder for Music Scores " | |
| "(Hendrik Roth, Emmanouil Karystinaios & Gerhard Widmer, JKU Linz) — " | |
| "[GitHub](https://github.com/hendrik-roth/score-ae)*" | |
| ) | |
| # Hidden state for page lists | |
| gt_orig_pages_state = gr.State([]) | |
| gt_recon_pages_state = gr.State([]) | |
| recon_pages_state = gr.State([]) | |
| has_orig_mxl_state = gr.State(False) | |
| with gr.Row(equal_height=False): | |
| # ---- Left column: controls ---- | |
| with gr.Column(scale=1, min_width=280): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value=list(MODELS.keys())[0], | |
| label="Model", | |
| ) | |
| graph_input = gr.File( | |
| label="Upload MusicXML (.mxl / .musicxml / .xml)", | |
| file_types=[".mxl", ".musicxml", ".xml", ".pt"], | |
| ) | |
| example_dropdown = gr.Dropdown( | |
| choices=["(none)"] + list(_EXAMPLE_DISPLAY.keys()), | |
| value="(none)", | |
| label="Or pick an example from our PDMX test set (deduplicated, no licence conflict subset)", | |
| ) | |
| run_btn = gr.Button("Reconstruct", variant="primary", size="lg") | |
| info_output = gr.Markdown() | |
| # ---- Right column: download, scores, accuracy ---- | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| mxl_output = gr.File(label="Download reconstructed MusicXML", scale=1, min_width=200) | |
| gr.Column(scale=2) # spacer | |
| with gr.Row(): | |
| page_slider = gr.Slider( | |
| minimum=1, maximum=1, step=1, value=1, | |
| label="Page", visible=False, scale=1, | |
| ) | |
| gr.Column(scale=2) # spacer | |
| gr.Markdown( | |
| "*Note: The **Ground Truth** column shows the score rendered from " | |
| "our discrete feature representation. When the original MusicXML is " | |
| "available, tick \"Raw MXL engraving\" to see the original " | |
| "publisher engraving instead.*", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| gr.Markdown("#### Ground Truth") | |
| show_orig_mxl = gr.Checkbox( | |
| label="Raw MXL engraving", | |
| value=True, | |
| visible=False, | |
| scale=0, | |
| ) | |
| gt_image = gr.Image( | |
| type="filepath", | |
| show_label=False, | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("#### Reconstructed") | |
| recon_image = gr.Image( | |
| type="filepath", | |
| show_label=False, | |
| ) | |
| gr.Markdown("#### Note-level accuracy") | |
| accuracy_output = gr.Markdown() | |
| # -- Callbacks -- | |
| def on_run(graph_file, example_name, model_name): | |
| if graph_file is None and example_name != "(none)": | |
| # Resolve display name back to filename | |
| filename = _EXAMPLE_DISPLAY.get(example_name, example_name) | |
| graph_file = str(EXAMPLES_DIR / filename) | |
| ( | |
| gt_first, recon_first, | |
| gt_orig_pgs, gt_recon_pgs, recon_pgs, | |
| has_orig, num_pages, | |
| mxl, acc_md, info_md, | |
| ) = run_reconstruction(graph_file, model_name) | |
| slider_update = gr.Slider( | |
| minimum=1, maximum=num_pages, step=1, value=1, | |
| label="Page", visible=(num_pages > 1), | |
| ) | |
| checkbox_update = gr.Checkbox( | |
| label="Raw MXL engraving", | |
| value=True, | |
| visible=has_orig, | |
| ) | |
| return ( | |
| gt_first, recon_first, | |
| gt_orig_pgs, gt_recon_pgs, recon_pgs, | |
| has_orig, | |
| slider_update, checkbox_update, | |
| mxl, acc_md, info_md, | |
| ) | |
| def on_page_change(page_num, show_orig, gt_orig_pgs, gt_recon_pgs, recon_pgs, has_orig): | |
| idx = int(page_num) - 1 | |
| gt_pgs = gt_orig_pgs if (show_orig and has_orig) else gt_recon_pgs | |
| gt_img = gt_pgs[idx] if idx < len(gt_pgs) else None | |
| recon_img = recon_pgs[idx] if idx < len(recon_pgs) else None | |
| return gt_img, recon_img | |
| def on_toggle_orig(show_orig, page_num, gt_orig_pgs, gt_recon_pgs, recon_pgs, has_orig): | |
| gt_pgs = gt_orig_pgs if (show_orig and has_orig) else gt_recon_pgs | |
| # Recompute page count from both active GT and reconstructed | |
| num_pages = max(len(gt_pgs), len(recon_pgs), 1) | |
| clamped_page = min(int(page_num), num_pages) | |
| idx = clamped_page - 1 | |
| gt_img = gt_pgs[idx] if idx < len(gt_pgs) else None | |
| recon_img = recon_pgs[idx] if idx < len(recon_pgs) else None | |
| slider_update = gr.Slider( | |
| minimum=1, maximum=num_pages, | |
| step=1, value=clamped_page, | |
| label="Page", visible=(num_pages > 1), | |
| ) | |
| return gt_img, recon_img, slider_update | |
| run_btn.click( | |
| fn=on_run, | |
| inputs=[graph_input, example_dropdown, model_dropdown], | |
| outputs=[ | |
| gt_image, recon_image, | |
| gt_orig_pages_state, gt_recon_pages_state, recon_pages_state, | |
| has_orig_mxl_state, | |
| page_slider, show_orig_mxl, | |
| mxl_output, accuracy_output, info_output, | |
| ], | |
| ) | |
| page_slider.change( | |
| fn=on_page_change, | |
| inputs=[page_slider, show_orig_mxl, gt_orig_pages_state, gt_recon_pages_state, recon_pages_state, has_orig_mxl_state], | |
| outputs=[gt_image, recon_image], | |
| ) | |
| show_orig_mxl.change( | |
| fn=on_toggle_orig, | |
| inputs=[show_orig_mxl, page_slider, gt_orig_pages_state, gt_recon_pages_state, recon_pages_state, has_orig_mxl_state], | |
| outputs=[gt_image, recon_image, page_slider], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.launch() | |