""" TabPFN Wrapper ============== Sklearn-compatible wrapper for TabPFN (Tabular Pre-trained Transformers). TabPFN is a pretrained model for tabular classification using in-context learning (no training required). Author: UW MSIM Team Date: November 2025 """ import time import logging import os from typing import Optional, Union import numpy as np import pandas as pd # ── TabPFN non-interactive authentication ───────────────────────────────────── # For TabPFN v2 (PriorLabs), set TABPFN_TOKEN from the HF Space secret. # The user must add TABPFN_TOKEN as a secret in HF Space settings. _tabpfn_token = os.environ.get("TABPFN_TOKEN", "") if _tabpfn_token: os.environ["TABPFN_TOKEN"] = _tabpfn_token # ensure it's set for child processes # Cover all license-acceptance env var names across TabPFN versions. os.environ["TABPFN_ACCEPT_LICENSE"] = "1" os.environ["TABPFN_LICENSE"] = "accept" os.environ["TABPFN_ACCEPT_TERMS"] = "1" os.environ["TABPFN_LICENSE_ACCEPTED"] = "1" os.environ["AGREE_TABPFN_LICENSE"] = "1" # ── Patch for old TabPFN compatibility with newer torch ────────────────────── try: import torch.nn.modules.transformer if not hasattr(torch.nn.modules.transformer, 'Optional'): import typing torch.nn.modules.transformer.Optional = typing.Optional torch.nn.modules.transformer.Any = typing.Any torch.nn.modules.transformer.Tuple = typing.Tuple torch.nn.modules.transformer.List = typing.List except (ImportError, AttributeError): pass # ── Patch for old TabPFN compatibility with newer sklearn ──────────────────── try: import sklearn.utils.validation def _patch_validation(func): from functools import wraps @wraps(func) def wrapper(*args, **kwargs): if 'force_all_finite' in kwargs: kwargs['ensure_all_finite'] = kwargs.pop('force_all_finite') return func(*args, **kwargs) return wrapper sklearn.utils.validation.check_X_y = _patch_validation(sklearn.utils.validation.check_X_y) sklearn.utils.validation.check_array = _patch_validation(sklearn.utils.validation.check_array) except (ImportError, AttributeError): pass from .base_wrapper import BaseModelWrapper logger = logging.getLogger(__name__) class TabPFNWrapper(BaseModelWrapper): """ TabPFN (Tabular Prior-Fitted Networks) wrapper. TabPFN uses pretrained transformers for zero-shot tabular prediction. Works best on datasets with <1000 samples and <100 features. Parameters ---------- task_type : str, default='classification' Task type (only 'classification' supported by TabPFN) n_ensemble : int, default=1 Number of ensemble members device : str, default='auto' Device: 'cpu', 'cuda', or 'auto' random_state : int, default=42 Random seed """ # Class-level cache: weights are loaded once and shared across ALL instances # in the same process. This prevents reloading 103 weight files on every CV fold. _shared_classifier = None def __init__( self, task_type: str = 'classification', n_ensemble: int = 1, device: str = 'auto', random_state: int = 42 ): super().__init__(task_type=task_type, random_state=random_state) if task_type != 'classification': raise ValueError("TabPFN only supports classification tasks") self.n_ensemble = n_ensemble self.device = device def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]) -> 'TabPFNWrapper': """ Fit TabPFN (stores training data for in-context learning). """ self._validate_input(X, y) # Check TabPFN constraints if X.shape[0] > 1024: logger.warning(f"TabPFN strictly requires <= 1024 samples to avoid Memory OOM. Subsampling {X.shape[0]} to 1024 samples.") sample_idx = np.random.RandomState(self.random_state).choice( len(X), 1024, replace=False ) if isinstance(X, pd.DataFrame): X = X.iloc[sample_idx] else: X = X[sample_idx] if isinstance(y, pd.Series): y = y.iloc[sample_idx] else: y = y[sample_idx] if X.shape[1] > 100: logger.warning(f"TabPFN strictly requires <= 100 features. Truncating {X.shape[1]} to 100 features.") if isinstance(X, pd.DataFrame): X = X.iloc[:, :100] else: X = X[:, :100] self.truncated_features_ = True else: self.truncated_features_ = False logger.info(f"Fitting TabPFN on {X.shape[0]} samples...") start_time = time.time() try: from tabpfn import TabPFNClassifier import tabpfn # Reuse class-level cached classifier so weights are only loaded ONCE # per process, not once per CV fold. if TabPFNWrapper._shared_classifier is None: logger.info("Creating new TabPFNClassifier and caching at class level...") # TabPFN v2: no device/N_ensemble args; token read from TABPFN_TOKEN env var. # TabPFN v0.1.x: needs device + N_ensemble_configurations. version = getattr(tabpfn, '__version__', '0') if version.startswith('0.1'): import torch actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else 'cpu' TabPFNWrapper._shared_classifier = TabPFNClassifier( device=actual_device, N_ensemble_configurations=self.n_ensemble ) else: # v2+: just instantiate — auth is via TABPFN_TOKEN env var TabPFNWrapper._shared_classifier = TabPFNClassifier() else: logger.info("Reusing cached TabPFN classifier (weights NOT reloaded).") self.model = TabPFNWrapper._shared_classifier # Fit — v0.1.x accepts overwrite_warning=True; v2+ does not. try: self.model.fit(X, y, overwrite_warning=True) except TypeError: self.model.fit(X, y) self.is_fitted = True self.fit_time = time.time() - start_time logger.info(f"TabPFN fitted in {self.fit_time:.2f} seconds") except ImportError: logger.error("TabPFN not installed") raise ImportError("Install TabPFN with: pip install tabpfn") except Exception as e: logger.error(f"Error fitting TabPFN: {e}") raise return self def predict(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray: """ Make predictions with TabPFN. Parameters ---------- X : pd.DataFrame or np.ndarray, shape (n_samples, n_features) Test features Returns ------- predictions : np.ndarray, shape (n_samples,) Predicted class labels """ if not self.is_fitted: raise ValueError("Model not fitted. Call fit() first.") self._validate_input(X) if getattr(self, 'truncated_features_', False) and X.shape[1] > 100: if isinstance(X, pd.DataFrame): X = X.iloc[:, :100] else: X = X[:, :100] logger.info(f"Predicting on {X.shape[0]} samples with TabPFN...") start_time = time.time() try: predictions = self.model.predict(X) self.predict_time = time.time() - start_time logger.info(f"Predictions complete in {self.predict_time:.2f} seconds") return predictions except Exception as e: logger.error(f"Error during prediction: {e}") raise def _predict_proba_impl(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray: """ Predict class probabilities with TabPFN. Parameters ---------- X : pd.DataFrame or np.ndarray, shape (n_samples, n_features) Test features Returns ------- probabilities : np.ndarray, shape (n_samples, n_classes) Class probabilities """ if getattr(self, 'truncated_features_', False) and X.shape[1] > 100: if isinstance(X, pd.DataFrame): X = X.iloc[:, :100] else: X = X[:, :100] return self.model.predict_proba(X) def get_params(self, deep: bool = True) -> dict: """Get parameters for this estimator.""" params = super().get_params(deep) params.update({ 'n_ensemble': self.n_ensemble, 'device': self.device }) return params