Spaces:
Running
Running
Commit ·
8d8acb9
1
Parent(s): 287f048
Fix SAP RPT-1 single-row prediction and add hf_xet
Browse files- matrix/benchmark.py +32 -2
- requirements.txt +2 -0
matrix/benchmark.py
CHANGED
|
@@ -102,13 +102,43 @@ class _SAPModel:
|
|
| 102 |
return self
|
| 103 |
|
| 104 |
def predict(self, X):
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
if not self._real and self.task == "classification":
|
| 107 |
preds = self._le.inverse_transform(preds)
|
| 108 |
return preds
|
| 109 |
|
| 110 |
def predict_proba(self, X):
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
@property
|
| 114 |
def label(self):
|
|
|
|
| 102 |
return self
|
| 103 |
|
| 104 |
def predict(self, X):
|
| 105 |
+
# Defensively handle single-row predictions which can crash some versions of sap_rpt_oss
|
| 106 |
+
is_single = len(X) == 1
|
| 107 |
+
if is_single:
|
| 108 |
+
if isinstance(X, pd.DataFrame):
|
| 109 |
+
X_in = pd.concat([X, X])
|
| 110 |
+
else:
|
| 111 |
+
X_in = np.vstack([X, X])
|
| 112 |
+
else:
|
| 113 |
+
X_in = X
|
| 114 |
+
|
| 115 |
+
preds = self._model.predict(X_in)
|
| 116 |
+
|
| 117 |
+
if is_single:
|
| 118 |
+
# Take the first prediction only
|
| 119 |
+
if isinstance(preds, (list, np.ndarray)):
|
| 120 |
+
preds = preds[:1]
|
| 121 |
+
|
| 122 |
if not self._real and self.task == "classification":
|
| 123 |
preds = self._le.inverse_transform(preds)
|
| 124 |
return preds
|
| 125 |
|
| 126 |
def predict_proba(self, X):
|
| 127 |
+
is_single = len(X) == 1
|
| 128 |
+
if is_single:
|
| 129 |
+
if isinstance(X, pd.DataFrame):
|
| 130 |
+
X_in = pd.concat([X, X])
|
| 131 |
+
else:
|
| 132 |
+
X_in = np.vstack([X, X])
|
| 133 |
+
else:
|
| 134 |
+
X_in = X
|
| 135 |
+
|
| 136 |
+
proba = self._model.predict_proba(X_in)
|
| 137 |
+
|
| 138 |
+
if is_single:
|
| 139 |
+
if isinstance(proba, (list, np.ndarray)):
|
| 140 |
+
proba = proba[:1]
|
| 141 |
+
return proba
|
| 142 |
|
| 143 |
@property
|
| 144 |
def label(self):
|
requirements.txt
CHANGED
|
@@ -4,6 +4,8 @@ uvicorn[standard]>=0.29.0
|
|
| 4 |
python-multipart>=0.0.9
|
| 5 |
python-dotenv>=1.0.0
|
| 6 |
jinja2>=3.1.3
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# Data & Science Stack
|
| 9 |
numpy==1.26.4
|
|
|
|
| 4 |
python-multipart>=0.0.9
|
| 5 |
python-dotenv>=1.0.0
|
| 6 |
jinja2>=3.1.3
|
| 7 |
+
hf_xet
|
| 8 |
+
huggingface_hub[hf_xet]
|
| 9 |
|
| 10 |
# Data & Science Stack
|
| 11 |
numpy==1.26.4
|