Akshay4506 commited on
Commit
725b792
Β·
1 Parent(s): fc6a2fc

fix: revert to tabpfn v2, add TABPFN_TOKEN support, and update v2 API usage

Browse files
Dockerfile CHANGED
@@ -37,14 +37,6 @@ RUN pip install --no-cache-dir -r webapp/requirements.txt
37
  # Install SAP-RPT-1 OSS directly from GitHub (needed for the real model)
38
  RUN pip install --no-cache-dir git+https://github.com/SAP-samples/sap-rpt-1-oss.git
39
 
40
- # Pre-accept TabPFN license at build time by writing the marker file TabPFN
41
- # looks for in the cache directory (covers all versions of the check).
42
- RUN python -c "\
43
- import os; \
44
- os.makedirs(os.path.expanduser('~/.cache/tabpfn'), exist_ok=True); \
45
- open(os.path.expanduser('~/.cache/tabpfn/license_accepted'), 'w').write('accepted'); \
46
- print('TabPFN license pre-accepted.')"
47
-
48
  # Expose port 7860 (Hugging Face Spaces default port)
49
  EXPOSE 7860
50
 
 
37
  # Install SAP-RPT-1 OSS directly from GitHub (needed for the real model)
38
  RUN pip install --no-cache-dir git+https://github.com/SAP-samples/sap-rpt-1-oss.git
39
 
 
 
 
 
 
 
 
 
40
  # Expose port 7860 (Hugging Face Spaces default port)
41
  EXPOSE 7860
42
 
code/models/tabpfn_wrapper.py CHANGED
@@ -18,8 +18,14 @@ from typing import Optional, Union
18
  import numpy as np
19
  import pandas as pd
20
 
21
- # Automatically accept the TabPFN license (non-interactive / CI environments).
22
- # Different TabPFN versions check different env var names β€” set all of them.
 
 
 
 
 
 
23
  os.environ["TABPFN_ACCEPT_LICENSE"] = "1"
24
  os.environ["TABPFN_LICENSE"] = "accept"
25
  os.environ["TABPFN_ACCEPT_TERMS"] = "1"
@@ -135,32 +141,34 @@ class TabPFNWrapper(BaseModelWrapper):
135
  try:
136
  from tabpfn import TabPFNClassifier
137
 
138
- import torch
139
  import tabpfn
140
 
141
- actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else ('cpu' if self.device == 'auto' else self.device)
142
-
143
  # Reuse class-level cached classifier so weights are only loaded ONCE
144
  # per process, not once per CV fold.
145
  if TabPFNWrapper._shared_classifier is None:
146
  logger.info("Creating new TabPFNClassifier and caching at class level...")
147
- if hasattr(tabpfn, '__version__') and tabpfn.__version__.startswith('0.1'):
 
 
 
 
 
148
  TabPFNWrapper._shared_classifier = TabPFNClassifier(
149
  device=actual_device,
150
  N_ensemble_configurations=self.n_ensemble
151
  )
152
  else:
153
- TabPFNWrapper._shared_classifier = TabPFNClassifier(device=actual_device)
 
154
  else:
155
  logger.info("Reusing cached TabPFN classifier (weights NOT reloaded).")
156
 
157
  self.model = TabPFNWrapper._shared_classifier
158
 
159
- # Fit β€” for v0.1.x, overwrite_warning=True suppresses the data size warning
160
  try:
161
  self.model.fit(X, y, overwrite_warning=True)
162
  except TypeError:
163
- # v2+ doesn't accept overwrite_warning kwarg
164
  self.model.fit(X, y)
165
 
166
  self.is_fitted = True
 
18
  import numpy as np
19
  import pandas as pd
20
 
21
+ # ── TabPFN non-interactive authentication ─────────────────────────────────────
22
+ # For TabPFN v2 (PriorLabs), set TABPFN_TOKEN from the HF Space secret.
23
+ # The user must add TABPFN_TOKEN as a secret in HF Space settings.
24
+ _tabpfn_token = os.environ.get("TABPFN_TOKEN", "")
25
+ if _tabpfn_token:
26
+ os.environ["TABPFN_TOKEN"] = _tabpfn_token # ensure it's set for child processes
27
+
28
+ # Cover all license-acceptance env var names across TabPFN versions.
29
  os.environ["TABPFN_ACCEPT_LICENSE"] = "1"
30
  os.environ["TABPFN_LICENSE"] = "accept"
31
  os.environ["TABPFN_ACCEPT_TERMS"] = "1"
 
141
  try:
142
  from tabpfn import TabPFNClassifier
143
 
 
144
  import tabpfn
145
 
 
 
146
  # Reuse class-level cached classifier so weights are only loaded ONCE
147
  # per process, not once per CV fold.
148
  if TabPFNWrapper._shared_classifier is None:
149
  logger.info("Creating new TabPFNClassifier and caching at class level...")
150
+ # TabPFN v2: no device/N_ensemble args; token read from TABPFN_TOKEN env var.
151
+ # TabPFN v0.1.x: needs device + N_ensemble_configurations.
152
+ version = getattr(tabpfn, '__version__', '0')
153
+ if version.startswith('0.1'):
154
+ import torch
155
+ actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else 'cpu'
156
  TabPFNWrapper._shared_classifier = TabPFNClassifier(
157
  device=actual_device,
158
  N_ensemble_configurations=self.n_ensemble
159
  )
160
  else:
161
+ # v2+: just instantiate β€” auth is via TABPFN_TOKEN env var
162
+ TabPFNWrapper._shared_classifier = TabPFNClassifier()
163
  else:
164
  logger.info("Reusing cached TabPFN classifier (weights NOT reloaded).")
165
 
166
  self.model = TabPFNWrapper._shared_classifier
167
 
168
+ # Fit β€” v0.1.x accepts overwrite_warning=True; v2+ does not.
169
  try:
170
  self.model.fit(X, y, overwrite_warning=True)
171
  except TypeError:
 
172
  self.model.fit(X, y)
173
 
174
  self.is_fitted = True
webapp/requirements.txt CHANGED
@@ -9,5 +9,5 @@ scikit-learn>=1.3.0
9
  scipy>=1.10.0
10
  pandas>=2.0.0
11
  numpy>=1.24.0
12
- tabpfn==0.1.11
13
  huggingface_hub
 
9
  scipy>=1.10.0
10
  pandas>=2.0.0
11
  numpy>=1.24.0
12
+ tabpfn>=2.0.0
13
  huggingface_hub