Spaces:
Sleeping
Sleeping
Commit ·
30593bd
1
Parent(s): e17f3ba
fix: add missing scipy dep, TabPFN weight caching, better error logging
Browse files- code/models/tabpfn_wrapper.py +19 -3
- webapp/benchmark.py +9 -1
- webapp/main.py +2 -0
- webapp/requirements.txt +1 -0
code/models/tabpfn_wrapper.py
CHANGED
|
@@ -144,10 +144,26 @@ class TabPFNWrapper(BaseModelWrapper):
|
|
| 144 |
|
| 145 |
actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else ('cpu' if self.device == 'auto' else self.device)
|
| 146 |
|
| 147 |
-
if
|
| 148 |
-
|
|
|
|
|
|
|
| 149 |
else:
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
# Fit model
|
| 153 |
self.model.fit(X, y)
|
|
|
|
| 144 |
|
| 145 |
actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else ('cpu' if self.device == 'auto' else self.device)
|
| 146 |
|
| 147 |
+
# Reuse a cached classifier if one was injected to avoid reloading weights
|
| 148 |
+
if hasattr(self, '_cached_classifier') and self._cached_classifier is not None:
|
| 149 |
+
self.model = self._cached_classifier
|
| 150 |
+
logger.info("Reusing cached TabPFN classifier (skipping weight reload)")
|
| 151 |
else:
|
| 152 |
+
if hasattr(tabpfn, '__version__') and tabpfn.__version__.startswith('0.1'):
|
| 153 |
+
self.model = TabPFNClassifier(device=actual_device, N_ensemble_configurations=self.n_ensemble)
|
| 154 |
+
else:
|
| 155 |
+
self.model = TabPFNClassifier(device=actual_device)
|
| 156 |
+
|
| 157 |
+
# Store in global cache for future reuse
|
| 158 |
+
try:
|
| 159 |
+
import benchmark
|
| 160 |
+
benchmark._TABPFN_CACHED_MODEL = self.model
|
| 161 |
+
except (ImportError, AttributeError):
|
| 162 |
+
try:
|
| 163 |
+
import webapp.benchmark as wb
|
| 164 |
+
wb._TABPFN_CACHED_MODEL = self.model
|
| 165 |
+
except (ImportError, AttributeError):
|
| 166 |
+
pass
|
| 167 |
|
| 168 |
# Fit model
|
| 169 |
self.model.fit(X, y)
|
webapp/benchmark.py
CHANGED
|
@@ -55,7 +55,15 @@ def _tabpfn(task):
|
|
| 55 |
if task != "classification":
|
| 56 |
raise ValueError("TabPFN only supports classification tasks")
|
| 57 |
from models.tabpfn_wrapper import TabPFNWrapper
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
class _SAPModel:
|
|
|
|
| 55 |
if task != "classification":
|
| 56 |
raise ValueError("TabPFN only supports classification tasks")
|
| 57 |
from models.tabpfn_wrapper import TabPFNWrapper
|
| 58 |
+
wrapper = TabPFNWrapper(task_type=task, random_state=RAND)
|
| 59 |
+
# Re-use the cached TabPFNClassifier model if available to avoid
|
| 60 |
+
# reloading weights on every CV fold (saves ~2s per fold + RAM)
|
| 61 |
+
global _TABPFN_CACHED_MODEL
|
| 62 |
+
if _TABPFN_CACHED_MODEL is not None:
|
| 63 |
+
wrapper._cached_classifier = _TABPFN_CACHED_MODEL
|
| 64 |
+
return wrapper
|
| 65 |
+
|
| 66 |
+
_TABPFN_CACHED_MODEL = None
|
| 67 |
|
| 68 |
|
| 69 |
class _SAPModel:
|
webapp/main.py
CHANGED
|
@@ -212,6 +212,8 @@ async def benchmark(
|
|
| 212 |
}
|
| 213 |
|
| 214 |
except Exception as e:
|
|
|
|
|
|
|
| 215 |
raise HTTPException(500, f"Benchmarking failed: {e}")
|
| 216 |
|
| 217 |
return JSONResponse(result)
|
|
|
|
| 212 |
}
|
| 213 |
|
| 214 |
except Exception as e:
|
| 215 |
+
import traceback
|
| 216 |
+
traceback.print_exc()
|
| 217 |
raise HTTPException(500, f"Benchmarking failed: {e}")
|
| 218 |
|
| 219 |
return JSONResponse(result)
|
webapp/requirements.txt
CHANGED
|
@@ -6,6 +6,7 @@ xgboost>=2.0.0
|
|
| 6 |
lightgbm>=4.0.0
|
| 7 |
catboost>=1.2.0
|
| 8 |
scikit-learn>=1.3.0
|
|
|
|
| 9 |
pandas>=2.0.0
|
| 10 |
numpy>=1.24.0
|
| 11 |
tabpfn>=7.1.1
|
|
|
|
| 6 |
lightgbm>=4.0.0
|
| 7 |
catboost>=1.2.0
|
| 8 |
scikit-learn>=1.3.0
|
| 9 |
+
scipy>=1.10.0
|
| 10 |
pandas>=2.0.0
|
| 11 |
numpy>=1.24.0
|
| 12 |
tabpfn>=7.1.1
|