|
|
"""
|
|
|
Custom Inference Handler for Hugging Face Inference Endpoints
|
|
|
|
|
|
This handler loads the trained RandomForest model and provides
|
|
|
prediction functionality for the Hugging Face Inference API.
|
|
|
"""
|
|
|
|
|
|
import joblib
|
|
|
import numpy as np
|
|
|
from typing import Dict, List, Any, Union
|
|
|
import os
|
|
|
|
|
|
|
|
|
class EndpointHandler:
|
|
|
"""
|
|
|
Custom handler for Hugging Face Inference Endpoints.
|
|
|
|
|
|
This class is automatically instantiated by the Inference API
|
|
|
and handles incoming prediction requests.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, path: str = ""):
|
|
|
"""
|
|
|
Initialize the handler by loading the model.
|
|
|
|
|
|
Args:
|
|
|
path: Path to the model directory (provided by HF Inference API)
|
|
|
"""
|
|
|
model_path = os.path.join(path, "model.joblib") if path else "model.joblib"
|
|
|
self.model = joblib.load(model_path)
|
|
|
|
|
|
|
|
|
self.feature_names = [
|
|
|
"SPC_D7", "SPC_D14", "SPC_D21",
|
|
|
"TGN_D7", "TGN_D14", "TGN_D21"
|
|
|
]
|
|
|
|
|
|
|
|
|
self.class_names = list(self.model.classes_)
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
Handle prediction requests.
|
|
|
|
|
|
Args:
|
|
|
data: Input data dictionary. Supports multiple formats:
|
|
|
- {"inputs": [[f1, f2, f3, f4, f5, f6], ...]} # List of feature arrays
|
|
|
- {"inputs": {"SPC_D7": 4.5, ...}} # Dict with feature names
|
|
|
- {"inputs": [{"SPC_D7": 4.5, ...}, ...]} # List of dicts
|
|
|
|
|
|
Returns:
|
|
|
List of prediction results with labels and probabilities
|
|
|
"""
|
|
|
|
|
|
inputs = data.get("inputs", data)
|
|
|
|
|
|
|
|
|
X = self._process_inputs(inputs)
|
|
|
|
|
|
|
|
|
predictions = self.model.predict(X)
|
|
|
probabilities = self.model.predict_proba(X)
|
|
|
|
|
|
|
|
|
results = []
|
|
|
for pred, probs in zip(predictions, probabilities):
|
|
|
result = {
|
|
|
"label": str(pred),
|
|
|
"score": float(max(probs)),
|
|
|
"probabilities": {
|
|
|
cls: float(prob)
|
|
|
for cls, prob in zip(self.class_names, probs)
|
|
|
}
|
|
|
}
|
|
|
results.append(result)
|
|
|
|
|
|
return results
|
|
|
|
|
|
def _process_inputs(self, inputs: Union[List, Dict]) -> np.ndarray:
|
|
|
"""
|
|
|
Process various input formats into a numpy array.
|
|
|
|
|
|
Args:
|
|
|
inputs: Input data in various formats
|
|
|
|
|
|
Returns:
|
|
|
Numpy array of shape (n_samples, n_features)
|
|
|
"""
|
|
|
|
|
|
if isinstance(inputs, list) and len(inputs) > 0:
|
|
|
if isinstance(inputs[0], (list, tuple, np.ndarray)):
|
|
|
return np.array(inputs).reshape(-1, len(self.feature_names))
|
|
|
|
|
|
|
|
|
elif isinstance(inputs[0], dict):
|
|
|
return np.array([
|
|
|
[sample.get(feat, 0) for feat in self.feature_names]
|
|
|
for sample in inputs
|
|
|
])
|
|
|
|
|
|
|
|
|
else:
|
|
|
return np.array(inputs).reshape(1, -1)
|
|
|
|
|
|
|
|
|
elif isinstance(inputs, dict):
|
|
|
return np.array([[
|
|
|
inputs.get(feat, 0) for feat in self.feature_names
|
|
|
]])
|
|
|
|
|
|
|
|
|
return np.array(inputs).reshape(-1, len(self.feature_names))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
print("Testing EndpointHandler locally...")
|
|
|
|
|
|
try:
|
|
|
handler = EndpointHandler()
|
|
|
|
|
|
|
|
|
test_data_list = {
|
|
|
"inputs": [[4.5, 5.2, 6.1, 3.2, 4.0, 4.8]]
|
|
|
}
|
|
|
result = handler(test_data_list)
|
|
|
print(f"\nTest 1 (list format):")
|
|
|
print(f" Input: {test_data_list}")
|
|
|
print(f" Output: {result}")
|
|
|
|
|
|
|
|
|
test_data_dict = {
|
|
|
"inputs": {
|
|
|
"SPC_D7": 4.5, "SPC_D14": 5.2, "SPC_D21": 6.1,
|
|
|
"TGN_D7": 3.2, "TGN_D14": 4.0, "TGN_D21": 4.8
|
|
|
}
|
|
|
}
|
|
|
result = handler(test_data_dict)
|
|
|
print(f"\nTest 2 (dict format):")
|
|
|
print(f" Input: {test_data_dict}")
|
|
|
print(f" Output: {result}")
|
|
|
|
|
|
|
|
|
test_data_batch = {
|
|
|
"inputs": [
|
|
|
[4.5, 5.2, 6.1, 3.2, 4.0, 4.8],
|
|
|
[2.0, 2.5, 3.0, 1.5, 2.0, 2.5],
|
|
|
[6.0, 7.0, 8.0, 5.0, 6.0, 7.0]
|
|
|
]
|
|
|
}
|
|
|
result = handler(test_data_batch)
|
|
|
print(f"\nTest 3 (batch format):")
|
|
|
print(f" Input: {test_data_batch}")
|
|
|
print(f" Output: {result}")
|
|
|
|
|
|
print("\nAll tests passed!")
|
|
|
|
|
|
except FileNotFoundError:
|
|
|
print("Note: model.joblib not found. Run 'python prepare_model.py' first.")
|
|
|
|