ModelMatrix / code /models /tabpfn_wrapper.py
Akshay4506's picture
fix: revert to tabpfn v2, add TABPFN_TOKEN support, and update v2 API usage
725b792
raw
history blame
9.13 kB
"""
TabPFN Wrapper
==============
Sklearn-compatible wrapper for TabPFN (Tabular Pre-trained Transformers).
TabPFN is a pretrained model for tabular classification using
in-context learning (no training required).
Author: UW MSIM Team
Date: November 2025
"""
import time
import logging
import os
from typing import Optional, Union
import numpy as np
import pandas as pd
# ── TabPFN non-interactive authentication ─────────────────────────────────────
# For TabPFN v2 (PriorLabs), set TABPFN_TOKEN from the HF Space secret.
# The user must add TABPFN_TOKEN as a secret in HF Space settings.
_tabpfn_token = os.environ.get("TABPFN_TOKEN", "")
if _tabpfn_token:
os.environ["TABPFN_TOKEN"] = _tabpfn_token # ensure it's set for child processes
# Cover all license-acceptance env var names across TabPFN versions.
os.environ["TABPFN_ACCEPT_LICENSE"] = "1"
os.environ["TABPFN_LICENSE"] = "accept"
os.environ["TABPFN_ACCEPT_TERMS"] = "1"
os.environ["TABPFN_LICENSE_ACCEPTED"] = "1"
os.environ["AGREE_TABPFN_LICENSE"] = "1"
# ── Patch for old TabPFN compatibility with newer torch ──────────────────────
try:
import torch.nn.modules.transformer
if not hasattr(torch.nn.modules.transformer, 'Optional'):
import typing
torch.nn.modules.transformer.Optional = typing.Optional
torch.nn.modules.transformer.Any = typing.Any
torch.nn.modules.transformer.Tuple = typing.Tuple
torch.nn.modules.transformer.List = typing.List
except (ImportError, AttributeError):
pass
# ── Patch for old TabPFN compatibility with newer sklearn ────────────────────
try:
import sklearn.utils.validation
def _patch_validation(func):
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
if 'force_all_finite' in kwargs:
kwargs['ensure_all_finite'] = kwargs.pop('force_all_finite')
return func(*args, **kwargs)
return wrapper
sklearn.utils.validation.check_X_y = _patch_validation(sklearn.utils.validation.check_X_y)
sklearn.utils.validation.check_array = _patch_validation(sklearn.utils.validation.check_array)
except (ImportError, AttributeError):
pass
from .base_wrapper import BaseModelWrapper
logger = logging.getLogger(__name__)
class TabPFNWrapper(BaseModelWrapper):
"""
TabPFN (Tabular Prior-Fitted Networks) wrapper.
TabPFN uses pretrained transformers for zero-shot tabular prediction.
Works best on datasets with <1000 samples and <100 features.
Parameters
----------
task_type : str, default='classification'
Task type (only 'classification' supported by TabPFN)
n_ensemble : int, default=1
Number of ensemble members
device : str, default='auto'
Device: 'cpu', 'cuda', or 'auto'
random_state : int, default=42
Random seed
"""
# Class-level cache: weights are loaded once and shared across ALL instances
# in the same process. This prevents reloading 103 weight files on every CV fold.
_shared_classifier = None
def __init__(
self,
task_type: str = 'classification',
n_ensemble: int = 1,
device: str = 'auto',
random_state: int = 42
):
super().__init__(task_type=task_type, random_state=random_state)
if task_type != 'classification':
raise ValueError("TabPFN only supports classification tasks")
self.n_ensemble = n_ensemble
self.device = device
def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]) -> 'TabPFNWrapper':
"""
Fit TabPFN (stores training data for in-context learning).
"""
self._validate_input(X, y)
# Check TabPFN constraints
if X.shape[0] > 1024:
logger.warning(f"TabPFN strictly requires <= 1024 samples to avoid Memory OOM. Subsampling {X.shape[0]} to 1024 samples.")
sample_idx = np.random.RandomState(self.random_state).choice(
len(X), 1024, replace=False
)
if isinstance(X, pd.DataFrame):
X = X.iloc[sample_idx]
else:
X = X[sample_idx]
if isinstance(y, pd.Series):
y = y.iloc[sample_idx]
else:
y = y[sample_idx]
if X.shape[1] > 100:
logger.warning(f"TabPFN strictly requires <= 100 features. Truncating {X.shape[1]} to 100 features.")
if isinstance(X, pd.DataFrame):
X = X.iloc[:, :100]
else:
X = X[:, :100]
self.truncated_features_ = True
else:
self.truncated_features_ = False
logger.info(f"Fitting TabPFN on {X.shape[0]} samples...")
start_time = time.time()
try:
from tabpfn import TabPFNClassifier
import tabpfn
# Reuse class-level cached classifier so weights are only loaded ONCE
# per process, not once per CV fold.
if TabPFNWrapper._shared_classifier is None:
logger.info("Creating new TabPFNClassifier and caching at class level...")
# TabPFN v2: no device/N_ensemble args; token read from TABPFN_TOKEN env var.
# TabPFN v0.1.x: needs device + N_ensemble_configurations.
version = getattr(tabpfn, '__version__', '0')
if version.startswith('0.1'):
import torch
actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else 'cpu'
TabPFNWrapper._shared_classifier = TabPFNClassifier(
device=actual_device,
N_ensemble_configurations=self.n_ensemble
)
else:
# v2+: just instantiate — auth is via TABPFN_TOKEN env var
TabPFNWrapper._shared_classifier = TabPFNClassifier()
else:
logger.info("Reusing cached TabPFN classifier (weights NOT reloaded).")
self.model = TabPFNWrapper._shared_classifier
# Fit — v0.1.x accepts overwrite_warning=True; v2+ does not.
try:
self.model.fit(X, y, overwrite_warning=True)
except TypeError:
self.model.fit(X, y)
self.is_fitted = True
self.fit_time = time.time() - start_time
logger.info(f"TabPFN fitted in {self.fit_time:.2f} seconds")
except ImportError:
logger.error("TabPFN not installed")
raise ImportError("Install TabPFN with: pip install tabpfn")
except Exception as e:
logger.error(f"Error fitting TabPFN: {e}")
raise
return self
def predict(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
"""
Make predictions with TabPFN.
Parameters
----------
X : pd.DataFrame or np.ndarray, shape (n_samples, n_features)
Test features
Returns
-------
predictions : np.ndarray, shape (n_samples,)
Predicted class labels
"""
if not self.is_fitted:
raise ValueError("Model not fitted. Call fit() first.")
self._validate_input(X)
if getattr(self, 'truncated_features_', False) and X.shape[1] > 100:
if isinstance(X, pd.DataFrame):
X = X.iloc[:, :100]
else:
X = X[:, :100]
logger.info(f"Predicting on {X.shape[0]} samples with TabPFN...")
start_time = time.time()
try:
predictions = self.model.predict(X)
self.predict_time = time.time() - start_time
logger.info(f"Predictions complete in {self.predict_time:.2f} seconds")
return predictions
except Exception as e:
logger.error(f"Error during prediction: {e}")
raise
def _predict_proba_impl(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
"""
Predict class probabilities with TabPFN.
Parameters
----------
X : pd.DataFrame or np.ndarray, shape (n_samples, n_features)
Test features
Returns
-------
probabilities : np.ndarray, shape (n_samples, n_classes)
Class probabilities
"""
if getattr(self, 'truncated_features_', False) and X.shape[1] > 100:
if isinstance(X, pd.DataFrame):
X = X.iloc[:, :100]
else:
X = X[:, :100]
return self.model.predict_proba(X)
def get_params(self, deep: bool = True) -> dict:
"""Get parameters for this estimator."""
params = super().get_params(deep)
params.update({
'n_ensemble': self.n_ensemble,
'device': self.device
})
return params