score-ae / app.py
hroth's picture
Update app.py
6e51e38 verified
raw
history blame
19.1 kB
#!/usr/bin/env python3
"""
Gradio demo for LaM-SLidE Music Score Autoencoder.
Upload a MusicXML file (.mxl / .musicxml) or pick an example, choose a
model, and get back the reconstructed MusicXML plus per-feature accuracy
and a visual comparison of original vs. reconstructed scores.
"""
import shutil
import sys
import warnings
from pathlib import Path
# Add app root to path
sys.path.insert(0, str(Path(__file__).resolve().parent))
warnings.filterwarnings("ignore")
import os
import gradio as gr
import numpy as np
import torch
import verovio
from omegaconf import OmegaConf
VEROVIO_DATA_DIR = os.path.join(os.path.dirname(verovio.__file__), "data")
from src.model.autoencoder import create_autoencoder_from_dict
from inference import (
extract_features_from_graph,
reconstruct_from_graph,
undo_feature_shifts,
)
from reconstruct_mxl import (
reconstruct_score,
load_duration_vocabulary,
load_position_vocabulary,
)
from convert_mxl import mxl_to_graph_data, load_vocab_forward
# =============================================================================
# Paths
# =============================================================================
APP_DIR = Path(__file__).resolve().parent
VOCAB_DIR = APP_DIR / "vocabs"
MODELS_DIR = APP_DIR / "models"
EXAMPLES_DIR = APP_DIR / "examples"
TMP_DIR = APP_DIR / "tmp"
TMP_DIR.mkdir(exist_ok=True)
# Available models
MODELS = {
"Wide Large Factorized (best model)": "wide_large_factorized",
"HGT Wide Factorized (graph-aware)": "hgt_wide_factorized",
}
# Example .pt graph files (each may have a paired .mxl for original rendering)
# Sort by note count (extracted from filename pattern *_<N>notes_*) so the
# dropdown goes from smallest to largest score.
def _note_count_key(p: Path) -> int:
import re
m = re.search(r"_(\d+)notes", p.stem)
return int(m.group(1)) if m else 0
def _example_display_name(p: Path) -> str:
"""Format: '64 notes — Schubert: An die Laute (QmNcR2oAkq)'"""
import re
stem = p.stem
m = re.search(r"_(\d+)notes_(.+)$", stem)
if not m:
return stem
notes = m.group(1)
qm_hash = m.group(2)
# Everything before _<N>notes is the composer_title part
prefix = stem[:m.start()]
# Convert underscores to spaces and title-case
title = prefix.replace("_", " ").title()
return f"{notes} notes \u2014 {title} ({qm_hash})"
EXAMPLES = sorted(
(p for p in EXAMPLES_DIR.iterdir() if p.suffix == ".pt"),
key=_note_count_key,
) if EXAMPLES_DIR.exists() else []
# Map display name -> filename for dropdown
_EXAMPLE_DISPLAY = {_example_display_name(p): p.name for p in EXAMPLES}
_EXAMPLE_LOOKUP = {v: k for k, v in _EXAMPLE_DISPLAY.items()} # filename -> display
# File extensions recognised as MusicXML
MXL_EXTENSIONS = {".mxl", ".musicxml", ".xml"}
# Short display names for the accuracy table
FEATURE_SHORT_NAMES = {
"grid_position": "grid_pos",
"micro_offset": "micro",
"measure_idx": "bar",
"voice": "voice",
"pitch_step": "step",
"pitch_alter": "alter",
"pitch_octave": "oct",
"duration": "dur",
"clef": "clef",
"ts_beats": "ts_b",
"ts_beat_type": "ts_bt",
"key_fifths": "key",
"staff": "staff",
}
# =============================================================================
# Load vocabs once
# =============================================================================
duration_vocab_inv = load_duration_vocabulary(VOCAB_DIR / "duration_vocab.json")
grid_vocab_inv = load_position_vocabulary(VOCAB_DIR / "grid_vocab.json")
micro_vocab_inv = load_position_vocabulary(VOCAB_DIR / "micro_vocab.json")
# Forward vocabs for MXL → graph conversion (loaded from training data)
duration_vocab_fwd = load_vocab_forward(VOCAB_DIR / "duration_vocab.json")
grid_vocab_fwd = load_vocab_forward(VOCAB_DIR / "grid_vocab.json")
micro_vocab_fwd = load_vocab_forward(VOCAB_DIR / "micro_vocab.json")
# =============================================================================
# Model cache
# =============================================================================
_model_cache = {}
def get_model(model_key: str):
"""Load model (cached)."""
if model_key in _model_cache:
return _model_cache[model_key]
model_dir = MODELS_DIR / model_key
config_path = list(model_dir.glob("*.yaml"))[0]
checkpoint_path = model_dir / "best_model.pt"
cfg = OmegaConf.load(config_path)
config_dict = OmegaConf.to_container(cfg.model, resolve=True)
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model = create_autoencoder_from_dict(config_dict)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
model.cpu()
_model_cache[model_key] = model
return model
# =============================================================================
# Score rendering with verovio
# =============================================================================
def render_score_to_svg_pages(
mxl_path: str, prefix: str,
page_width: int = 2100, scale: int = 35,
) -> list[str]:
"""Render a MusicXML file to one SVG file per page. Returns list of paths."""
tk = verovio.toolkit(False)
tk.setResourcePath(VEROVIO_DATA_DIR)
tk.loadFile(mxl_path)
tk.setOptions({
"pageWidth": page_width,
"scale": scale,
"adjustPageHeight": True,
"footer": "none",
"header": "none",
})
tk.redoLayout()
paths = []
for page in range(1, tk.getPageCount() + 1):
svg_path = f"{prefix}_p{page}.svg"
Path(svg_path).write_text(tk.renderToSVG(page), encoding="utf-8")
paths.append(svg_path)
return paths
# =============================================================================
# Core logic
# =============================================================================
FEATURE_KEY_MAP = {
"grid_position": "position_grid_token",
"micro_offset": "position_micro_token",
"duration": "duration_token",
"pitch_step": "pitch_step",
"pitch_alter": "pitch_alter",
"pitch_octave": "pitch_octave",
"measure_idx": "measure_idx",
"voice": "voice",
"staff": "staff",
"clef": "clef",
"ts_beats": "ts_beats",
"ts_beat_type": "ts_beat_type",
"key_fifths": "key_fifths",
}
def build_recon_features(gt_raw: dict, raw_predictions: dict) -> dict:
"""Build feature dict for reconstruct_score from raw ground-truth and predictions."""
recon_features = {}
for raw_key, tensor in gt_raw.items():
recon_features[raw_key] = tensor.numpy()
for model_key, output_key in FEATURE_KEY_MAP.items():
if model_key in raw_predictions:
val = raw_predictions[model_key]
recon_features[output_key] = val.numpy() if isinstance(val, torch.Tensor) else val
return recon_features
def score_to_mxl(score, path: Path) -> str:
"""Write a music21 score to MusicXML, return the path string."""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
score.write("musicxml", fp=str(path))
return str(path)
def compute_per_feature_accuracy(
predictions: dict, ground_truth: dict, feature_names: list
) -> dict:
"""Compute per-feature token accuracy."""
accs = {}
for name in feature_names:
if name in predictions and name in ground_truth:
pred = predictions[name] if isinstance(predictions[name], torch.Tensor) else torch.tensor(predictions[name])
gt = ground_truth[name] if isinstance(ground_truth[name], torch.Tensor) else torch.tensor(ground_truth[name])
accs[name] = (pred == gt).float().mean().item()
return accs
def run_reconstruction(input_file, model_name: str):
"""
Main pipeline: load/convert input -> run model -> reconstruct MusicXML.
Accepts either a pre-processed score graph (.pt) or a MusicXML file
(.mxl / .musicxml / .xml). MusicXML files are converted on-the-fly.
Returns
-------
(gt_first, recon_first, gt_orig_pages, gt_recon_pages, recon_pages,
has_original_mxl, num_pages, mxl_path, accuracy_md, info_md)
"""
empty = (None, None, [], [], [], False, 1, None, "", "")
if input_file is None:
return empty
# Clean tmp directory
if TMP_DIR.exists():
shutil.rmtree(TMP_DIR)
TMP_DIR.mkdir(exist_ok=True)
model_key = MODELS[model_name]
model = get_model(model_key)
# Load or convert input
input_path = Path(input_file) if isinstance(input_file, str) else Path(input_file.name)
is_mxl = input_path.suffix.lower() in MXL_EXTENSIONS
original_mxl_path = str(input_path) if is_mxl else None
# For .pt graphs, check for a paired original .mxl alongside
if not is_mxl:
paired_mxl = input_path.with_suffix(".mxl")
if paired_mxl.exists():
original_mxl_path = str(paired_mxl)
if is_mxl:
try:
graph_data = mxl_to_graph_data(
str(input_path), duration_vocab_fwd,
grid_vocab_fwd, micro_vocab_fwd,
)
except Exception as e:
return (*empty[:8], f"**Conversion error**: {e}", "")
else:
graph_data = torch.load(input_path, map_location="cpu", weights_only=False)
num_notes = graph_data["num_notes"]
source = graph_data.get("source_file", "unknown")
# Clean source path: keep only from /mxl/ onwards, prefix with pdmx
if "/mxl/" in source:
source = "pdmx" + source[source.index("/mxl/"):]
feature_names = [f.name for f in model.config.input_features]
# Extract ground truth features
gt_features, entity_ids, gt_raw = extract_features_from_graph(
graph_data, feature_names,
identifier_pool_size=model.config.identifier_pool_size,
id_assignment="sequential",
)
# Run model
features_batch = {k: v.unsqueeze(0) for k, v in gt_features.items()}
entity_ids_batch = entity_ids.unsqueeze(0)
mask = torch.ones(1, num_notes, dtype=torch.bool)
kwargs = {}
if getattr(model.config, "use_hgt", False):
from src.model.note_hgt import NoteHGT
edge_dict = NoteHGT.extract_edge_dict(graph_data["graph"])
kwargs["edge_dicts"] = [edge_dict]
with torch.no_grad():
logits = model(features_batch, entity_ids_batch, mask=mask, **kwargs)
predictions = {
name: logits[name][0].argmax(dim=-1).cpu()
for name in logits.keys()
}
# ---- Accuracy ----
accs = compute_per_feature_accuracy(predictions, gt_features, feature_names)
core_names = [
"grid_position", "micro_offset", "measure_idx",
"pitch_step", "pitch_alter", "pitch_octave",
"duration", "staff", "voice",
]
core_correct = None
for name in core_names:
if name in predictions and name in gt_features:
match = (predictions[name] == gt_features[name])
core_correct = match if core_correct is None else (core_correct & match)
core_joint = core_correct.float().mean().item() if core_correct is not None else 0.0
all_correct = None
for name in feature_names:
if name in predictions and name in gt_features:
match = (predictions[name] == gt_features[name])
all_correct = match if all_correct is None else (all_correct & match)
all_joint = all_correct.float().mean().item() if all_correct is not None else 0.0
# Horizontal accuracy table
header_cells = " | ".join(
FEATURE_SHORT_NAMES.get(n, n) for n in feature_names if n in accs
)
value_cells = " | ".join(
f"{accs[n]*100:.1f}%" for n in feature_names if n in accs
)
accuracy_md = (
f"| {header_cells} | **core** | **all** |\n"
f"|{'---|' * (sum(1 for n in feature_names if n in accs) + 2)}\n"
f"| {value_cells} | **{core_joint*100:.1f}%** | **{all_joint*100:.1f}%** |"
)
# ---- Reconstruct MusicXML (predicted) ----
raw_predictions = undo_feature_shifts(predictions)
recon_features = build_recon_features(gt_raw, raw_predictions)
mxl_path = None
recon_pages: list[str] = []
try:
recon_score = reconstruct_score(
recon_features, grid_vocab_inv, micro_vocab_inv,
duration_vocab_inv, verbose=False,
)
mxl_path = score_to_mxl(recon_score, TMP_DIR / "reconstructed.musicxml")
recon_pages = render_score_to_svg_pages(
mxl_path, str(TMP_DIR / "reconstructed"),
)
except Exception as e:
accuracy_md += f"\n\n**Reconstruction error**: {e}"
# ---- Render ground truth: original MXL (if available) ----
gt_orig_pages: list[str] = []
if original_mxl_path is not None:
try:
gt_orig_pages = render_score_to_svg_pages(
original_mxl_path, str(TMP_DIR / "gt_original"),
)
except Exception:
pass
# ---- Render ground truth: reconstructed from features (always) ----
gt_recon_pages: list[str] = []
gt_raw_np = {k: v.numpy() if isinstance(v, torch.Tensor) else v for k, v in gt_raw.items()}
try:
gt_score = reconstruct_score(
gt_raw_np, grid_vocab_inv, micro_vocab_inv,
duration_vocab_inv, verbose=False,
)
gt_mxl_path = score_to_mxl(gt_score, TMP_DIR / "gt_reconstructed.musicxml")
gt_recon_pages = render_score_to_svg_pages(
gt_mxl_path, str(TMP_DIR / "gt_reconstructed"),
)
except Exception:
pass # ground-truth rendering is best-effort
has_original_mxl = bool(gt_orig_pages)
# Default: show original MXL if available, otherwise features-reconstructed
gt_pages = gt_orig_pages if has_original_mxl else gt_recon_pages
info_md = (
f"**Source**: {source} &nbsp;|&nbsp; "
f"**Notes**: {num_notes} &nbsp;|&nbsp; "
f"**Model**: {model_name}"
)
if is_mxl and graph_data.get("truncated", False):
info_md += (
f" &nbsp;|&nbsp; **Truncated** to {num_notes} notes "
f"({graph_data['total_bars']} bars) to fit model limit"
)
# Page count = max of the *active* GT variant and the reconstructed pages
num_pages = max(len(gt_pages), len(recon_pages), 1)
gt_first = gt_pages[0] if gt_pages else None
recon_first = recon_pages[0] if recon_pages else None
return (
gt_first, recon_first,
gt_orig_pages, gt_recon_pages, recon_pages,
has_original_mxl, num_pages,
mxl_path, accuracy_md, info_md,
)
# =============================================================================
# Gradio UI
# =============================================================================
def build_demo():
with gr.Blocks(
title="A Fixed-size Latent Space Autoencoder for Music Scores",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown(
"# A Fixed-size Latent Space Autoencoder for Music Scores\n"
"Upload a MusicXML file (`.mxl` / `.musicxml`) or select an example from [PDMX](https://zenodo.org/records/15571083), "
"choose a model, and reconstruct a MusicXML file. Scores are rendered "
"with [Verovio](https://www.verovio.org).\n\n"
"*Companion demo for: A Fixed-size Latent Space Autoencoder for Music Scores "
"(Hendrik Roth, Emmanouil Karystinaios & Gerhard Widmer, JKU Linz) — "
"[GitHub](https://github.com/hendrik-roth/score-ae)*"
)
# Hidden state for page lists
gt_orig_pages_state = gr.State([])
gt_recon_pages_state = gr.State([])
recon_pages_state = gr.State([])
has_orig_mxl_state = gr.State(False)
with gr.Row(equal_height=False):
# ---- Left column: controls ----
with gr.Column(scale=1, min_width=280):
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Model",
)
graph_input = gr.File(
label="Upload MusicXML (.mxl / .musicxml / .xml)",
file_types=[".mxl", ".musicxml", ".xml", ".pt"],
)
example_dropdown = gr.Dropdown(
choices=["(none)"] + list(_EXAMPLE_DISPLAY.keys()),
value="(none)",
label="Or pick an example from our PDMX test set (deduplicated, no licence conflict subset)",
)
run_btn = gr.Button("Reconstruct", variant="primary", size="lg")
info_output = gr.Markdown()
# ---- Right column: download, scores, accuracy ----
with gr.Column(scale=3):
with gr.Row():
mxl_output = gr.File(label="Download reconstructed MusicXML", scale=1, min_width=200)
gr.Column(scale=2) # spacer
with gr.Row():
page_slider = gr.Slider(
minimum=1, maximum=1, step=1, value=1,
label="Page", visible=False, scale=1,
)
gr.Column(scale=2) # spacer
gr.Markdown(
"*Note: The **Ground Truth** column shows the score rendered from "
"our discrete feature representation. When the original MusicXML is "
"available, tick \"Raw MXL engraving\" to see the original "
"publisher engraving instead.*",
)
with gr.Row():
with gr.Column():
with gr.Row():
gr.Markdown("#### Ground Truth")
show_orig_mxl = gr.Checkbox(
label="Raw MXL engraving",
value=True,
visible=False,
scale=0,
)
gt_image = gr.Image(
type="filepath",
show_label=False,
)
with gr.Column():
gr.Markdown("#### Reconstructed")
recon_image = gr.Image(
type="filepath",
show_label=False,
)
gr.Markdown("#### Note-level accuracy")
accuracy_output = gr.Markdown()
# -- Callbacks --
def on_run(graph_file, example_name, model_name):
if graph_file is None and example_name != "(none)":
# Resolve display name back to filename
filename = _EXAMPLE_DISPLAY.get(example_name, example_name)
graph_file = str(EXAMPLES_DIR / filename)
(
gt_first, recon_first,
gt_orig_pgs, gt_recon_pgs, recon_pgs,
has_orig, num_pages,
mxl, acc_md, info_md,
) = run_reconstruction(graph_file, model_name)
slider_update = gr.Slider(
minimum=1, maximum=num_pages, step=1, value=1,
label="Page", visible=(num_pages > 1),
)
checkbox_update = gr.Checkbox(
label="Raw MXL engraving",
value=True,
visible=has_orig,
)
return (
gt_first, recon_first,
gt_orig_pgs, gt_recon_pgs, recon_pgs,
has_orig,
slider_update, checkbox_update,
mxl, acc_md, info_md,
)
def on_page_change(page_num, show_orig, gt_orig_pgs, gt_recon_pgs, recon_pgs, has_orig):
idx = int(page_num) - 1
gt_pgs = gt_orig_pgs if (show_orig and has_orig) else gt_recon_pgs
gt_img = gt_pgs[idx] if idx < len(gt_pgs) else None
recon_img = recon_pgs[idx] if idx < len(recon_pgs) else None
return gt_img, recon_img
def on_toggle_orig(show_orig, page_num, gt_orig_pgs, gt_recon_pgs, recon_pgs, has_orig):
gt_pgs = gt_orig_pgs if (show_orig and has_orig) else gt_recon_pgs
# Recompute page count from both active GT and reconstructed
num_pages = max(len(gt_pgs), len(recon_pgs), 1)
clamped_page = min(int(page_num), num_pages)
idx = clamped_page - 1
gt_img = gt_pgs[idx] if idx < len(gt_pgs) else None
recon_img = recon_pgs[idx] if idx < len(recon_pgs) else None
slider_update = gr.Slider(
minimum=1, maximum=num_pages,
step=1, value=clamped_page,
label="Page", visible=(num_pages > 1),
)
return gt_img, recon_img, slider_update
run_btn.click(
fn=on_run,
inputs=[graph_input, example_dropdown, model_dropdown],
outputs=[
gt_image, recon_image,
gt_orig_pages_state, gt_recon_pages_state, recon_pages_state,
has_orig_mxl_state,
page_slider, show_orig_mxl,
mxl_output, accuracy_output, info_output,
],
)
page_slider.change(
fn=on_page_change,
inputs=[page_slider, show_orig_mxl, gt_orig_pages_state, gt_recon_pages_state, recon_pages_state, has_orig_mxl_state],
outputs=[gt_image, recon_image],
)
show_orig_mxl.change(
fn=on_toggle_orig,
inputs=[show_orig_mxl, page_slider, gt_orig_pages_state, gt_recon_pages_state, recon_pages_state, has_orig_mxl_state],
outputs=[gt_image, recon_image, page_slider],
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch()