File size: 5,426 Bytes
b6193b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
"""
Custom Inference Handler for Hugging Face Inference Endpoints
This handler loads the trained RandomForest model and provides
prediction functionality for the Hugging Face Inference API.
"""
import joblib
import numpy as np
from typing import Dict, List, Any, Union
import os
class EndpointHandler:
"""
Custom handler for Hugging Face Inference Endpoints.
This class is automatically instantiated by the Inference API
and handles incoming prediction requests.
"""
def __init__(self, path: str = ""):
"""
Initialize the handler by loading the model.
Args:
path: Path to the model directory (provided by HF Inference API)
"""
model_path = os.path.join(path, "model.joblib") if path else "model.joblib"
self.model = joblib.load(model_path)
# Feature names in expected order
self.feature_names = [
"SPC_D7", "SPC_D14", "SPC_D21",
"TGN_D7", "TGN_D14", "TGN_D21"
]
# Class names from the model
self.class_names = list(self.model.classes_)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Handle prediction requests.
Args:
data: Input data dictionary. Supports multiple formats:
- {"inputs": [[f1, f2, f3, f4, f5, f6], ...]} # List of feature arrays
- {"inputs": {"SPC_D7": 4.5, ...}} # Dict with feature names
- {"inputs": [{"SPC_D7": 4.5, ...}, ...]} # List of dicts
Returns:
List of prediction results with labels and probabilities
"""
# Extract inputs from the data
inputs = data.get("inputs", data)
# Convert inputs to numpy array
X = self._process_inputs(inputs)
# Make predictions
predictions = self.model.predict(X)
probabilities = self.model.predict_proba(X)
# Format results
results = []
for pred, probs in zip(predictions, probabilities):
result = {
"label": str(pred),
"score": float(max(probs)),
"probabilities": {
cls: float(prob)
for cls, prob in zip(self.class_names, probs)
}
}
results.append(result)
return results
def _process_inputs(self, inputs: Union[List, Dict]) -> np.ndarray:
"""
Process various input formats into a numpy array.
Args:
inputs: Input data in various formats
Returns:
Numpy array of shape (n_samples, n_features)
"""
# Case 1: List of lists/arrays (direct feature values)
if isinstance(inputs, list) and len(inputs) > 0:
if isinstance(inputs[0], (list, tuple, np.ndarray)):
return np.array(inputs).reshape(-1, len(self.feature_names))
# Case 2: List of dictionaries with feature names
elif isinstance(inputs[0], dict):
return np.array([
[sample.get(feat, 0) for feat in self.feature_names]
for sample in inputs
])
# Case 3: Single sample as flat list
else:
return np.array(inputs).reshape(1, -1)
# Case 4: Single dictionary with feature names
elif isinstance(inputs, dict):
return np.array([[
inputs.get(feat, 0) for feat in self.feature_names
]])
# Fallback: try to convert directly
return np.array(inputs).reshape(-1, len(self.feature_names))
# For local testing
if __name__ == "__main__":
# Test the handler locally
print("Testing EndpointHandler locally...")
try:
handler = EndpointHandler()
# Test with list format
test_data_list = {
"inputs": [[4.5, 5.2, 6.1, 3.2, 4.0, 4.8]]
}
result = handler(test_data_list)
print(f"\nTest 1 (list format):")
print(f" Input: {test_data_list}")
print(f" Output: {result}")
# Test with dict format
test_data_dict = {
"inputs": {
"SPC_D7": 4.5, "SPC_D14": 5.2, "SPC_D21": 6.1,
"TGN_D7": 3.2, "TGN_D14": 4.0, "TGN_D21": 4.8
}
}
result = handler(test_data_dict)
print(f"\nTest 2 (dict format):")
print(f" Input: {test_data_dict}")
print(f" Output: {result}")
# Test batch prediction
test_data_batch = {
"inputs": [
[4.5, 5.2, 6.1, 3.2, 4.0, 4.8],
[2.0, 2.5, 3.0, 1.5, 2.0, 2.5],
[6.0, 7.0, 8.0, 5.0, 6.0, 7.0]
]
}
result = handler(test_data_batch)
print(f"\nTest 3 (batch format):")
print(f" Input: {test_data_batch}")
print(f" Output: {result}")
print("\nAll tests passed!")
except FileNotFoundError:
print("Note: model.joblib not found. Run 'python prepare_model.py' first.")
|