Spaces:
Sleeping
Sleeping
Commit Β·
725b792
1
Parent(s): fc6a2fc
fix: revert to tabpfn v2, add TABPFN_TOKEN support, and update v2 API usage
Browse files- Dockerfile +0 -8
- code/models/tabpfn_wrapper.py +17 -9
- webapp/requirements.txt +1 -1
Dockerfile
CHANGED
|
@@ -37,14 +37,6 @@ RUN pip install --no-cache-dir -r webapp/requirements.txt
|
|
| 37 |
# Install SAP-RPT-1 OSS directly from GitHub (needed for the real model)
|
| 38 |
RUN pip install --no-cache-dir git+https://github.com/SAP-samples/sap-rpt-1-oss.git
|
| 39 |
|
| 40 |
-
# Pre-accept TabPFN license at build time by writing the marker file TabPFN
|
| 41 |
-
# looks for in the cache directory (covers all versions of the check).
|
| 42 |
-
RUN python -c "\
|
| 43 |
-
import os; \
|
| 44 |
-
os.makedirs(os.path.expanduser('~/.cache/tabpfn'), exist_ok=True); \
|
| 45 |
-
open(os.path.expanduser('~/.cache/tabpfn/license_accepted'), 'w').write('accepted'); \
|
| 46 |
-
print('TabPFN license pre-accepted.')"
|
| 47 |
-
|
| 48 |
# Expose port 7860 (Hugging Face Spaces default port)
|
| 49 |
EXPOSE 7860
|
| 50 |
|
|
|
|
| 37 |
# Install SAP-RPT-1 OSS directly from GitHub (needed for the real model)
|
| 38 |
RUN pip install --no-cache-dir git+https://github.com/SAP-samples/sap-rpt-1-oss.git
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# Expose port 7860 (Hugging Face Spaces default port)
|
| 41 |
EXPOSE 7860
|
| 42 |
|
code/models/tabpfn_wrapper.py
CHANGED
|
@@ -18,8 +18,14 @@ from typing import Optional, Union
|
|
| 18 |
import numpy as np
|
| 19 |
import pandas as pd
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
os.environ["TABPFN_ACCEPT_LICENSE"] = "1"
|
| 24 |
os.environ["TABPFN_LICENSE"] = "accept"
|
| 25 |
os.environ["TABPFN_ACCEPT_TERMS"] = "1"
|
|
@@ -135,32 +141,34 @@ class TabPFNWrapper(BaseModelWrapper):
|
|
| 135 |
try:
|
| 136 |
from tabpfn import TabPFNClassifier
|
| 137 |
|
| 138 |
-
import torch
|
| 139 |
import tabpfn
|
| 140 |
|
| 141 |
-
actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else ('cpu' if self.device == 'auto' else self.device)
|
| 142 |
-
|
| 143 |
# Reuse class-level cached classifier so weights are only loaded ONCE
|
| 144 |
# per process, not once per CV fold.
|
| 145 |
if TabPFNWrapper._shared_classifier is None:
|
| 146 |
logger.info("Creating new TabPFNClassifier and caching at class level...")
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
TabPFNWrapper._shared_classifier = TabPFNClassifier(
|
| 149 |
device=actual_device,
|
| 150 |
N_ensemble_configurations=self.n_ensemble
|
| 151 |
)
|
| 152 |
else:
|
| 153 |
-
|
|
|
|
| 154 |
else:
|
| 155 |
logger.info("Reusing cached TabPFN classifier (weights NOT reloaded).")
|
| 156 |
|
| 157 |
self.model = TabPFNWrapper._shared_classifier
|
| 158 |
|
| 159 |
-
# Fit β
|
| 160 |
try:
|
| 161 |
self.model.fit(X, y, overwrite_warning=True)
|
| 162 |
except TypeError:
|
| 163 |
-
# v2+ doesn't accept overwrite_warning kwarg
|
| 164 |
self.model.fit(X, y)
|
| 165 |
|
| 166 |
self.is_fitted = True
|
|
|
|
| 18 |
import numpy as np
|
| 19 |
import pandas as pd
|
| 20 |
|
| 21 |
+
# ββ TabPFN non-interactive authentication βββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
# For TabPFN v2 (PriorLabs), set TABPFN_TOKEN from the HF Space secret.
|
| 23 |
+
# The user must add TABPFN_TOKEN as a secret in HF Space settings.
|
| 24 |
+
_tabpfn_token = os.environ.get("TABPFN_TOKEN", "")
|
| 25 |
+
if _tabpfn_token:
|
| 26 |
+
os.environ["TABPFN_TOKEN"] = _tabpfn_token # ensure it's set for child processes
|
| 27 |
+
|
| 28 |
+
# Cover all license-acceptance env var names across TabPFN versions.
|
| 29 |
os.environ["TABPFN_ACCEPT_LICENSE"] = "1"
|
| 30 |
os.environ["TABPFN_LICENSE"] = "accept"
|
| 31 |
os.environ["TABPFN_ACCEPT_TERMS"] = "1"
|
|
|
|
| 141 |
try:
|
| 142 |
from tabpfn import TabPFNClassifier
|
| 143 |
|
|
|
|
| 144 |
import tabpfn
|
| 145 |
|
|
|
|
|
|
|
| 146 |
# Reuse class-level cached classifier so weights are only loaded ONCE
|
| 147 |
# per process, not once per CV fold.
|
| 148 |
if TabPFNWrapper._shared_classifier is None:
|
| 149 |
logger.info("Creating new TabPFNClassifier and caching at class level...")
|
| 150 |
+
# TabPFN v2: no device/N_ensemble args; token read from TABPFN_TOKEN env var.
|
| 151 |
+
# TabPFN v0.1.x: needs device + N_ensemble_configurations.
|
| 152 |
+
version = getattr(tabpfn, '__version__', '0')
|
| 153 |
+
if version.startswith('0.1'):
|
| 154 |
+
import torch
|
| 155 |
+
actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else 'cpu'
|
| 156 |
TabPFNWrapper._shared_classifier = TabPFNClassifier(
|
| 157 |
device=actual_device,
|
| 158 |
N_ensemble_configurations=self.n_ensemble
|
| 159 |
)
|
| 160 |
else:
|
| 161 |
+
# v2+: just instantiate β auth is via TABPFN_TOKEN env var
|
| 162 |
+
TabPFNWrapper._shared_classifier = TabPFNClassifier()
|
| 163 |
else:
|
| 164 |
logger.info("Reusing cached TabPFN classifier (weights NOT reloaded).")
|
| 165 |
|
| 166 |
self.model = TabPFNWrapper._shared_classifier
|
| 167 |
|
| 168 |
+
# Fit β v0.1.x accepts overwrite_warning=True; v2+ does not.
|
| 169 |
try:
|
| 170 |
self.model.fit(X, y, overwrite_warning=True)
|
| 171 |
except TypeError:
|
|
|
|
| 172 |
self.model.fit(X, y)
|
| 173 |
|
| 174 |
self.is_fitted = True
|
webapp/requirements.txt
CHANGED
|
@@ -9,5 +9,5 @@ scikit-learn>=1.3.0
|
|
| 9 |
scipy>=1.10.0
|
| 10 |
pandas>=2.0.0
|
| 11 |
numpy>=1.24.0
|
| 12 |
-
tabpfn=
|
| 13 |
huggingface_hub
|
|
|
|
| 9 |
scipy>=1.10.0
|
| 10 |
pandas>=2.0.0
|
| 11 |
numpy>=1.24.0
|
| 12 |
+
tabpfn>=2.0.0
|
| 13 |
huggingface_hub
|