#!/usr/bin/env python3 """ Gradio demo for LaM-SLidE Music Score Autoencoder. Upload a MusicXML file (.mxl / .musicxml) or pick an example, choose a model, and get back the reconstructed MusicXML plus per-feature accuracy and a visual comparison of original vs. reconstructed scores. """ import shutil import sys import warnings from pathlib import Path # Add app root to path sys.path.insert(0, str(Path(__file__).resolve().parent)) warnings.filterwarnings("ignore") import os import gradio as gr import numpy as np import torch import verovio from omegaconf import OmegaConf VEROVIO_DATA_DIR = os.path.join(os.path.dirname(verovio.__file__), "data") from src.model.autoencoder import create_autoencoder_from_dict from inference import ( extract_features_from_graph, reconstruct_from_graph, undo_feature_shifts, ) from reconstruct_mxl import ( reconstruct_score, load_duration_vocabulary, load_position_vocabulary, ) from convert_mxl import mxl_to_graph_data, load_vocab_forward # ============================================================================= # Paths # ============================================================================= APP_DIR = Path(__file__).resolve().parent VOCAB_DIR = APP_DIR / "vocabs" MODELS_DIR = APP_DIR / "models" EXAMPLES_DIR = APP_DIR / "examples" TMP_DIR = APP_DIR / "tmp" TMP_DIR.mkdir(exist_ok=True) # Available models MODELS = { "Wide Large Factorized (best model)": "wide_large_factorized", "HGT Wide Factorized (graph-aware)": "hgt_wide_factorized", } # Example .pt graph files (each may have a paired .mxl for original rendering) # Sort by note count (extracted from filename pattern *_notes_*) so the # dropdown goes from smallest to largest score. def _note_count_key(p: Path) -> int: import re m = re.search(r"_(\d+)notes", p.stem) return int(m.group(1)) if m else 0 def _example_display_name(p: Path) -> str: """Format: '64 notes — Schubert: An die Laute (QmNcR2oAkq)'""" import re stem = p.stem m = re.search(r"_(\d+)notes_(.+)$", stem) if not m: return stem notes = m.group(1) qm_hash = m.group(2) # Everything before _notes is the composer_title part prefix = stem[:m.start()] # Convert underscores to spaces and title-case title = prefix.replace("_", " ").title() return f"{notes} notes \u2014 {title} ({qm_hash})" EXAMPLES = sorted( (p for p in EXAMPLES_DIR.iterdir() if p.suffix == ".pt"), key=_note_count_key, ) if EXAMPLES_DIR.exists() else [] # Map display name -> filename for dropdown _EXAMPLE_DISPLAY = {_example_display_name(p): p.name for p in EXAMPLES} _EXAMPLE_LOOKUP = {v: k for k, v in _EXAMPLE_DISPLAY.items()} # filename -> display # File extensions recognised as MusicXML MXL_EXTENSIONS = {".mxl", ".musicxml", ".xml"} # Short display names for the accuracy table FEATURE_SHORT_NAMES = { "grid_position": "grid_pos", "micro_offset": "micro", "measure_idx": "bar", "voice": "voice", "pitch_step": "step", "pitch_alter": "alter", "pitch_octave": "oct", "duration": "dur", "clef": "clef", "ts_beats": "ts_b", "ts_beat_type": "ts_bt", "key_fifths": "key", "staff": "staff", } # ============================================================================= # Load vocabs once # ============================================================================= duration_vocab_inv = load_duration_vocabulary(VOCAB_DIR / "duration_vocab.json") grid_vocab_inv = load_position_vocabulary(VOCAB_DIR / "grid_vocab.json") micro_vocab_inv = load_position_vocabulary(VOCAB_DIR / "micro_vocab.json") # Forward vocabs for MXL → graph conversion (loaded from training data) duration_vocab_fwd = load_vocab_forward(VOCAB_DIR / "duration_vocab.json") grid_vocab_fwd = load_vocab_forward(VOCAB_DIR / "grid_vocab.json") micro_vocab_fwd = load_vocab_forward(VOCAB_DIR / "micro_vocab.json") # ============================================================================= # Model cache # ============================================================================= _model_cache = {} def get_model(model_key: str): """Load model (cached).""" if model_key in _model_cache: return _model_cache[model_key] model_dir = MODELS_DIR / model_key config_path = list(model_dir.glob("*.yaml"))[0] checkpoint_path = model_dir / "best_model.pt" cfg = OmegaConf.load(config_path) config_dict = OmegaConf.to_container(cfg.model, resolve=True) checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) model = create_autoencoder_from_dict(config_dict) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() model.cpu() _model_cache[model_key] = model return model # ============================================================================= # Score rendering with verovio # ============================================================================= def render_score_to_svg_pages( mxl_path: str, prefix: str, page_width: int = 2100, scale: int = 35, ) -> list[str]: """Render a MusicXML file to one SVG file per page. Returns list of paths.""" tk = verovio.toolkit(False) tk.setResourcePath(VEROVIO_DATA_DIR) tk.loadFile(mxl_path) tk.setOptions({ "pageWidth": page_width, "scale": scale, "adjustPageHeight": True, "footer": "none", "header": "none", }) tk.redoLayout() paths = [] for page in range(1, tk.getPageCount() + 1): svg_path = f"{prefix}_p{page}.svg" Path(svg_path).write_text(tk.renderToSVG(page), encoding="utf-8") paths.append(svg_path) return paths # ============================================================================= # Core logic # ============================================================================= FEATURE_KEY_MAP = { "grid_position": "position_grid_token", "micro_offset": "position_micro_token", "duration": "duration_token", "pitch_step": "pitch_step", "pitch_alter": "pitch_alter", "pitch_octave": "pitch_octave", "measure_idx": "measure_idx", "voice": "voice", "staff": "staff", "clef": "clef", "ts_beats": "ts_beats", "ts_beat_type": "ts_beat_type", "key_fifths": "key_fifths", } def build_recon_features(gt_raw: dict, raw_predictions: dict) -> dict: """Build feature dict for reconstruct_score from raw ground-truth and predictions.""" recon_features = {} for raw_key, tensor in gt_raw.items(): recon_features[raw_key] = tensor.numpy() for model_key, output_key in FEATURE_KEY_MAP.items(): if model_key in raw_predictions: val = raw_predictions[model_key] recon_features[output_key] = val.numpy() if isinstance(val, torch.Tensor) else val return recon_features def score_to_mxl(score, path: Path) -> str: """Write a music21 score to MusicXML, return the path string.""" with warnings.catch_warnings(): warnings.simplefilter("ignore") score.write("musicxml", fp=str(path)) return str(path) def compute_per_feature_accuracy( predictions: dict, ground_truth: dict, feature_names: list ) -> dict: """Compute per-feature token accuracy.""" accs = {} for name in feature_names: if name in predictions and name in ground_truth: pred = predictions[name] if isinstance(predictions[name], torch.Tensor) else torch.tensor(predictions[name]) gt = ground_truth[name] if isinstance(ground_truth[name], torch.Tensor) else torch.tensor(ground_truth[name]) accs[name] = (pred == gt).float().mean().item() return accs def run_reconstruction(input_file, model_name: str): """ Main pipeline: load/convert input -> run model -> reconstruct MusicXML. Accepts either a pre-processed score graph (.pt) or a MusicXML file (.mxl / .musicxml / .xml). MusicXML files are converted on-the-fly. Returns ------- (gt_first, recon_first, gt_orig_pages, gt_recon_pages, recon_pages, has_original_mxl, num_pages, mxl_path, accuracy_md, info_md) """ empty = (None, None, [], [], [], False, 1, None, "", "") if input_file is None: return empty # Clean tmp directory if TMP_DIR.exists(): shutil.rmtree(TMP_DIR) TMP_DIR.mkdir(exist_ok=True) model_key = MODELS[model_name] model = get_model(model_key) # Load or convert input input_path = Path(input_file) if isinstance(input_file, str) else Path(input_file.name) is_mxl = input_path.suffix.lower() in MXL_EXTENSIONS original_mxl_path = str(input_path) if is_mxl else None # For .pt graphs, check for a paired original .mxl alongside if not is_mxl: paired_mxl = input_path.with_suffix(".mxl") if paired_mxl.exists(): original_mxl_path = str(paired_mxl) if is_mxl: try: graph_data = mxl_to_graph_data( str(input_path), duration_vocab_fwd, grid_vocab_fwd, micro_vocab_fwd, ) except Exception as e: return (*empty[:8], f"**Conversion error**: {e}", "") else: graph_data = torch.load(input_path, map_location="cpu", weights_only=False) num_notes = graph_data["num_notes"] source = graph_data.get("source_file", "unknown") # Clean source path: keep only from /mxl/ onwards, prefix with pdmx if "/mxl/" in source: source = "pdmx" + source[source.index("/mxl/"):] feature_names = [f.name for f in model.config.input_features] # Extract ground truth features gt_features, entity_ids, gt_raw = extract_features_from_graph( graph_data, feature_names, identifier_pool_size=model.config.identifier_pool_size, id_assignment="sequential", ) # Run model features_batch = {k: v.unsqueeze(0) for k, v in gt_features.items()} entity_ids_batch = entity_ids.unsqueeze(0) mask = torch.ones(1, num_notes, dtype=torch.bool) kwargs = {} if getattr(model.config, "use_hgt", False): from src.model.note_hgt import NoteHGT edge_dict = NoteHGT.extract_edge_dict(graph_data["graph"]) kwargs["edge_dicts"] = [edge_dict] with torch.no_grad(): logits = model(features_batch, entity_ids_batch, mask=mask, **kwargs) predictions = { name: logits[name][0].argmax(dim=-1).cpu() for name in logits.keys() } # ---- Accuracy ---- accs = compute_per_feature_accuracy(predictions, gt_features, feature_names) core_names = [ "grid_position", "micro_offset", "measure_idx", "pitch_step", "pitch_alter", "pitch_octave", "duration", "staff", "voice", ] core_correct = None for name in core_names: if name in predictions and name in gt_features: match = (predictions[name] == gt_features[name]) core_correct = match if core_correct is None else (core_correct & match) core_joint = core_correct.float().mean().item() if core_correct is not None else 0.0 all_correct = None for name in feature_names: if name in predictions and name in gt_features: match = (predictions[name] == gt_features[name]) all_correct = match if all_correct is None else (all_correct & match) all_joint = all_correct.float().mean().item() if all_correct is not None else 0.0 # Horizontal accuracy table header_cells = " | ".join( FEATURE_SHORT_NAMES.get(n, n) for n in feature_names if n in accs ) value_cells = " | ".join( f"{accs[n]*100:.1f}%" for n in feature_names if n in accs ) accuracy_md = ( f"| {header_cells} | **core** | **all** |\n" f"|{'---|' * (sum(1 for n in feature_names if n in accs) + 2)}\n" f"| {value_cells} | **{core_joint*100:.1f}%** | **{all_joint*100:.1f}%** |" ) # ---- Reconstruct MusicXML (predicted) ---- raw_predictions = undo_feature_shifts(predictions) recon_features = build_recon_features(gt_raw, raw_predictions) mxl_path = None recon_pages: list[str] = [] try: recon_score = reconstruct_score( recon_features, grid_vocab_inv, micro_vocab_inv, duration_vocab_inv, verbose=False, ) mxl_path = score_to_mxl(recon_score, TMP_DIR / "reconstructed.musicxml") recon_pages = render_score_to_svg_pages( mxl_path, str(TMP_DIR / "reconstructed"), ) except Exception as e: accuracy_md += f"\n\n**Reconstruction error**: {e}" # ---- Render ground truth: original MXL (if available) ---- gt_orig_pages: list[str] = [] if original_mxl_path is not None: try: gt_orig_pages = render_score_to_svg_pages( original_mxl_path, str(TMP_DIR / "gt_original"), ) except Exception: pass # ---- Render ground truth: reconstructed from features (always) ---- gt_recon_pages: list[str] = [] gt_raw_np = {k: v.numpy() if isinstance(v, torch.Tensor) else v for k, v in gt_raw.items()} try: gt_score = reconstruct_score( gt_raw_np, grid_vocab_inv, micro_vocab_inv, duration_vocab_inv, verbose=False, ) gt_mxl_path = score_to_mxl(gt_score, TMP_DIR / "gt_reconstructed.musicxml") gt_recon_pages = render_score_to_svg_pages( gt_mxl_path, str(TMP_DIR / "gt_reconstructed"), ) except Exception: pass # ground-truth rendering is best-effort has_original_mxl = bool(gt_orig_pages) # Default: show original MXL if available, otherwise features-reconstructed gt_pages = gt_orig_pages if has_original_mxl else gt_recon_pages info_md = ( f"**Source**: {source}  |  " f"**Notes**: {num_notes}  |  " f"**Model**: {model_name}" ) if is_mxl and graph_data.get("truncated", False): info_md += ( f"  |  **Truncated** to {num_notes} notes " f"({graph_data['total_bars']} bars) to fit model limit" ) # Page count = max of the *active* GT variant and the reconstructed pages num_pages = max(len(gt_pages), len(recon_pages), 1) gt_first = gt_pages[0] if gt_pages else None recon_first = recon_pages[0] if recon_pages else None return ( gt_first, recon_first, gt_orig_pages, gt_recon_pages, recon_pages, has_original_mxl, num_pages, mxl_path, accuracy_md, info_md, ) # ============================================================================= # Gradio UI # ============================================================================= def build_demo(): with gr.Blocks( title="A Fixed-size Latent Space Autoencoder for Music Scores", theme=gr.themes.Soft(), ) as demo: gr.Markdown( "# A Fixed-size Latent Space Autoencoder for Music Scores\n" "Upload a MusicXML file (`.mxl` / `.musicxml`) or select an example from [PDMX](https://zenodo.org/records/15571083), " "choose a model, and reconstruct a MusicXML file. Scores are rendered " "with [Verovio](https://www.verovio.org).\n\n" "*Companion demo for: A Fixed-size Latent Space Autoencoder for Music Scores " "(Hendrik Roth, Emmanouil Karystinaios & Gerhard Widmer, JKU Linz) — " "[GitHub](https://github.com/hendrik-roth/score-ae)*" ) # Hidden state for page lists gt_orig_pages_state = gr.State([]) gt_recon_pages_state = gr.State([]) recon_pages_state = gr.State([]) has_orig_mxl_state = gr.State(False) with gr.Row(equal_height=False): # ---- Left column: controls ---- with gr.Column(scale=1, min_width=280): model_dropdown = gr.Dropdown( choices=list(MODELS.keys()), value=list(MODELS.keys())[0], label="Model", ) graph_input = gr.File( label="Upload MusicXML (.mxl / .musicxml / .xml)", file_types=[".mxl", ".musicxml", ".xml", ".pt"], ) example_dropdown = gr.Dropdown( choices=["(none)"] + list(_EXAMPLE_DISPLAY.keys()), value="(none)", label="Or pick an example from our PDMX test set (deduplicated, no licence conflict subset)", ) run_btn = gr.Button("Reconstruct", variant="primary", size="lg") info_output = gr.Markdown() # ---- Right column: download, scores, accuracy ---- with gr.Column(scale=3): with gr.Row(): mxl_output = gr.File(label="Download reconstructed MusicXML", scale=1, min_width=200) gr.Column(scale=2) # spacer with gr.Row(): page_slider = gr.Slider( minimum=1, maximum=1, step=1, value=1, label="Page", visible=False, scale=1, ) gr.Column(scale=2) # spacer gr.Markdown( "*Note: The **Ground Truth** column shows the score rendered from " "our discrete feature representation. When the original MusicXML is " "available, tick \"Raw MXL engraving\" to see the original " "publisher engraving instead.*", ) with gr.Row(): with gr.Column(): with gr.Row(): gr.Markdown("#### Ground Truth") show_orig_mxl = gr.Checkbox( label="Raw MXL engraving", value=True, visible=False, scale=0, ) gt_image = gr.Image( type="filepath", show_label=False, ) with gr.Column(): gr.Markdown("#### Reconstructed") recon_image = gr.Image( type="filepath", show_label=False, ) gr.Markdown("#### Note-level accuracy") accuracy_output = gr.Markdown() # -- Callbacks -- def on_run(graph_file, example_name, model_name): if graph_file is None and example_name != "(none)": # Resolve display name back to filename filename = _EXAMPLE_DISPLAY.get(example_name, example_name) graph_file = str(EXAMPLES_DIR / filename) ( gt_first, recon_first, gt_orig_pgs, gt_recon_pgs, recon_pgs, has_orig, num_pages, mxl, acc_md, info_md, ) = run_reconstruction(graph_file, model_name) slider_update = gr.Slider( minimum=1, maximum=num_pages, step=1, value=1, label="Page", visible=(num_pages > 1), ) checkbox_update = gr.Checkbox( label="Raw MXL engraving", value=True, visible=has_orig, ) return ( gt_first, recon_first, gt_orig_pgs, gt_recon_pgs, recon_pgs, has_orig, slider_update, checkbox_update, mxl, acc_md, info_md, ) def on_page_change(page_num, show_orig, gt_orig_pgs, gt_recon_pgs, recon_pgs, has_orig): idx = int(page_num) - 1 gt_pgs = gt_orig_pgs if (show_orig and has_orig) else gt_recon_pgs gt_img = gt_pgs[idx] if idx < len(gt_pgs) else None recon_img = recon_pgs[idx] if idx < len(recon_pgs) else None return gt_img, recon_img def on_toggle_orig(show_orig, page_num, gt_orig_pgs, gt_recon_pgs, recon_pgs, has_orig): gt_pgs = gt_orig_pgs if (show_orig and has_orig) else gt_recon_pgs # Recompute page count from both active GT and reconstructed num_pages = max(len(gt_pgs), len(recon_pgs), 1) clamped_page = min(int(page_num), num_pages) idx = clamped_page - 1 gt_img = gt_pgs[idx] if idx < len(gt_pgs) else None recon_img = recon_pgs[idx] if idx < len(recon_pgs) else None slider_update = gr.Slider( minimum=1, maximum=num_pages, step=1, value=clamped_page, label="Page", visible=(num_pages > 1), ) return gt_img, recon_img, slider_update run_btn.click( fn=on_run, inputs=[graph_input, example_dropdown, model_dropdown], outputs=[ gt_image, recon_image, gt_orig_pages_state, gt_recon_pages_state, recon_pages_state, has_orig_mxl_state, page_slider, show_orig_mxl, mxl_output, accuracy_output, info_output, ], ) page_slider.change( fn=on_page_change, inputs=[page_slider, show_orig_mxl, gt_orig_pages_state, gt_recon_pages_state, recon_pages_state, has_orig_mxl_state], outputs=[gt_image, recon_image], ) show_orig_mxl.change( fn=on_toggle_orig, inputs=[show_orig_mxl, page_slider, gt_orig_pages_state, gt_recon_pages_state, recon_pages_state, has_orig_mxl_state], outputs=[gt_image, recon_image, page_slider], ) return demo if __name__ == "__main__": demo = build_demo() demo.launch()