| """Audio encoding and iterative unmasking inference. |
| |
| Adapted from midmid/prediction/model.py for standalone use. |
| Device management is caller-controlled (for ZeroGPU compatibility). |
| """ |
|
|
| import itertools as _it |
| import json |
| import math |
| from pathlib import Path |
| from typing import Optional |
|
|
| import numpy as np |
| import torch |
|
|
| from midmid.nn import ( |
| ChartMaskPredictor, ChartMaskPredictorConfig, |
| MASK_TOKEN, SILENCE_TOKEN, |
| ) |
| from midmid.datatypes import NoteEvent |
|
|
| MERT_MODEL_ID = "m-a-p/MERT-v1-95M" |
|
|
| DIFF_ID = {"easy": 0, "medium": 1, "hard": 2, "expert": 3} |
|
|
| |
| _CLASS_TO_FRETS: list[tuple[int, ...]] = [] |
| for _r in range(1, 6): |
| _CLASS_TO_FRETS.extend(_it.combinations(range(5), _r)) |
| _CLASS_TO_FRETS.append((7,)) |
|
|
| |
| _BUCKET_BEATS = [0.0, 1.0, 2.0, 4.0, 8.0, 16.0] |
|
|
|
|
| |
| |
| |
|
|
| def load_model_from_hub( |
| repo_id: str = "markury/midmid3-19m-0326", |
| device: str = "cpu", |
| ) -> ChartMaskPredictor: |
| """Download and load model from HuggingFace Hub (safetensors).""" |
| from huggingface_hub import hf_hub_download |
| from safetensors.torch import load_file |
|
|
| config_path = hf_hub_download(repo_id, "config.json") |
| weights_path = hf_hub_download(repo_id, "model.safetensors") |
|
|
| with open(config_path) as f: |
| config_dict = json.load(f) |
|
|
| config = ChartMaskPredictorConfig(**config_dict) |
| model = ChartMaskPredictor(config) |
|
|
| state_dict = load_file(weights_path, device=device) |
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| |
| |
| |
|
|
| _mert_model = None |
| _mert_processor = None |
| _mert_frame_rate = None |
|
|
|
|
| def _ensure_mert(device: torch.device): |
| """Load MERT model and processor on first use.""" |
| global _mert_model, _mert_processor, _mert_frame_rate |
| if _mert_model is not None: |
| |
| if next(_mert_model.parameters()).device != device: |
| _mert_model.to(device) |
| return |
|
|
| from transformers import AutoModel, Wav2Vec2FeatureExtractor |
|
|
| print(f"Loading MERT ({MERT_MODEL_ID}) ...") |
| _mert_processor = Wav2Vec2FeatureExtractor.from_pretrained( |
| MERT_MODEL_ID, trust_remote_code=True, |
| ) |
| _mert_model = AutoModel.from_pretrained(MERT_MODEL_ID, trust_remote_code=True) |
| _mert_model.to(device) |
| _mert_model.eval() |
|
|
| |
| sr = _mert_processor.sampling_rate |
| test_wav = np.zeros(sr, dtype=np.float32) |
| inputs = _mert_processor(test_wav, sampling_rate=sr, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| out = _mert_model(**inputs, output_hidden_states=False) |
| _mert_frame_rate = float(out.last_hidden_state.shape[1]) |
| print(f" MERT frame rate: {_mert_frame_rate:.2f} Hz") |
|
|
|
|
| def move_models_to_device(device: torch.device): |
| """Move all cached models to the specified device (for ZeroGPU).""" |
| global _mert_model |
| if _mert_model is not None: |
| _mert_model.to(device) |
|
|
|
|
| @torch.no_grad() |
| def encode_audio_mert( |
| audio_path: str, |
| device: torch.device, |
| chunk_sec: float = 60.0, |
| ) -> tuple[torch.Tensor, float]: |
| """Encode audio with MERT, return (embeddings, frame_rate).""" |
| import librosa |
| _ensure_mert(device) |
|
|
| sr = _mert_processor.sampling_rate |
| wav, _ = librosa.load(audio_path, sr=sr, mono=True) |
|
|
| chunk_samples = int(chunk_sec * sr) |
| overlap_sec = 5.0 |
| overlap_samples = int(overlap_sec * sr) |
| stride_samples = chunk_samples - overlap_samples |
|
|
| if len(wav) <= chunk_samples: |
| inputs = _mert_processor(wav, sampling_rate=sr, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| out = _mert_model(**inputs, output_hidden_states=False) |
| return out.last_hidden_state.squeeze(0).cpu(), _mert_frame_rate |
|
|
| |
| all_emb = [] |
| pos = 0 |
| idx = 0 |
| while pos < len(wav): |
| end = min(pos + chunk_samples, len(wav)) |
| chunk = wav[pos:end] |
| min_len = chunk_samples // 4 |
| if len(chunk) < min_len: |
| chunk = np.pad(chunk, (0, min_len - len(chunk))) |
|
|
| inputs = _mert_processor(chunk, sampling_rate=sr, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| out = _mert_model(**inputs, output_hidden_states=False) |
| emb = out.last_hidden_state.squeeze(0) |
|
|
| n = emb.shape[0] |
| fps = n / (len(chunk) / sr) |
| half_overlap = int(round((overlap_sec / 2) * fps)) |
|
|
| if idx == 0: |
| keep = n - half_overlap if end < len(wav) else n |
| all_emb.append(emb[:keep].cpu()) |
| elif end >= len(wav): |
| all_emb.append(emb[half_overlap:].cpu()) |
| else: |
| keep = int(round((len(chunk) / sr - overlap_sec) * fps)) |
| all_emb.append(emb[half_overlap:half_overlap + keep].cpu()) |
|
|
| pos += stride_samples |
| idx += 1 |
|
|
| return torch.cat(all_emb, dim=0), _mert_frame_rate |
|
|
|
|
| |
| |
| |
|
|
| def _build_16th_grid(fretbars): |
| """Build 16th-note timestamps (ms) from beat positions.""" |
| if len(fretbars) < 2: |
| return list(fretbars) |
| positions = [] |
| for i in range(len(fretbars) - 1): |
| start = fretbars[i] |
| interval = fretbars[i + 1] - start |
| for sub in range(4): |
| positions.append(start + sub * interval / 4.0) |
| positions.append(fretbars[-1]) |
| return positions |
|
|
|
|
| def _get_local_beat_ms(grid_idx, fretbars): |
| beat_idx = min(grid_idx // 4, len(fretbars) - 2) |
| beat_idx = max(0, beat_idx) |
| if beat_idx + 1 < len(fretbars): |
| return fretbars[beat_idx + 1] - fretbars[beat_idx] |
| return 500.0 |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def predict_notes( |
| audio_path: str, |
| model: ChartMaskPredictor, |
| beat_times: list[float], |
| difficulty: str = "expert", |
| device: torch.device = None, |
| num_steps: int = 12, |
| temperature: float = 0.9, |
| ) -> list[NoteEvent]: |
| """MaskGIT-style iterative unmasking inference.""" |
| if device is None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| dev = device |
| model.to(dev) |
| model.eval() |
|
|
| fretbars = [t * 1000.0 for t in beat_times] |
| if len(fretbars) < 2: |
| return [] |
|
|
| |
| embeddings, frame_rate = encode_audio_mert(audio_path, dev) |
|
|
| |
| grid_times = _build_16th_grid(fretbars) |
| num_positions = len(grid_times) |
| max_frame = embeddings.shape[0] - 1 |
| frame_indices = torch.tensor( |
| [min(int(round(t / 1000.0 * frame_rate)), max_frame) |
| for t in grid_times], dtype=torch.long, |
| ) |
|
|
| window = 2 |
| if window > 0 and max_frame >= window * 2: |
| padded = torch.nn.functional.pad( |
| embeddings.unsqueeze(0), (0, 0, window, window), mode="replicate", |
| ).squeeze(0) |
| shifted = frame_indices + window |
| stacked = torch.stack( |
| [padded[shifted + d] for d in range(-window, window + 1)], dim=0, |
| ) |
| grid_emb = stacked.mean(dim=0) |
| else: |
| grid_emb = embeddings[frame_indices] |
|
|
| |
| if model.config.audio_dim > grid_emb.shape[-1]: |
| import librosa as _lr |
| wav, _ = _lr.load(audio_path, sr=24000, mono=True) |
| hop = 320 |
| onset = _lr.onset.onset_strength(y=wav, sr=24000, hop_length=hop) |
| rms_arr = _lr.feature.rms(y=wav, hop_length=hop)[0] |
| centroid = _lr.feature.spectral_centroid(y=wav, sr=24000, hop_length=hop)[0] |
|
|
| def _norm(x): |
| mn, mx = x.min(), x.max() |
| return (x - mn) / max(mx - mn, 1e-8) |
|
|
| onset, rms_arr, centroid = _norm(onset), _norm(rms_arr), _norm(centroid) |
| af_rate = 24000 / hop |
| af_max = len(onset) - 1 |
| af_indices = [min(int(round(t / 1000.0 * af_rate)), af_max) for t in grid_times] |
| af_tensor = torch.tensor( |
| [[onset[i], rms_arr[i], centroid[i]] for i in af_indices], |
| dtype=torch.float32, |
| ) |
| grid_emb = torch.cat([grid_emb, af_tensor], dim=-1) |
|
|
| audio_features = grid_emb.unsqueeze(0).to(dev) |
|
|
| diff_id = DIFF_ID.get(difficulty, 3) |
| diff_tensor = torch.tensor([diff_id], dtype=torch.long, device=dev) |
| padding_mask = torch.ones(1, num_positions, dtype=torch.bool, device=dev) |
|
|
| |
| chart_tokens = torch.full( |
| (1, num_positions), MASK_TOKEN, dtype=torch.long, device=dev, |
| ) |
|
|
| |
| schedule = [] |
| for step in range(num_steps): |
| r_prev = math.cos(math.pi / 2 * step / num_steps) |
| r_next = math.cos(math.pi / 2 * (step + 1) / num_steps) |
| n_unmask = max(1, int((r_prev - r_next) * num_positions)) |
| schedule.append(n_unmask) |
|
|
| |
| for step in range(num_steps): |
| outputs = model(audio_features, chart_tokens, diff_tensor, padding_mask) |
| token_logits = outputs["token_logits"].squeeze(0) |
|
|
| is_masked = (chart_tokens.squeeze(0) == MASK_TOKEN) |
| masked_indices = is_masked.nonzero(as_tuple=True)[0] |
|
|
| if len(masked_indices) == 0: |
| break |
|
|
| probs = torch.softmax(token_logits / temperature, dim=-1) |
| sampled = torch.multinomial(probs, num_samples=1).squeeze(-1) |
|
|
| n_unmask = min(schedule[step], len(masked_indices)) |
| perm = torch.randperm(len(masked_indices), device=dev) |
| unmask_idx = masked_indices[perm[:n_unmask]] |
| chart_tokens[0, unmask_idx] = sampled[unmask_idx] |
|
|
| |
| outputs = model(audio_features, chart_tokens, diff_tensor, padding_mask) |
| sustain_prob = outputs["sustain_logits"].squeeze(0).squeeze(-1).sigmoid() |
| dur_pred = outputs["duration_logits"].squeeze(0).argmax(dim=-1) |
|
|
| |
| tokens = chart_tokens.squeeze(0).cpu() |
| notes = [] |
| for i in range(num_positions): |
| tok = tokens[i].item() |
| if tok >= SILENCE_TOKEN or tok < 0: |
| continue |
|
|
| fret_set = set(_CLASS_TO_FRETS[tok]) |
| if not fret_set: |
| continue |
|
|
| sustain_ticks = 0 |
| if sustain_prob[i] >= 0.5: |
| bucket = dur_pred[i].item() |
| beat_ms = _get_local_beat_ms(i, fretbars) |
| sustain_ticks = _BUCKET_BEATS[bucket] * beat_ms |
|
|
| notes.append(NoteEvent( |
| tick=i, |
| fret_set=fret_set, |
| sustain_ticks=sustain_ticks, |
| )) |
|
|
| return notes |
|
|