File size: 9,968 Bytes
bfc6d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f5895b
bfc6d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f5895b
 
 
 
bfc6d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f5895b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfc6d2a
 
 
 
1f5895b
 
bfc6d2a
 
 
 
 
 
 
 
 
 
 
1f5895b
 
 
bfc6d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
"""HuggingFace Inference Endpoints handler for piano performance analysis.

A1-Max MuQ LoRA model using MuQ layers 9-12 with attention pooling.
Returns 6-dimension performance evaluation scores:
dynamics, timing, pedaling, articulation, phrasing, interpretation.

Compatible with HuggingFace Inference Endpoints custom handler pattern.
"""

import base64
import time
import traceback
from pathlib import Path
from typing import Any, Dict, Union

import numpy as np

from constants import MODEL_INFO, PERCEPIANO_DIMENSIONS
from models.loader import get_model_cache
from models.inference import (
    extract_muq_embeddings,
    predict_with_ensemble,
)
from models.transcription import TranscriptionModel, TranscriptionError
from preprocessing.audio import (
    AudioDownloadError,
    AudioProcessingError,
    download_and_preprocess_audio,
    preprocess_audio_from_bytes,
)


class EndpointHandler:
    """HuggingFace Inference Endpoints handler for piano performance analysis."""

    def __init__(self, path: str = ""):
        """Initialize MuQ model and prediction heads.

        Called once when the endpoint container starts.

        Args:
            path: Path to the model repository (provided by HF Inference Endpoints).
                  Contains the checkpoints/ directory with model weights.
        """
        print(f"Initializing A1-Max EndpointHandler with path: {path}")

        # Determine checkpoint directory
        # HF Inference Endpoints mount the repo at the provided path
        # Fall back to /repository (HF default) or current dir for local testing
        if path:
            model_path = Path(path)
        else:
            model_path = Path("/repository")
            if not model_path.exists():
                model_path = Path(".")

        checkpoint_dir = model_path / "checkpoints"
        if not checkpoint_dir.exists():
            # Try /app/checkpoints for backward compatibility
            checkpoint_dir = Path("/app/checkpoints")

        print(f"Using checkpoint directory: {checkpoint_dir}")

        # Initialize model cache (loads MuQ and prediction heads)
        self._cache = get_model_cache()
        self._cache.initialize(device="cuda", checkpoint_dir=checkpoint_dir)

        # Initialize AMT transcription model
        print("Loading ByteDance AMT model...")
        self._transcription = TranscriptionModel(device="cuda")

        print("A1-Max EndpointHandler initialization complete!")

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Process inference request.

        Args:
            data: Request payload. Supports two formats:

                HuggingFace format:
                {
                    "inputs": "<base64-audio>" or {"audio_url": "..."},
                    "parameters": {
                        "max_duration_seconds": 300
                    }
                }

                Legacy RunPod format (for backward compatibility):
                {
                    "input": {
                        "audio_url": "https://...",
                        "options": {...}
                    }
                }

        Returns:
            Prediction results:
            {
                "predictions": {"timing": 0.85, ...},
                "model_info": {"name": "M1c-MuQ-L9-12", "r2": 0.539},
                "audio_duration_seconds": 180.5,
                "processing_time_ms": 1234
            }

            Or error:
            {
                "error": {"code": "...", "message": "..."}
            }
        """
        start_time = time.time()

        try:
            # Parse input - support both HF and legacy RunPod formats
            inputs, parameters = self._parse_request(data)

            # Extract parameters
            max_duration = parameters.get("max_duration_seconds", 300)

            # Load and preprocess audio
            audio, duration = self._load_audio(inputs, max_duration)
            print(f"Audio loaded: {duration:.1f}s")

            # Verify models are loaded
            if not self._cache.muq_model:
                return {
                    "error": {
                        "code": "MODEL_NOT_LOADED",
                        "message": "MuQ model not initialized",
                    }
                }

            # Extract MuQ embeddings (averaged layers 9-12)
            print("Extracting MuQ embeddings (layers 9-12)...")
            embeddings = extract_muq_embeddings(audio, self._cache)
            print(f"MuQ embeddings shape: {embeddings.shape}")

            # Get ensemble predictions (4-fold A1-Max)
            print("Running A1-Max ensemble inference...")
            predictions = predict_with_ensemble(embeddings, self._cache)

            # Run AMT transcription (after MuQ scoring, sequential)
            midi_notes = None
            transcription_info = None
            amt_error = None

            try:
                print("Running AMT transcription...")
                amt_start = time.time()
                midi_notes = self._transcription.transcribe(audio, 24000)
                amt_elapsed_ms = int((time.time() - amt_start) * 1000)

                pitches = [n["pitch"] for n in midi_notes]
                transcription_info = {
                    "note_count": len(midi_notes),
                    "pitch_range": [min(pitches), max(pitches)] if pitches else [0, 0],
                    "transcription_time_ms": amt_elapsed_ms,
                }
            except TranscriptionError as e:
                print(f"AMT failed (graceful degradation): {e}")
                amt_error = str(e)

            # Build combined response
            processing_time_ms = int((time.time() - start_time) * 1000)

            result = {
                "predictions": self._predictions_to_dict(predictions),
                "midi_notes": midi_notes,
                "transcription_info": transcription_info,
                "model_info": {
                    "name": MODEL_INFO["name"],
                    "type": MODEL_INFO["type"],
                    "pairwise": MODEL_INFO["pairwise"],
                    "architecture": MODEL_INFO["architecture"],
                    "ensemble_folds": len(self._cache.muq_heads),
                },
                "audio_duration_seconds": duration,
                "processing_time_ms": processing_time_ms,
            }

            if amt_error:
                result["amt_error"] = amt_error

            print(f"Inference complete in {processing_time_ms}ms")
            return result

        except AudioDownloadError as e:
            return {
                "error": {
                    "code": "AUDIO_DOWNLOAD_FAILED",
                    "message": str(e),
                }
            }

        except AudioProcessingError as e:
            return {
                "error": {
                    "code": "AUDIO_PROCESSING_FAILED",
                    "message": str(e),
                }
            }

        except Exception as e:
            return {
                "error": {
                    "code": "INFERENCE_ERROR",
                    "message": str(e),
                    "traceback": traceback.format_exc(),
                }
            }

    def _parse_request(self, data: Dict[str, Any]) -> tuple:
        """Parse request data supporting both HF and legacy formats.

        Returns:
            Tuple of (inputs, parameters)
        """
        # HF format: {"inputs": ..., "parameters": ...}
        if "inputs" in data:
            inputs = data["inputs"]
            parameters = data.get("parameters", {})
            return inputs, parameters

        # Legacy RunPod format: {"input": {"audio_url": ..., "options": ...}}
        if "input" in data:
            job_input = data["input"]
            inputs = {
                "audio_url": job_input.get("audio_url"),
                "performance_id": job_input.get("performance_id", "unknown"),
            }
            parameters = job_input.get("options", {})
            parameters["performance_id"] = inputs.get("performance_id", "unknown")
            return inputs, parameters

        # Fallback: treat entire data as inputs
        return data, {}

    def _load_audio(
        self, inputs: Union[str, bytes, Dict[str, Any]], max_duration: int
    ) -> tuple:
        """Load audio from various input formats.

        Args:
            inputs: One of:
                - str: Base64-encoded audio bytes
                - bytes: Raw audio bytes
                - dict: {"audio_url": "..."} for URL-based loading

        Returns:
            Tuple of (audio_array, duration_seconds)
        """
        if isinstance(inputs, str):
            # Base64-encoded audio
            try:
                audio_bytes = base64.b64decode(inputs)
                return preprocess_audio_from_bytes(audio_bytes, max_duration=max_duration)
            except Exception:
                # Maybe it's a URL string
                if inputs.startswith("http"):
                    return download_and_preprocess_audio(inputs, max_duration=max_duration)
                raise AudioProcessingError("Invalid input string: not base64 or URL")

        elif isinstance(inputs, bytes):
            # Raw bytes
            return preprocess_audio_from_bytes(inputs, max_duration=max_duration)

        elif isinstance(inputs, dict):
            # URL-based input
            audio_url = inputs.get("audio_url")
            if not audio_url:
                raise AudioProcessingError("No audio_url provided in inputs")
            return download_and_preprocess_audio(audio_url, max_duration=max_duration)

        else:
            raise AudioProcessingError(f"Unsupported input type: {type(inputs)}")

    def _predictions_to_dict(self, preds: np.ndarray) -> Dict[str, float]:
        """Convert prediction array to dimension dict."""
        return {dim: float(preds[i]) for i, dim in enumerate(PERCEPIANO_DIMENSIONS)}