Akshay4506 commited on
Commit
73d00fd
·
1 Parent(s): e7f46d8

Fix single-row prediction error and enable high-speed HF transfers

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -1
  2. matrix/benchmark.py +14 -2
  3. requirements.txt +2 -0
Dockerfile CHANGED
@@ -4,7 +4,8 @@ FROM python:3.11-slim
4
  RUN useradd -m -u 1000 user
5
  USER user
6
  ENV HOME=/home/user \
7
- PATH=/home/user/.local/bin:$PATH
 
8
 
9
  WORKDIR $HOME/app
10
 
 
4
  RUN useradd -m -u 1000 user
5
  USER user
6
  ENV HOME=/home/user \
7
+ PATH=/home/user/.local/bin:$PATH \
8
+ HF_HUB_ENABLE_HF_TRANSFER=1
9
 
10
  WORKDIR $HOME/app
11
 
matrix/benchmark.py CHANGED
@@ -116,13 +116,25 @@ class _SAPModel:
116
  return self
117
 
118
  def predict(self, X):
119
- preds = self._model.predict(X)
 
 
 
 
 
 
 
 
120
  if not self._real and self.task == "classification":
121
  preds = self._le.inverse_transform(preds)
122
  return preds
123
 
124
  def predict_proba(self, X):
125
- return self._model.predict_proba(X)
 
 
 
 
126
 
127
  @property
128
  def label(self):
 
116
  return self
117
 
118
  def predict(self, X):
119
+ # sap_rpt_oss fails on single-row prediction due to an internal concatenation bug.
120
+ # We work around this by doubling the row if len(X) == 1.
121
+ is_single = len(X) == 1
122
+ X_input = pd.concat([X, X]) if is_single else X
123
+
124
+ preds = self._model.predict(X_input)
125
+ if is_single:
126
+ preds = preds[:1]
127
+
128
  if not self._real and self.task == "classification":
129
  preds = self._le.inverse_transform(preds)
130
  return preds
131
 
132
  def predict_proba(self, X):
133
+ is_single = len(X) == 1
134
+ X_input = pd.concat([X, X]) if is_single else X
135
+
136
+ proba = self._model.predict_proba(X_input)
137
+ return proba[:1] if is_single else proba
138
 
139
  @property
140
  def label(self):
requirements.txt CHANGED
@@ -4,6 +4,8 @@ uvicorn[standard]>=0.29.0
4
  python-multipart>=0.0.9
5
  python-dotenv>=1.0.0
6
  jinja2>=3.1.3
 
 
7
 
8
  # Data & Science Stack
9
  numpy==1.26.4
 
4
  python-multipart>=0.0.9
5
  python-dotenv>=1.0.0
6
  jinja2>=3.1.3
7
+ hf-transfer
8
+ hf_xet
9
 
10
  # Data & Science Stack
11
  numpy==1.26.4