File size: 9,968 Bytes
bfc6d2a 1f5895b bfc6d2a 1f5895b bfc6d2a 1f5895b bfc6d2a 1f5895b bfc6d2a 1f5895b bfc6d2a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 | """HuggingFace Inference Endpoints handler for piano performance analysis.
A1-Max MuQ LoRA model using MuQ layers 9-12 with attention pooling.
Returns 6-dimension performance evaluation scores:
dynamics, timing, pedaling, articulation, phrasing, interpretation.
Compatible with HuggingFace Inference Endpoints custom handler pattern.
"""
import base64
import time
import traceback
from pathlib import Path
from typing import Any, Dict, Union
import numpy as np
from constants import MODEL_INFO, PERCEPIANO_DIMENSIONS
from models.loader import get_model_cache
from models.inference import (
extract_muq_embeddings,
predict_with_ensemble,
)
from models.transcription import TranscriptionModel, TranscriptionError
from preprocessing.audio import (
AudioDownloadError,
AudioProcessingError,
download_and_preprocess_audio,
preprocess_audio_from_bytes,
)
class EndpointHandler:
"""HuggingFace Inference Endpoints handler for piano performance analysis."""
def __init__(self, path: str = ""):
"""Initialize MuQ model and prediction heads.
Called once when the endpoint container starts.
Args:
path: Path to the model repository (provided by HF Inference Endpoints).
Contains the checkpoints/ directory with model weights.
"""
print(f"Initializing A1-Max EndpointHandler with path: {path}")
# Determine checkpoint directory
# HF Inference Endpoints mount the repo at the provided path
# Fall back to /repository (HF default) or current dir for local testing
if path:
model_path = Path(path)
else:
model_path = Path("/repository")
if not model_path.exists():
model_path = Path(".")
checkpoint_dir = model_path / "checkpoints"
if not checkpoint_dir.exists():
# Try /app/checkpoints for backward compatibility
checkpoint_dir = Path("/app/checkpoints")
print(f"Using checkpoint directory: {checkpoint_dir}")
# Initialize model cache (loads MuQ and prediction heads)
self._cache = get_model_cache()
self._cache.initialize(device="cuda", checkpoint_dir=checkpoint_dir)
# Initialize AMT transcription model
print("Loading ByteDance AMT model...")
self._transcription = TranscriptionModel(device="cuda")
print("A1-Max EndpointHandler initialization complete!")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Process inference request.
Args:
data: Request payload. Supports two formats:
HuggingFace format:
{
"inputs": "<base64-audio>" or {"audio_url": "..."},
"parameters": {
"max_duration_seconds": 300
}
}
Legacy RunPod format (for backward compatibility):
{
"input": {
"audio_url": "https://...",
"options": {...}
}
}
Returns:
Prediction results:
{
"predictions": {"timing": 0.85, ...},
"model_info": {"name": "M1c-MuQ-L9-12", "r2": 0.539},
"audio_duration_seconds": 180.5,
"processing_time_ms": 1234
}
Or error:
{
"error": {"code": "...", "message": "..."}
}
"""
start_time = time.time()
try:
# Parse input - support both HF and legacy RunPod formats
inputs, parameters = self._parse_request(data)
# Extract parameters
max_duration = parameters.get("max_duration_seconds", 300)
# Load and preprocess audio
audio, duration = self._load_audio(inputs, max_duration)
print(f"Audio loaded: {duration:.1f}s")
# Verify models are loaded
if not self._cache.muq_model:
return {
"error": {
"code": "MODEL_NOT_LOADED",
"message": "MuQ model not initialized",
}
}
# Extract MuQ embeddings (averaged layers 9-12)
print("Extracting MuQ embeddings (layers 9-12)...")
embeddings = extract_muq_embeddings(audio, self._cache)
print(f"MuQ embeddings shape: {embeddings.shape}")
# Get ensemble predictions (4-fold A1-Max)
print("Running A1-Max ensemble inference...")
predictions = predict_with_ensemble(embeddings, self._cache)
# Run AMT transcription (after MuQ scoring, sequential)
midi_notes = None
transcription_info = None
amt_error = None
try:
print("Running AMT transcription...")
amt_start = time.time()
midi_notes = self._transcription.transcribe(audio, 24000)
amt_elapsed_ms = int((time.time() - amt_start) * 1000)
pitches = [n["pitch"] for n in midi_notes]
transcription_info = {
"note_count": len(midi_notes),
"pitch_range": [min(pitches), max(pitches)] if pitches else [0, 0],
"transcription_time_ms": amt_elapsed_ms,
}
except TranscriptionError as e:
print(f"AMT failed (graceful degradation): {e}")
amt_error = str(e)
# Build combined response
processing_time_ms = int((time.time() - start_time) * 1000)
result = {
"predictions": self._predictions_to_dict(predictions),
"midi_notes": midi_notes,
"transcription_info": transcription_info,
"model_info": {
"name": MODEL_INFO["name"],
"type": MODEL_INFO["type"],
"pairwise": MODEL_INFO["pairwise"],
"architecture": MODEL_INFO["architecture"],
"ensemble_folds": len(self._cache.muq_heads),
},
"audio_duration_seconds": duration,
"processing_time_ms": processing_time_ms,
}
if amt_error:
result["amt_error"] = amt_error
print(f"Inference complete in {processing_time_ms}ms")
return result
except AudioDownloadError as e:
return {
"error": {
"code": "AUDIO_DOWNLOAD_FAILED",
"message": str(e),
}
}
except AudioProcessingError as e:
return {
"error": {
"code": "AUDIO_PROCESSING_FAILED",
"message": str(e),
}
}
except Exception as e:
return {
"error": {
"code": "INFERENCE_ERROR",
"message": str(e),
"traceback": traceback.format_exc(),
}
}
def _parse_request(self, data: Dict[str, Any]) -> tuple:
"""Parse request data supporting both HF and legacy formats.
Returns:
Tuple of (inputs, parameters)
"""
# HF format: {"inputs": ..., "parameters": ...}
if "inputs" in data:
inputs = data["inputs"]
parameters = data.get("parameters", {})
return inputs, parameters
# Legacy RunPod format: {"input": {"audio_url": ..., "options": ...}}
if "input" in data:
job_input = data["input"]
inputs = {
"audio_url": job_input.get("audio_url"),
"performance_id": job_input.get("performance_id", "unknown"),
}
parameters = job_input.get("options", {})
parameters["performance_id"] = inputs.get("performance_id", "unknown")
return inputs, parameters
# Fallback: treat entire data as inputs
return data, {}
def _load_audio(
self, inputs: Union[str, bytes, Dict[str, Any]], max_duration: int
) -> tuple:
"""Load audio from various input formats.
Args:
inputs: One of:
- str: Base64-encoded audio bytes
- bytes: Raw audio bytes
- dict: {"audio_url": "..."} for URL-based loading
Returns:
Tuple of (audio_array, duration_seconds)
"""
if isinstance(inputs, str):
# Base64-encoded audio
try:
audio_bytes = base64.b64decode(inputs)
return preprocess_audio_from_bytes(audio_bytes, max_duration=max_duration)
except Exception:
# Maybe it's a URL string
if inputs.startswith("http"):
return download_and_preprocess_audio(inputs, max_duration=max_duration)
raise AudioProcessingError("Invalid input string: not base64 or URL")
elif isinstance(inputs, bytes):
# Raw bytes
return preprocess_audio_from_bytes(inputs, max_duration=max_duration)
elif isinstance(inputs, dict):
# URL-based input
audio_url = inputs.get("audio_url")
if not audio_url:
raise AudioProcessingError("No audio_url provided in inputs")
return download_and_preprocess_audio(audio_url, max_duration=max_duration)
else:
raise AudioProcessingError(f"Unsupported input type: {type(inputs)}")
def _predictions_to_dict(self, preds: np.ndarray) -> Dict[str, float]:
"""Convert prediction array to dimension dict."""
return {dim: float(preds[i]) for i, dim in enumerate(PERCEPIANO_DIMENSIONS)}
|