Akshay4506 commited on
Commit
fc6a2fc
·
1 Parent(s): 6f18cb2

fix: downgrade tabpfn to 0.1.11, class-level weight cache, fix ravel on list

Browse files
code/models/tabpfn_wrapper.py CHANGED
@@ -78,6 +78,10 @@ class TabPFNWrapper(BaseModelWrapper):
78
  Random seed
79
  """
80
 
 
 
 
 
81
  def __init__(
82
  self,
83
  task_type: str = 'classification',
@@ -96,18 +100,6 @@ class TabPFNWrapper(BaseModelWrapper):
96
  def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]) -> 'TabPFNWrapper':
97
  """
98
  Fit TabPFN (stores training data for in-context learning).
99
-
100
- Parameters
101
- ----------
102
- X : pd.DataFrame or np.ndarray, shape (n_samples, n_features)
103
- Training features (max 1000 samples, 100 features)
104
- y : pd.Series or np.ndarray, shape (n_samples,)
105
- Training target
106
-
107
- Returns
108
- -------
109
- self : TabPFNWrapper
110
- Fitted model
111
  """
112
  self._validate_input(X, y)
113
 
@@ -145,32 +137,31 @@ class TabPFNWrapper(BaseModelWrapper):
145
 
146
  import torch
147
  import tabpfn
148
-
149
  actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else ('cpu' if self.device == 'auto' else self.device)
150
-
151
- # Reuse a cached classifier if one was injected to avoid reloading weights
152
- if hasattr(self, '_cached_classifier') and self._cached_classifier is not None:
153
- self.model = self._cached_classifier
154
- logger.info("Reusing cached TabPFN classifier (skipping weight reload)")
155
- else:
156
  if hasattr(tabpfn, '__version__') and tabpfn.__version__.startswith('0.1'):
157
- self.model = TabPFNClassifier(device=actual_device, N_ensemble_configurations=self.n_ensemble)
 
 
 
158
  else:
159
- self.model = TabPFNClassifier(device=actual_device)
160
-
161
- # Store in global cache for future reuse
162
- try:
163
- import benchmark
164
- benchmark._TABPFN_CACHED_MODEL = self.model
165
- except (ImportError, AttributeError):
166
- try:
167
- import webapp.benchmark as wb
168
- wb._TABPFN_CACHED_MODEL = self.model
169
- except (ImportError, AttributeError):
170
- pass
171
-
172
- # Fit model
173
- self.model.fit(X, y)
174
 
175
  self.is_fitted = True
176
  self.fit_time = time.time() - start_time
 
78
  Random seed
79
  """
80
 
81
+ # Class-level cache: weights are loaded once and shared across ALL instances
82
+ # in the same process. This prevents reloading 103 weight files on every CV fold.
83
+ _shared_classifier = None
84
+
85
  def __init__(
86
  self,
87
  task_type: str = 'classification',
 
100
  def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]) -> 'TabPFNWrapper':
101
  """
102
  Fit TabPFN (stores training data for in-context learning).
 
 
 
 
 
 
 
 
 
 
 
 
103
  """
104
  self._validate_input(X, y)
105
 
 
137
 
138
  import torch
139
  import tabpfn
140
+
141
  actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else ('cpu' if self.device == 'auto' else self.device)
142
+
143
+ # Reuse class-level cached classifier so weights are only loaded ONCE
144
+ # per process, not once per CV fold.
145
+ if TabPFNWrapper._shared_classifier is None:
146
+ logger.info("Creating new TabPFNClassifier and caching at class level...")
 
147
  if hasattr(tabpfn, '__version__') and tabpfn.__version__.startswith('0.1'):
148
+ TabPFNWrapper._shared_classifier = TabPFNClassifier(
149
+ device=actual_device,
150
+ N_ensemble_configurations=self.n_ensemble
151
+ )
152
  else:
153
+ TabPFNWrapper._shared_classifier = TabPFNClassifier(device=actual_device)
154
+ else:
155
+ logger.info("Reusing cached TabPFN classifier (weights NOT reloaded).")
156
+
157
+ self.model = TabPFNWrapper._shared_classifier
158
+
159
+ # Fit — for v0.1.x, overwrite_warning=True suppresses the data size warning
160
+ try:
161
+ self.model.fit(X, y, overwrite_warning=True)
162
+ except TypeError:
163
+ # v2+ doesn't accept overwrite_warning kwarg
164
+ self.model.fit(X, y)
 
 
 
165
 
166
  self.is_fitted = True
167
  self.fit_time = time.time() - start_time
webapp/benchmark.py CHANGED
@@ -63,15 +63,10 @@ def _tabpfn(task):
63
  if task != "classification":
64
  raise ValueError("TabPFN only supports classification tasks")
65
  from models.tabpfn_wrapper import TabPFNWrapper
66
- wrapper = TabPFNWrapper(task_type=task, random_state=RAND)
67
- # Re-use the cached TabPFNClassifier model if available to avoid
68
- # reloading weights on every CV fold (saves ~2s per fold + RAM)
69
- global _TABPFN_CACHED_MODEL
70
- if _TABPFN_CACHED_MODEL is not None:
71
- wrapper._cached_classifier = _TABPFN_CACHED_MODEL
72
- return wrapper
73
-
74
- _TABPFN_CACHED_MODEL = None
75
 
76
 
77
  class _SAPModel:
 
63
  if task != "classification":
64
  raise ValueError("TabPFN only supports classification tasks")
65
  from models.tabpfn_wrapper import TabPFNWrapper
66
+ # TabPFNWrapper uses a class-level _shared_classifier so weights are only
67
+ # loaded once per process regardless of how many instances are created.
68
+ return TabPFNWrapper(task_type=task, random_state=RAND)
69
+
 
 
 
 
 
70
 
71
 
72
  class _SAPModel:
webapp/main.py CHANGED
@@ -245,7 +245,7 @@ async def predict(data: dict):
245
  X_test, _ = _prep(input_df, encoders=CHAMPION_INFO.get("encoders"))
246
 
247
  if CHAMPION_INFO["task"] == "classification":
248
- raw_pred = CHAMPION_MODEL.predict(X_test)
249
  # Flatten if nested (CatBoost/Sklearn sometimes return [[val]] or [val])
250
  pred_val = raw_pred.ravel()[0]
251
  pred_idx = int(pred_val)
@@ -266,7 +266,7 @@ async def predict(data: dict):
266
  "labels": CHAMPION_INFO["labels"]
267
  }
268
  else:
269
- raw_pred = CHAMPION_MODEL.predict(X_test)
270
  pred = float(raw_pred.ravel()[0])
271
  return {"prediction": pred}
272
 
 
245
  X_test, _ = _prep(input_df, encoders=CHAMPION_INFO.get("encoders"))
246
 
247
  if CHAMPION_INFO["task"] == "classification":
248
+ raw_pred = np.array(CHAMPION_MODEL.predict(X_test))
249
  # Flatten if nested (CatBoost/Sklearn sometimes return [[val]] or [val])
250
  pred_val = raw_pred.ravel()[0]
251
  pred_idx = int(pred_val)
 
266
  "labels": CHAMPION_INFO["labels"]
267
  }
268
  else:
269
+ raw_pred = np.array(CHAMPION_MODEL.predict(X_test))
270
  pred = float(raw_pred.ravel()[0])
271
  return {"prediction": pred}
272
 
webapp/requirements.txt CHANGED
@@ -9,5 +9,5 @@ scikit-learn>=1.3.0
9
  scipy>=1.10.0
10
  pandas>=2.0.0
11
  numpy>=1.24.0
12
- tabpfn>=7.1.1
13
  huggingface_hub
 
9
  scipy>=1.10.0
10
  pandas>=2.0.0
11
  numpy>=1.24.0
12
+ tabpfn==0.1.11
13
  huggingface_hub