| """HuggingFace Inference Endpoints handler for piano performance analysis. |
| |
| A1-Max MuQ LoRA model using MuQ layers 9-12 with attention pooling. |
| Returns 6-dimension performance evaluation scores: |
| dynamics, timing, pedaling, articulation, phrasing, interpretation. |
| |
| Compatible with HuggingFace Inference Endpoints custom handler pattern. |
| """ |
|
|
| import base64 |
| import time |
| import traceback |
| from pathlib import Path |
| from typing import Any, Dict, Union |
|
|
| import numpy as np |
|
|
| from constants import MODEL_INFO, PERCEPIANO_DIMENSIONS |
| from models.loader import get_model_cache |
| from models.inference import ( |
| extract_muq_embeddings, |
| predict_with_ensemble, |
| ) |
| from models.transcription import TranscriptionModel, TranscriptionError |
| from preprocessing.audio import ( |
| AudioDownloadError, |
| AudioProcessingError, |
| download_and_preprocess_audio, |
| preprocess_audio_from_bytes, |
| ) |
|
|
|
|
| class EndpointHandler: |
| """HuggingFace Inference Endpoints handler for piano performance analysis.""" |
|
|
| def __init__(self, path: str = ""): |
| """Initialize MuQ model and prediction heads. |
| |
| Called once when the endpoint container starts. |
| |
| Args: |
| path: Path to the model repository (provided by HF Inference Endpoints). |
| Contains the checkpoints/ directory with model weights. |
| """ |
| print(f"Initializing A1-Max EndpointHandler with path: {path}") |
|
|
| |
| |
| |
| if path: |
| model_path = Path(path) |
| else: |
| model_path = Path("/repository") |
| if not model_path.exists(): |
| model_path = Path(".") |
|
|
| checkpoint_dir = model_path / "checkpoints" |
| if not checkpoint_dir.exists(): |
| |
| checkpoint_dir = Path("/app/checkpoints") |
|
|
| print(f"Using checkpoint directory: {checkpoint_dir}") |
|
|
| |
| self._cache = get_model_cache() |
| self._cache.initialize(device="cuda", checkpoint_dir=checkpoint_dir) |
|
|
| |
| print("Loading ByteDance AMT model...") |
| self._transcription = TranscriptionModel(device="cuda") |
|
|
| print("A1-Max EndpointHandler initialization complete!") |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """Process inference request. |
| |
| Args: |
| data: Request payload. Supports two formats: |
| |
| HuggingFace format: |
| { |
| "inputs": "<base64-audio>" or {"audio_url": "..."}, |
| "parameters": { |
| "max_duration_seconds": 300 |
| } |
| } |
| |
| Legacy RunPod format (for backward compatibility): |
| { |
| "input": { |
| "audio_url": "https://...", |
| "options": {...} |
| } |
| } |
| |
| Returns: |
| Prediction results: |
| { |
| "predictions": {"timing": 0.85, ...}, |
| "model_info": {"name": "M1c-MuQ-L9-12", "r2": 0.539}, |
| "audio_duration_seconds": 180.5, |
| "processing_time_ms": 1234 |
| } |
| |
| Or error: |
| { |
| "error": {"code": "...", "message": "..."} |
| } |
| """ |
| start_time = time.time() |
|
|
| try: |
| |
| inputs, parameters = self._parse_request(data) |
|
|
| |
| max_duration = parameters.get("max_duration_seconds", 300) |
|
|
| |
| audio, duration = self._load_audio(inputs, max_duration) |
| print(f"Audio loaded: {duration:.1f}s") |
|
|
| |
| if not self._cache.muq_model: |
| return { |
| "error": { |
| "code": "MODEL_NOT_LOADED", |
| "message": "MuQ model not initialized", |
| } |
| } |
|
|
| |
| print("Extracting MuQ embeddings (layers 9-12)...") |
| embeddings = extract_muq_embeddings(audio, self._cache) |
| print(f"MuQ embeddings shape: {embeddings.shape}") |
|
|
| |
| print("Running A1-Max ensemble inference...") |
| predictions = predict_with_ensemble(embeddings, self._cache) |
|
|
| |
| midi_notes = None |
| transcription_info = None |
| amt_error = None |
|
|
| try: |
| print("Running AMT transcription...") |
| amt_start = time.time() |
| midi_notes = self._transcription.transcribe(audio, 24000) |
| amt_elapsed_ms = int((time.time() - amt_start) * 1000) |
|
|
| pitches = [n["pitch"] for n in midi_notes] |
| transcription_info = { |
| "note_count": len(midi_notes), |
| "pitch_range": [min(pitches), max(pitches)] if pitches else [0, 0], |
| "transcription_time_ms": amt_elapsed_ms, |
| } |
| except TranscriptionError as e: |
| print(f"AMT failed (graceful degradation): {e}") |
| amt_error = str(e) |
|
|
| |
| processing_time_ms = int((time.time() - start_time) * 1000) |
|
|
| result = { |
| "predictions": self._predictions_to_dict(predictions), |
| "midi_notes": midi_notes, |
| "transcription_info": transcription_info, |
| "model_info": { |
| "name": MODEL_INFO["name"], |
| "type": MODEL_INFO["type"], |
| "pairwise": MODEL_INFO["pairwise"], |
| "architecture": MODEL_INFO["architecture"], |
| "ensemble_folds": len(self._cache.muq_heads), |
| }, |
| "audio_duration_seconds": duration, |
| "processing_time_ms": processing_time_ms, |
| } |
|
|
| if amt_error: |
| result["amt_error"] = amt_error |
|
|
| print(f"Inference complete in {processing_time_ms}ms") |
| return result |
|
|
| except AudioDownloadError as e: |
| return { |
| "error": { |
| "code": "AUDIO_DOWNLOAD_FAILED", |
| "message": str(e), |
| } |
| } |
|
|
| except AudioProcessingError as e: |
| return { |
| "error": { |
| "code": "AUDIO_PROCESSING_FAILED", |
| "message": str(e), |
| } |
| } |
|
|
| except Exception as e: |
| return { |
| "error": { |
| "code": "INFERENCE_ERROR", |
| "message": str(e), |
| "traceback": traceback.format_exc(), |
| } |
| } |
|
|
| def _parse_request(self, data: Dict[str, Any]) -> tuple: |
| """Parse request data supporting both HF and legacy formats. |
| |
| Returns: |
| Tuple of (inputs, parameters) |
| """ |
| |
| if "inputs" in data: |
| inputs = data["inputs"] |
| parameters = data.get("parameters", {}) |
| return inputs, parameters |
|
|
| |
| if "input" in data: |
| job_input = data["input"] |
| inputs = { |
| "audio_url": job_input.get("audio_url"), |
| "performance_id": job_input.get("performance_id", "unknown"), |
| } |
| parameters = job_input.get("options", {}) |
| parameters["performance_id"] = inputs.get("performance_id", "unknown") |
| return inputs, parameters |
|
|
| |
| return data, {} |
|
|
| def _load_audio( |
| self, inputs: Union[str, bytes, Dict[str, Any]], max_duration: int |
| ) -> tuple: |
| """Load audio from various input formats. |
| |
| Args: |
| inputs: One of: |
| - str: Base64-encoded audio bytes |
| - bytes: Raw audio bytes |
| - dict: {"audio_url": "..."} for URL-based loading |
| |
| Returns: |
| Tuple of (audio_array, duration_seconds) |
| """ |
| if isinstance(inputs, str): |
| |
| try: |
| audio_bytes = base64.b64decode(inputs) |
| return preprocess_audio_from_bytes(audio_bytes, max_duration=max_duration) |
| except Exception: |
| |
| if inputs.startswith("http"): |
| return download_and_preprocess_audio(inputs, max_duration=max_duration) |
| raise AudioProcessingError("Invalid input string: not base64 or URL") |
|
|
| elif isinstance(inputs, bytes): |
| |
| return preprocess_audio_from_bytes(inputs, max_duration=max_duration) |
|
|
| elif isinstance(inputs, dict): |
| |
| audio_url = inputs.get("audio_url") |
| if not audio_url: |
| raise AudioProcessingError("No audio_url provided in inputs") |
| return download_and_preprocess_audio(audio_url, max_duration=max_duration) |
|
|
| else: |
| raise AudioProcessingError(f"Unsupported input type: {type(inputs)}") |
|
|
| def _predictions_to_dict(self, preds: np.ndarray) -> Dict[str, float]: |
| """Convert prediction array to dimension dict.""" |
| return {dim: float(preds[i]) for i, dim in enumerate(PERCEPIANO_DIMENSIONS)} |
|
|