Spaces:
Running
Running
Commit ·
fc6a2fc
1
Parent(s): 6f18cb2
fix: downgrade tabpfn to 0.1.11, class-level weight cache, fix ravel on list
Browse files- code/models/tabpfn_wrapper.py +26 -35
- webapp/benchmark.py +4 -9
- webapp/main.py +2 -2
- webapp/requirements.txt +1 -1
code/models/tabpfn_wrapper.py
CHANGED
|
@@ -78,6 +78,10 @@ class TabPFNWrapper(BaseModelWrapper):
|
|
| 78 |
Random seed
|
| 79 |
"""
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
def __init__(
|
| 82 |
self,
|
| 83 |
task_type: str = 'classification',
|
|
@@ -96,18 +100,6 @@ class TabPFNWrapper(BaseModelWrapper):
|
|
| 96 |
def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]) -> 'TabPFNWrapper':
|
| 97 |
"""
|
| 98 |
Fit TabPFN (stores training data for in-context learning).
|
| 99 |
-
|
| 100 |
-
Parameters
|
| 101 |
-
----------
|
| 102 |
-
X : pd.DataFrame or np.ndarray, shape (n_samples, n_features)
|
| 103 |
-
Training features (max 1000 samples, 100 features)
|
| 104 |
-
y : pd.Series or np.ndarray, shape (n_samples,)
|
| 105 |
-
Training target
|
| 106 |
-
|
| 107 |
-
Returns
|
| 108 |
-
-------
|
| 109 |
-
self : TabPFNWrapper
|
| 110 |
-
Fitted model
|
| 111 |
"""
|
| 112 |
self._validate_input(X, y)
|
| 113 |
|
|
@@ -145,32 +137,31 @@ class TabPFNWrapper(BaseModelWrapper):
|
|
| 145 |
|
| 146 |
import torch
|
| 147 |
import tabpfn
|
| 148 |
-
|
| 149 |
actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else ('cpu' if self.device == 'auto' else self.device)
|
| 150 |
-
|
| 151 |
-
# Reuse
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
logger.info("
|
| 155 |
-
else:
|
| 156 |
if hasattr(tabpfn, '__version__') and tabpfn.__version__.startswith('0.1'):
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
| 158 |
else:
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
# Fit model
|
| 173 |
-
self.model.fit(X, y)
|
| 174 |
|
| 175 |
self.is_fitted = True
|
| 176 |
self.fit_time = time.time() - start_time
|
|
|
|
| 78 |
Random seed
|
| 79 |
"""
|
| 80 |
|
| 81 |
+
# Class-level cache: weights are loaded once and shared across ALL instances
|
| 82 |
+
# in the same process. This prevents reloading 103 weight files on every CV fold.
|
| 83 |
+
_shared_classifier = None
|
| 84 |
+
|
| 85 |
def __init__(
|
| 86 |
self,
|
| 87 |
task_type: str = 'classification',
|
|
|
|
| 100 |
def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]) -> 'TabPFNWrapper':
|
| 101 |
"""
|
| 102 |
Fit TabPFN (stores training data for in-context learning).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
"""
|
| 104 |
self._validate_input(X, y)
|
| 105 |
|
|
|
|
| 137 |
|
| 138 |
import torch
|
| 139 |
import tabpfn
|
| 140 |
+
|
| 141 |
actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else ('cpu' if self.device == 'auto' else self.device)
|
| 142 |
+
|
| 143 |
+
# Reuse class-level cached classifier so weights are only loaded ONCE
|
| 144 |
+
# per process, not once per CV fold.
|
| 145 |
+
if TabPFNWrapper._shared_classifier is None:
|
| 146 |
+
logger.info("Creating new TabPFNClassifier and caching at class level...")
|
|
|
|
| 147 |
if hasattr(tabpfn, '__version__') and tabpfn.__version__.startswith('0.1'):
|
| 148 |
+
TabPFNWrapper._shared_classifier = TabPFNClassifier(
|
| 149 |
+
device=actual_device,
|
| 150 |
+
N_ensemble_configurations=self.n_ensemble
|
| 151 |
+
)
|
| 152 |
else:
|
| 153 |
+
TabPFNWrapper._shared_classifier = TabPFNClassifier(device=actual_device)
|
| 154 |
+
else:
|
| 155 |
+
logger.info("Reusing cached TabPFN classifier (weights NOT reloaded).")
|
| 156 |
+
|
| 157 |
+
self.model = TabPFNWrapper._shared_classifier
|
| 158 |
+
|
| 159 |
+
# Fit — for v0.1.x, overwrite_warning=True suppresses the data size warning
|
| 160 |
+
try:
|
| 161 |
+
self.model.fit(X, y, overwrite_warning=True)
|
| 162 |
+
except TypeError:
|
| 163 |
+
# v2+ doesn't accept overwrite_warning kwarg
|
| 164 |
+
self.model.fit(X, y)
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
self.is_fitted = True
|
| 167 |
self.fit_time = time.time() - start_time
|
webapp/benchmark.py
CHANGED
|
@@ -63,15 +63,10 @@ def _tabpfn(task):
|
|
| 63 |
if task != "classification":
|
| 64 |
raise ValueError("TabPFN only supports classification tasks")
|
| 65 |
from models.tabpfn_wrapper import TabPFNWrapper
|
| 66 |
-
|
| 67 |
-
#
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
if _TABPFN_CACHED_MODEL is not None:
|
| 71 |
-
wrapper._cached_classifier = _TABPFN_CACHED_MODEL
|
| 72 |
-
return wrapper
|
| 73 |
-
|
| 74 |
-
_TABPFN_CACHED_MODEL = None
|
| 75 |
|
| 76 |
|
| 77 |
class _SAPModel:
|
|
|
|
| 63 |
if task != "classification":
|
| 64 |
raise ValueError("TabPFN only supports classification tasks")
|
| 65 |
from models.tabpfn_wrapper import TabPFNWrapper
|
| 66 |
+
# TabPFNWrapper uses a class-level _shared_classifier so weights are only
|
| 67 |
+
# loaded once per process regardless of how many instances are created.
|
| 68 |
+
return TabPFNWrapper(task_type=task, random_state=RAND)
|
| 69 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
|
| 72 |
class _SAPModel:
|
webapp/main.py
CHANGED
|
@@ -245,7 +245,7 @@ async def predict(data: dict):
|
|
| 245 |
X_test, _ = _prep(input_df, encoders=CHAMPION_INFO.get("encoders"))
|
| 246 |
|
| 247 |
if CHAMPION_INFO["task"] == "classification":
|
| 248 |
-
raw_pred = CHAMPION_MODEL.predict(X_test)
|
| 249 |
# Flatten if nested (CatBoost/Sklearn sometimes return [[val]] or [val])
|
| 250 |
pred_val = raw_pred.ravel()[0]
|
| 251 |
pred_idx = int(pred_val)
|
|
@@ -266,7 +266,7 @@ async def predict(data: dict):
|
|
| 266 |
"labels": CHAMPION_INFO["labels"]
|
| 267 |
}
|
| 268 |
else:
|
| 269 |
-
raw_pred = CHAMPION_MODEL.predict(X_test)
|
| 270 |
pred = float(raw_pred.ravel()[0])
|
| 271 |
return {"prediction": pred}
|
| 272 |
|
|
|
|
| 245 |
X_test, _ = _prep(input_df, encoders=CHAMPION_INFO.get("encoders"))
|
| 246 |
|
| 247 |
if CHAMPION_INFO["task"] == "classification":
|
| 248 |
+
raw_pred = np.array(CHAMPION_MODEL.predict(X_test))
|
| 249 |
# Flatten if nested (CatBoost/Sklearn sometimes return [[val]] or [val])
|
| 250 |
pred_val = raw_pred.ravel()[0]
|
| 251 |
pred_idx = int(pred_val)
|
|
|
|
| 266 |
"labels": CHAMPION_INFO["labels"]
|
| 267 |
}
|
| 268 |
else:
|
| 269 |
+
raw_pred = np.array(CHAMPION_MODEL.predict(X_test))
|
| 270 |
pred = float(raw_pred.ravel()[0])
|
| 271 |
return {"prediction": pred}
|
| 272 |
|
webapp/requirements.txt
CHANGED
|
@@ -9,5 +9,5 @@ scikit-learn>=1.3.0
|
|
| 9 |
scipy>=1.10.0
|
| 10 |
pandas>=2.0.0
|
| 11 |
numpy>=1.24.0
|
| 12 |
-
tabpfn
|
| 13 |
huggingface_hub
|
|
|
|
| 9 |
scipy>=1.10.0
|
| 10 |
pandas>=2.0.0
|
| 11 |
numpy>=1.24.0
|
| 12 |
+
tabpfn==0.1.11
|
| 13 |
huggingface_hub
|