midmid3 / midmid /inference.py
markury's picture
Initial commit
d171350
"""Audio encoding and iterative unmasking inference.
Adapted from midmid/prediction/model.py for standalone use.
Device management is caller-controlled (for ZeroGPU compatibility).
"""
import itertools as _it
import json
import math
from pathlib import Path
from typing import Optional
import numpy as np
import torch
from midmid.nn import (
ChartMaskPredictor, ChartMaskPredictorConfig,
MASK_TOKEN, SILENCE_TOKEN,
)
from midmid.datatypes import NoteEvent
MERT_MODEL_ID = "m-a-p/MERT-v1-95M"
DIFF_ID = {"easy": 0, "medium": 1, "hard": 2, "expert": 3}
# Class ID -> fret tuple
_CLASS_TO_FRETS: list[tuple[int, ...]] = []
for _r in range(1, 6):
_CLASS_TO_FRETS.extend(_it.combinations(range(5), _r))
_CLASS_TO_FRETS.append((7,)) # class 31 = open
# Sustain bucket center values in beats
_BUCKET_BEATS = [0.0, 1.0, 2.0, 4.0, 8.0, 16.0]
# ---------------------------------------------------------------------------
# Model loading (safetensors from HF Hub)
# ---------------------------------------------------------------------------
def load_model_from_hub(
repo_id: str = "markury/midmid3-19m-0326",
device: str = "cpu",
) -> ChartMaskPredictor:
"""Download and load model from HuggingFace Hub (safetensors)."""
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
config_path = hf_hub_download(repo_id, "config.json")
weights_path = hf_hub_download(repo_id, "model.safetensors")
with open(config_path) as f:
config_dict = json.load(f)
config = ChartMaskPredictorConfig(**config_dict)
model = ChartMaskPredictor(config)
state_dict = load_file(weights_path, device=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
# ---------------------------------------------------------------------------
# MERT audio encoding (lazy-loaded)
# ---------------------------------------------------------------------------
_mert_model = None
_mert_processor = None
_mert_frame_rate = None
def _ensure_mert(device: torch.device):
"""Load MERT model and processor on first use."""
global _mert_model, _mert_processor, _mert_frame_rate
if _mert_model is not None:
# Move to correct device if needed
if next(_mert_model.parameters()).device != device:
_mert_model.to(device)
return
from transformers import AutoModel, Wav2Vec2FeatureExtractor
print(f"Loading MERT ({MERT_MODEL_ID}) ...")
_mert_processor = Wav2Vec2FeatureExtractor.from_pretrained(
MERT_MODEL_ID, trust_remote_code=True,
)
_mert_model = AutoModel.from_pretrained(MERT_MODEL_ID, trust_remote_code=True)
_mert_model.to(device)
_mert_model.eval()
# Compute frame rate dynamically
sr = _mert_processor.sampling_rate
test_wav = np.zeros(sr, dtype=np.float32)
inputs = _mert_processor(test_wav, sampling_rate=sr, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
out = _mert_model(**inputs, output_hidden_states=False)
_mert_frame_rate = float(out.last_hidden_state.shape[1])
print(f" MERT frame rate: {_mert_frame_rate:.2f} Hz")
def move_models_to_device(device: torch.device):
"""Move all cached models to the specified device (for ZeroGPU)."""
global _mert_model
if _mert_model is not None:
_mert_model.to(device)
@torch.no_grad()
def encode_audio_mert(
audio_path: str,
device: torch.device,
chunk_sec: float = 60.0,
) -> tuple[torch.Tensor, float]:
"""Encode audio with MERT, return (embeddings, frame_rate)."""
import librosa
_ensure_mert(device)
sr = _mert_processor.sampling_rate
wav, _ = librosa.load(audio_path, sr=sr, mono=True)
chunk_samples = int(chunk_sec * sr)
overlap_sec = 5.0
overlap_samples = int(overlap_sec * sr)
stride_samples = chunk_samples - overlap_samples
if len(wav) <= chunk_samples:
inputs = _mert_processor(wav, sampling_rate=sr, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
out = _mert_model(**inputs, output_hidden_states=False)
return out.last_hidden_state.squeeze(0).cpu(), _mert_frame_rate
# Chunked processing for long audio
all_emb = []
pos = 0
idx = 0
while pos < len(wav):
end = min(pos + chunk_samples, len(wav))
chunk = wav[pos:end]
min_len = chunk_samples // 4
if len(chunk) < min_len:
chunk = np.pad(chunk, (0, min_len - len(chunk)))
inputs = _mert_processor(chunk, sampling_rate=sr, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
out = _mert_model(**inputs, output_hidden_states=False)
emb = out.last_hidden_state.squeeze(0)
n = emb.shape[0]
fps = n / (len(chunk) / sr)
half_overlap = int(round((overlap_sec / 2) * fps))
if idx == 0:
keep = n - half_overlap if end < len(wav) else n
all_emb.append(emb[:keep].cpu())
elif end >= len(wav):
all_emb.append(emb[half_overlap:].cpu())
else:
keep = int(round((len(chunk) / sr - overlap_sec) * fps))
all_emb.append(emb[half_overlap:half_overlap + keep].cpu())
pos += stride_samples
idx += 1
return torch.cat(all_emb, dim=0), _mert_frame_rate
# ---------------------------------------------------------------------------
# Grid helpers
# ---------------------------------------------------------------------------
def _build_16th_grid(fretbars):
"""Build 16th-note timestamps (ms) from beat positions."""
if len(fretbars) < 2:
return list(fretbars)
positions = []
for i in range(len(fretbars) - 1):
start = fretbars[i]
interval = fretbars[i + 1] - start
for sub in range(4):
positions.append(start + sub * interval / 4.0)
positions.append(fretbars[-1])
return positions
def _get_local_beat_ms(grid_idx, fretbars):
beat_idx = min(grid_idx // 4, len(fretbars) - 2)
beat_idx = max(0, beat_idx)
if beat_idx + 1 < len(fretbars):
return fretbars[beat_idx + 1] - fretbars[beat_idx]
return 500.0
# ---------------------------------------------------------------------------
# Main inference
# ---------------------------------------------------------------------------
@torch.no_grad()
def predict_notes(
audio_path: str,
model: ChartMaskPredictor,
beat_times: list[float],
difficulty: str = "expert",
device: torch.device = None,
num_steps: int = 12,
temperature: float = 0.9,
) -> list[NoteEvent]:
"""MaskGIT-style iterative unmasking inference."""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dev = device
model.to(dev)
model.eval()
fretbars = [t * 1000.0 for t in beat_times]
if len(fretbars) < 2:
return []
# MERT embeddings
embeddings, frame_rate = encode_audio_mert(audio_path, dev)
# Build grid and sample MERT frames with windowing
grid_times = _build_16th_grid(fretbars)
num_positions = len(grid_times)
max_frame = embeddings.shape[0] - 1
frame_indices = torch.tensor(
[min(int(round(t / 1000.0 * frame_rate)), max_frame)
for t in grid_times], dtype=torch.long,
)
window = 2
if window > 0 and max_frame >= window * 2:
padded = torch.nn.functional.pad(
embeddings.unsqueeze(0), (0, 0, window, window), mode="replicate",
).squeeze(0)
shifted = frame_indices + window
stacked = torch.stack(
[padded[shifted + d] for d in range(-window, window + 1)], dim=0,
)
grid_emb = stacked.mean(dim=0)
else:
grid_emb = embeddings[frame_indices]
# Compute and concat audio features if model expects them
if model.config.audio_dim > grid_emb.shape[-1]:
import librosa as _lr
wav, _ = _lr.load(audio_path, sr=24000, mono=True)
hop = 320
onset = _lr.onset.onset_strength(y=wav, sr=24000, hop_length=hop)
rms_arr = _lr.feature.rms(y=wav, hop_length=hop)[0]
centroid = _lr.feature.spectral_centroid(y=wav, sr=24000, hop_length=hop)[0]
def _norm(x):
mn, mx = x.min(), x.max()
return (x - mn) / max(mx - mn, 1e-8)
onset, rms_arr, centroid = _norm(onset), _norm(rms_arr), _norm(centroid)
af_rate = 24000 / hop
af_max = len(onset) - 1
af_indices = [min(int(round(t / 1000.0 * af_rate)), af_max) for t in grid_times]
af_tensor = torch.tensor(
[[onset[i], rms_arr[i], centroid[i]] for i in af_indices],
dtype=torch.float32,
)
grid_emb = torch.cat([grid_emb, af_tensor], dim=-1)
audio_features = grid_emb.unsqueeze(0).to(dev)
diff_id = DIFF_ID.get(difficulty, 3)
diff_tensor = torch.tensor([diff_id], dtype=torch.long, device=dev)
padding_mask = torch.ones(1, num_positions, dtype=torch.bool, device=dev)
# Start fully masked
chart_tokens = torch.full(
(1, num_positions), MASK_TOKEN, dtype=torch.long, device=dev,
)
# Cosine unmasking schedule
schedule = []
for step in range(num_steps):
r_prev = math.cos(math.pi / 2 * step / num_steps)
r_next = math.cos(math.pi / 2 * (step + 1) / num_steps)
n_unmask = max(1, int((r_prev - r_next) * num_positions))
schedule.append(n_unmask)
# Iterative unmasking
for step in range(num_steps):
outputs = model(audio_features, chart_tokens, diff_tensor, padding_mask)
token_logits = outputs["token_logits"].squeeze(0)
is_masked = (chart_tokens.squeeze(0) == MASK_TOKEN)
masked_indices = is_masked.nonzero(as_tuple=True)[0]
if len(masked_indices) == 0:
break
probs = torch.softmax(token_logits / temperature, dim=-1)
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
n_unmask = min(schedule[step], len(masked_indices))
perm = torch.randperm(len(masked_indices), device=dev)
unmask_idx = masked_indices[perm[:n_unmask]]
chart_tokens[0, unmask_idx] = sampled[unmask_idx]
# Final pass for sustain predictions
outputs = model(audio_features, chart_tokens, diff_tensor, padding_mask)
sustain_prob = outputs["sustain_logits"].squeeze(0).squeeze(-1).sigmoid()
dur_pred = outputs["duration_logits"].squeeze(0).argmax(dim=-1)
# Convert tokens to NoteEvents
tokens = chart_tokens.squeeze(0).cpu()
notes = []
for i in range(num_positions):
tok = tokens[i].item()
if tok >= SILENCE_TOKEN or tok < 0:
continue
fret_set = set(_CLASS_TO_FRETS[tok])
if not fret_set:
continue
sustain_ticks = 0
if sustain_prob[i] >= 0.5:
bucket = dur_pred[i].item()
beat_ms = _get_local_beat_ms(i, fretbars)
sustain_ticks = _BUCKET_BEATS[bucket] * beat_ms
notes.append(NoteEvent(
tick=i,
fret_set=fret_set,
sustain_ticks=sustain_ticks,
))
return notes