ModelMatrix / matrix /code /models /tabpfn_wrapper.py
Akshay4506's picture
Fix deployment entry point and merge requirements
c4ff02d
"""
TabPFN Wrapper
==============
Sklearn-compatible wrapper for TabPFN (Tabular Pre-trained Transformers).
TabPFN is a pretrained model for tabular classification using
in-context learning (no training required).
Author: UW MSIM Team
Date: November 2025
"""
import time
import logging
import os
from typing import Optional, Union
import numpy as np
import pandas as pd
# Automatically accept the TabPFN license to prevent browser/socket crashes on Windows
os.environ["TABPFN_LICENSE"] = "accept"
os.environ["TABPFN_ACCEPT_LICENSE"] = "1"
# ── Patch for old TabPFN compatibility with newer torch ──────────────────────
try:
import torch.nn.modules.transformer
if not hasattr(torch.nn.modules.transformer, 'Optional'):
import typing
torch.nn.modules.transformer.Optional = typing.Optional
torch.nn.modules.transformer.Any = typing.Any
torch.nn.modules.transformer.Tuple = typing.Tuple
torch.nn.modules.transformer.List = typing.List
except (ImportError, AttributeError):
pass
# ── Patch for old TabPFN compatibility with newer sklearn ────────────────────
try:
import sklearn.utils.validation
def _patch_validation(func):
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
if 'force_all_finite' in kwargs:
kwargs['ensure_all_finite'] = kwargs.pop('force_all_finite')
return func(*args, **kwargs)
return wrapper
sklearn.utils.validation.check_X_y = _patch_validation(sklearn.utils.validation.check_X_y)
sklearn.utils.validation.check_array = _patch_validation(sklearn.utils.validation.check_array)
except (ImportError, AttributeError):
pass
from .base_wrapper import BaseModelWrapper
logger = logging.getLogger(__name__)
class TabPFNWrapper(BaseModelWrapper):
"""
TabPFN (Tabular Prior-Fitted Networks) wrapper.
TabPFN uses pretrained transformers for zero-shot tabular prediction.
Works best on datasets with <1000 samples and <100 features.
Parameters
----------
task_type : str, default='classification'
Task type (only 'classification' supported by TabPFN)
n_ensemble : int, default=1
Number of ensemble members
device : str, default='auto'
Device: 'cpu', 'cuda', or 'auto'
random_state : int, default=42
Random seed
"""
def __init__(
self,
task_type: str = 'classification',
n_ensemble: int = 1,
device: str = 'auto',
random_state: int = 42
):
super().__init__(task_type=task_type, random_state=random_state)
if task_type != 'classification':
raise ValueError("TabPFN only supports classification tasks")
self.n_ensemble = n_ensemble
self.device = device
def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]) -> 'TabPFNWrapper':
"""
Fit TabPFN (stores training data for in-context learning).
Parameters
----------
X : pd.DataFrame or np.ndarray, shape (n_samples, n_features)
Training features (max 1000 samples, 100 features)
y : pd.Series or np.ndarray, shape (n_samples,)
Training target
Returns
-------
self : TabPFNWrapper
Fitted model
"""
self._validate_input(X, y)
# Check TabPFN constraints
if X.shape[0] > 1024:
logger.warning(f"TabPFN strictly requires <= 1024 samples to avoid Memory OOM. Subsampling {X.shape[0]} to 1024 samples.")
sample_idx = np.random.RandomState(self.random_state).choice(
len(X), 1024, replace=False
)
if isinstance(X, pd.DataFrame):
X = X.iloc[sample_idx]
else:
X = X[sample_idx]
if isinstance(y, pd.Series):
y = y.iloc[sample_idx]
else:
y = y[sample_idx]
if X.shape[1] > 100:
logger.warning(f"TabPFN strictly requires <= 100 features. Truncating {X.shape[1]} to 100 features.")
if isinstance(X, pd.DataFrame):
X = X.iloc[:, :100]
else:
X = X[:, :100]
self.truncated_features_ = True
else:
self.truncated_features_ = False
logger.info(f"Fitting TabPFN on {X.shape[0]} samples...")
start_time = time.time()
try:
from tabpfn import TabPFNClassifier
import torch
import tabpfn
actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else ('cpu' if self.device == 'auto' else self.device)
if hasattr(tabpfn, '__version__') and tabpfn.__version__.startswith('0.1'):
self.model = TabPFNClassifier(device=actual_device, N_ensemble_configurations=self.n_ensemble)
else:
self.model = TabPFNClassifier(device=actual_device)
# Fit model
self.model.fit(X, y)
self.is_fitted = True
self.fit_time = time.time() - start_time
logger.info(f"TabPFN fitted in {self.fit_time:.2f} seconds")
except ImportError:
logger.error("TabPFN not installed")
raise ImportError("Install TabPFN with: pip install tabpfn")
except Exception as e:
logger.error(f"Error fitting TabPFN: {e}")
raise
return self
def predict(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
"""
Make predictions with TabPFN.
Parameters
----------
X : pd.DataFrame or np.ndarray, shape (n_samples, n_features)
Test features
Returns
-------
predictions : np.ndarray, shape (n_samples,)
Predicted class labels
"""
if not self.is_fitted:
raise ValueError("Model not fitted. Call fit() first.")
self._validate_input(X)
if getattr(self, 'truncated_features_', False) and X.shape[1] > 100:
if isinstance(X, pd.DataFrame):
X = X.iloc[:, :100]
else:
X = X[:, :100]
logger.info(f"Predicting on {X.shape[0]} samples with TabPFN...")
start_time = time.time()
try:
predictions = self.model.predict(X)
self.predict_time = time.time() - start_time
logger.info(f"Predictions complete in {self.predict_time:.2f} seconds")
return predictions
except Exception as e:
logger.error(f"Error during prediction: {e}")
raise
def _predict_proba_impl(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
"""
Predict class probabilities with TabPFN.
Parameters
----------
X : pd.DataFrame or np.ndarray, shape (n_samples, n_features)
Test features
Returns
-------
probabilities : np.ndarray, shape (n_samples, n_classes)
Class probabilities
"""
if getattr(self, 'truncated_features_', False) and X.shape[1] > 100:
if isinstance(X, pd.DataFrame):
X = X.iloc[:, :100]
else:
X = X[:, :100]
return self.model.predict_proba(X)
def get_params(self, deep: bool = True) -> dict:
"""Get parameters for this estimator."""
params = super().get_params(deep)
params.update({
'n_ensemble': self.n_ensemble,
'device': self.device
})
return params