"""HuggingFace Inference Endpoints handler for piano performance analysis. A1-Max MuQ LoRA model using MuQ layers 9-12 with attention pooling. Returns 6-dimension performance evaluation scores: dynamics, timing, pedaling, articulation, phrasing, interpretation. Compatible with HuggingFace Inference Endpoints custom handler pattern. """ import base64 import time import traceback from pathlib import Path from typing import Any, Dict, Union import numpy as np from constants import MODEL_INFO, PERCEPIANO_DIMENSIONS from models.loader import get_model_cache from models.inference import ( extract_muq_embeddings, predict_with_ensemble, ) from models.transcription import TranscriptionModel, TranscriptionError from preprocessing.audio import ( AudioDownloadError, AudioProcessingError, download_and_preprocess_audio, preprocess_audio_from_bytes, ) class EndpointHandler: """HuggingFace Inference Endpoints handler for piano performance analysis.""" def __init__(self, path: str = ""): """Initialize MuQ model and prediction heads. Called once when the endpoint container starts. Args: path: Path to the model repository (provided by HF Inference Endpoints). Contains the checkpoints/ directory with model weights. """ print(f"Initializing A1-Max EndpointHandler with path: {path}") # Determine checkpoint directory # HF Inference Endpoints mount the repo at the provided path # Fall back to /repository (HF default) or current dir for local testing if path: model_path = Path(path) else: model_path = Path("/repository") if not model_path.exists(): model_path = Path(".") checkpoint_dir = model_path / "checkpoints" if not checkpoint_dir.exists(): # Try /app/checkpoints for backward compatibility checkpoint_dir = Path("/app/checkpoints") print(f"Using checkpoint directory: {checkpoint_dir}") # Initialize model cache (loads MuQ and prediction heads) self._cache = get_model_cache() self._cache.initialize(device="cuda", checkpoint_dir=checkpoint_dir) # Initialize AMT transcription model print("Loading ByteDance AMT model...") self._transcription = TranscriptionModel(device="cuda") print("A1-Max EndpointHandler initialization complete!") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """Process inference request. Args: data: Request payload. Supports two formats: HuggingFace format: { "inputs": "" or {"audio_url": "..."}, "parameters": { "max_duration_seconds": 300 } } Legacy RunPod format (for backward compatibility): { "input": { "audio_url": "https://...", "options": {...} } } Returns: Prediction results: { "predictions": {"timing": 0.85, ...}, "model_info": {"name": "M1c-MuQ-L9-12", "r2": 0.539}, "audio_duration_seconds": 180.5, "processing_time_ms": 1234 } Or error: { "error": {"code": "...", "message": "..."} } """ start_time = time.time() try: # Parse input - support both HF and legacy RunPod formats inputs, parameters = self._parse_request(data) # Extract parameters max_duration = parameters.get("max_duration_seconds", 300) # Load and preprocess audio audio, duration = self._load_audio(inputs, max_duration) print(f"Audio loaded: {duration:.1f}s") # Verify models are loaded if not self._cache.muq_model: return { "error": { "code": "MODEL_NOT_LOADED", "message": "MuQ model not initialized", } } # Extract MuQ embeddings (averaged layers 9-12) print("Extracting MuQ embeddings (layers 9-12)...") embeddings = extract_muq_embeddings(audio, self._cache) print(f"MuQ embeddings shape: {embeddings.shape}") # Get ensemble predictions (4-fold A1-Max) print("Running A1-Max ensemble inference...") predictions = predict_with_ensemble(embeddings, self._cache) # Run AMT transcription (after MuQ scoring, sequential) midi_notes = None transcription_info = None amt_error = None try: print("Running AMT transcription...") amt_start = time.time() midi_notes = self._transcription.transcribe(audio, 24000) amt_elapsed_ms = int((time.time() - amt_start) * 1000) pitches = [n["pitch"] for n in midi_notes] transcription_info = { "note_count": len(midi_notes), "pitch_range": [min(pitches), max(pitches)] if pitches else [0, 0], "transcription_time_ms": amt_elapsed_ms, } except TranscriptionError as e: print(f"AMT failed (graceful degradation): {e}") amt_error = str(e) # Build combined response processing_time_ms = int((time.time() - start_time) * 1000) result = { "predictions": self._predictions_to_dict(predictions), "midi_notes": midi_notes, "transcription_info": transcription_info, "model_info": { "name": MODEL_INFO["name"], "type": MODEL_INFO["type"], "pairwise": MODEL_INFO["pairwise"], "architecture": MODEL_INFO["architecture"], "ensemble_folds": len(self._cache.muq_heads), }, "audio_duration_seconds": duration, "processing_time_ms": processing_time_ms, } if amt_error: result["amt_error"] = amt_error print(f"Inference complete in {processing_time_ms}ms") return result except AudioDownloadError as e: return { "error": { "code": "AUDIO_DOWNLOAD_FAILED", "message": str(e), } } except AudioProcessingError as e: return { "error": { "code": "AUDIO_PROCESSING_FAILED", "message": str(e), } } except Exception as e: return { "error": { "code": "INFERENCE_ERROR", "message": str(e), "traceback": traceback.format_exc(), } } def _parse_request(self, data: Dict[str, Any]) -> tuple: """Parse request data supporting both HF and legacy formats. Returns: Tuple of (inputs, parameters) """ # HF format: {"inputs": ..., "parameters": ...} if "inputs" in data: inputs = data["inputs"] parameters = data.get("parameters", {}) return inputs, parameters # Legacy RunPod format: {"input": {"audio_url": ..., "options": ...}} if "input" in data: job_input = data["input"] inputs = { "audio_url": job_input.get("audio_url"), "performance_id": job_input.get("performance_id", "unknown"), } parameters = job_input.get("options", {}) parameters["performance_id"] = inputs.get("performance_id", "unknown") return inputs, parameters # Fallback: treat entire data as inputs return data, {} def _load_audio( self, inputs: Union[str, bytes, Dict[str, Any]], max_duration: int ) -> tuple: """Load audio from various input formats. Args: inputs: One of: - str: Base64-encoded audio bytes - bytes: Raw audio bytes - dict: {"audio_url": "..."} for URL-based loading Returns: Tuple of (audio_array, duration_seconds) """ if isinstance(inputs, str): # Base64-encoded audio try: audio_bytes = base64.b64decode(inputs) return preprocess_audio_from_bytes(audio_bytes, max_duration=max_duration) except Exception: # Maybe it's a URL string if inputs.startswith("http"): return download_and_preprocess_audio(inputs, max_duration=max_duration) raise AudioProcessingError("Invalid input string: not base64 or URL") elif isinstance(inputs, bytes): # Raw bytes return preprocess_audio_from_bytes(inputs, max_duration=max_duration) elif isinstance(inputs, dict): # URL-based input audio_url = inputs.get("audio_url") if not audio_url: raise AudioProcessingError("No audio_url provided in inputs") return download_and_preprocess_audio(audio_url, max_duration=max_duration) else: raise AudioProcessingError(f"Unsupported input type: {type(inputs)}") def _predictions_to_dict(self, preds: np.ndarray) -> Dict[str, float]: """Convert prediction array to dimension dict.""" return {dim: float(preds[i]) for i, dim in enumerate(PERCEPIANO_DIMENSIONS)}