Akshay4506 commited on
Commit
30593bd
·
1 Parent(s): e17f3ba

fix: add missing scipy dep, TabPFN weight caching, better error logging

Browse files
code/models/tabpfn_wrapper.py CHANGED
@@ -144,10 +144,26 @@ class TabPFNWrapper(BaseModelWrapper):
144
 
145
  actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else ('cpu' if self.device == 'auto' else self.device)
146
 
147
- if hasattr(tabpfn, '__version__') and tabpfn.__version__.startswith('0.1'):
148
- self.model = TabPFNClassifier(device=actual_device, N_ensemble_configurations=self.n_ensemble)
 
 
149
  else:
150
- self.model = TabPFNClassifier(device=actual_device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  # Fit model
153
  self.model.fit(X, y)
 
144
 
145
  actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else ('cpu' if self.device == 'auto' else self.device)
146
 
147
+ # Reuse a cached classifier if one was injected to avoid reloading weights
148
+ if hasattr(self, '_cached_classifier') and self._cached_classifier is not None:
149
+ self.model = self._cached_classifier
150
+ logger.info("Reusing cached TabPFN classifier (skipping weight reload)")
151
  else:
152
+ if hasattr(tabpfn, '__version__') and tabpfn.__version__.startswith('0.1'):
153
+ self.model = TabPFNClassifier(device=actual_device, N_ensemble_configurations=self.n_ensemble)
154
+ else:
155
+ self.model = TabPFNClassifier(device=actual_device)
156
+
157
+ # Store in global cache for future reuse
158
+ try:
159
+ import benchmark
160
+ benchmark._TABPFN_CACHED_MODEL = self.model
161
+ except (ImportError, AttributeError):
162
+ try:
163
+ import webapp.benchmark as wb
164
+ wb._TABPFN_CACHED_MODEL = self.model
165
+ except (ImportError, AttributeError):
166
+ pass
167
 
168
  # Fit model
169
  self.model.fit(X, y)
webapp/benchmark.py CHANGED
@@ -55,7 +55,15 @@ def _tabpfn(task):
55
  if task != "classification":
56
  raise ValueError("TabPFN only supports classification tasks")
57
  from models.tabpfn_wrapper import TabPFNWrapper
58
- return TabPFNWrapper(task_type=task, random_state=RAND)
 
 
 
 
 
 
 
 
59
 
60
 
61
  class _SAPModel:
 
55
  if task != "classification":
56
  raise ValueError("TabPFN only supports classification tasks")
57
  from models.tabpfn_wrapper import TabPFNWrapper
58
+ wrapper = TabPFNWrapper(task_type=task, random_state=RAND)
59
+ # Re-use the cached TabPFNClassifier model if available to avoid
60
+ # reloading weights on every CV fold (saves ~2s per fold + RAM)
61
+ global _TABPFN_CACHED_MODEL
62
+ if _TABPFN_CACHED_MODEL is not None:
63
+ wrapper._cached_classifier = _TABPFN_CACHED_MODEL
64
+ return wrapper
65
+
66
+ _TABPFN_CACHED_MODEL = None
67
 
68
 
69
  class _SAPModel:
webapp/main.py CHANGED
@@ -212,6 +212,8 @@ async def benchmark(
212
  }
213
 
214
  except Exception as e:
 
 
215
  raise HTTPException(500, f"Benchmarking failed: {e}")
216
 
217
  return JSONResponse(result)
 
212
  }
213
 
214
  except Exception as e:
215
+ import traceback
216
+ traceback.print_exc()
217
  raise HTTPException(500, f"Benchmarking failed: {e}")
218
 
219
  return JSONResponse(result)
webapp/requirements.txt CHANGED
@@ -6,6 +6,7 @@ xgboost>=2.0.0
6
  lightgbm>=4.0.0
7
  catboost>=1.2.0
8
  scikit-learn>=1.3.0
 
9
  pandas>=2.0.0
10
  numpy>=1.24.0
11
  tabpfn>=7.1.1
 
6
  lightgbm>=4.0.0
7
  catboost>=1.2.0
8
  scikit-learn>=1.3.0
9
+ scipy>=1.10.0
10
  pandas>=2.0.0
11
  numpy>=1.24.0
12
  tabpfn>=7.1.1