piano-eval / models /inference.py
Jai-D's picture
Upload folder using huggingface_hub
bfc6d2a verified
"""A1-Max MuQ inference - MuQ embedding extraction and prediction."""
import numpy as np
import torch
from constants import MODEL_CONFIG
from models.loader import ModelCache
@torch.no_grad()
def extract_muq_embeddings(
audio: np.ndarray,
cache: ModelCache,
layer_start: int = None,
layer_end: int = None,
max_frames: int = None,
) -> torch.Tensor:
"""Extract MuQ embeddings from audio waveform.
Averages hidden states from layers 9-12 (best performing range).
Args:
audio: Audio waveform at 24kHz
cache: Model cache with loaded MuQ model
layer_start: Start layer (inclusive), default 9
layer_end: End layer (exclusive), default 13
max_frames: Maximum frames to keep
Returns:
Embeddings tensor [T, 1024] where T is number of frames
"""
layer_start = layer_start or MODEL_CONFIG["muq_layer_start"]
layer_end = layer_end or MODEL_CONFIG["muq_layer_end"]
max_frames = max_frames or MODEL_CONFIG["max_frames"]
# MuQ expects [B, samples] tensor
wavs = torch.tensor(audio).unsqueeze(0).to(cache.device)
# Get hidden states from all layers
outputs = cache.muq_model(wavs, output_hidden_states=True)
# Average layers 9-12 (indices in hidden_states tuple)
# hidden_states is tuple of [B, T, D] tensors
hidden_states = outputs.hidden_states[layer_start:layer_end]
embeddings = torch.stack(hidden_states, dim=0).mean(dim=0).squeeze(0)
if embeddings.shape[0] > max_frames:
embeddings = embeddings[:max_frames]
return embeddings
@torch.no_grad()
def predict_with_ensemble(
embeddings: torch.Tensor,
cache: ModelCache,
) -> np.ndarray:
"""Get predictions from 4-fold ensemble of A1-Max heads.
Each head uses attention pooling on frame-level embeddings,
then encoder + regression head to predict 6-dim scores.
Args:
embeddings: Frame embeddings [T, D] from MuQ
cache: Model cache with loaded heads
Returns:
Averaged predictions [6] across all folds
"""
if not cache.muq_heads:
raise RuntimeError("No A1-Max heads loaded in cache")
# Get predictions from each fold head
predictions = []
for head in cache.muq_heads:
pred = head(embeddings).cpu().numpy()
predictions.append(pred)
return np.mean(predictions, axis=0)