Akshay4506 commited on
Commit
8d8acb9
·
1 Parent(s): 287f048

Fix SAP RPT-1 single-row prediction and add hf_xet

Browse files
Files changed (2) hide show
  1. matrix/benchmark.py +32 -2
  2. requirements.txt +2 -0
matrix/benchmark.py CHANGED
@@ -102,13 +102,43 @@ class _SAPModel:
102
  return self
103
 
104
  def predict(self, X):
105
- preds = self._model.predict(X)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if not self._real and self.task == "classification":
107
  preds = self._le.inverse_transform(preds)
108
  return preds
109
 
110
  def predict_proba(self, X):
111
- return self._model.predict_proba(X)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  @property
114
  def label(self):
 
102
  return self
103
 
104
  def predict(self, X):
105
+ # Defensively handle single-row predictions which can crash some versions of sap_rpt_oss
106
+ is_single = len(X) == 1
107
+ if is_single:
108
+ if isinstance(X, pd.DataFrame):
109
+ X_in = pd.concat([X, X])
110
+ else:
111
+ X_in = np.vstack([X, X])
112
+ else:
113
+ X_in = X
114
+
115
+ preds = self._model.predict(X_in)
116
+
117
+ if is_single:
118
+ # Take the first prediction only
119
+ if isinstance(preds, (list, np.ndarray)):
120
+ preds = preds[:1]
121
+
122
  if not self._real and self.task == "classification":
123
  preds = self._le.inverse_transform(preds)
124
  return preds
125
 
126
  def predict_proba(self, X):
127
+ is_single = len(X) == 1
128
+ if is_single:
129
+ if isinstance(X, pd.DataFrame):
130
+ X_in = pd.concat([X, X])
131
+ else:
132
+ X_in = np.vstack([X, X])
133
+ else:
134
+ X_in = X
135
+
136
+ proba = self._model.predict_proba(X_in)
137
+
138
+ if is_single:
139
+ if isinstance(proba, (list, np.ndarray)):
140
+ proba = proba[:1]
141
+ return proba
142
 
143
  @property
144
  def label(self):
requirements.txt CHANGED
@@ -4,6 +4,8 @@ uvicorn[standard]>=0.29.0
4
  python-multipart>=0.0.9
5
  python-dotenv>=1.0.0
6
  jinja2>=3.1.3
 
 
7
 
8
  # Data & Science Stack
9
  numpy==1.26.4
 
4
  python-multipart>=0.0.9
5
  python-dotenv>=1.0.0
6
  jinja2>=3.1.3
7
+ hf_xet
8
+ huggingface_hub[hf_xet]
9
 
10
  # Data & Science Stack
11
  numpy==1.26.4