Spaces:
Sleeping
Sleeping
| """ | |
| TabPFN Wrapper | |
| ============== | |
| Sklearn-compatible wrapper for TabPFN (Tabular Pre-trained Transformers). | |
| TabPFN is a pretrained model for tabular classification using | |
| in-context learning (no training required). | |
| Author: UW MSIM Team | |
| Date: November 2025 | |
| """ | |
| import time | |
| import logging | |
| import os | |
| from typing import Optional, Union | |
| import numpy as np | |
| import pandas as pd | |
| # ── TabPFN non-interactive authentication ───────────────────────────────────── | |
| # For TabPFN v2 (PriorLabs), set TABPFN_TOKEN from the HF Space secret. | |
| # The user must add TABPFN_TOKEN as a secret in HF Space settings. | |
| _tabpfn_token = os.environ.get("TABPFN_TOKEN", "") | |
| if _tabpfn_token: | |
| os.environ["TABPFN_TOKEN"] = _tabpfn_token # ensure it's set for child processes | |
| # Cover all license-acceptance env var names across TabPFN versions. | |
| os.environ["TABPFN_ACCEPT_LICENSE"] = "1" | |
| os.environ["TABPFN_LICENSE"] = "accept" | |
| os.environ["TABPFN_ACCEPT_TERMS"] = "1" | |
| os.environ["TABPFN_LICENSE_ACCEPTED"] = "1" | |
| os.environ["AGREE_TABPFN_LICENSE"] = "1" | |
| # ── Patch for old TabPFN compatibility with newer torch ────────────────────── | |
| try: | |
| import torch.nn.modules.transformer | |
| if not hasattr(torch.nn.modules.transformer, 'Optional'): | |
| import typing | |
| torch.nn.modules.transformer.Optional = typing.Optional | |
| torch.nn.modules.transformer.Any = typing.Any | |
| torch.nn.modules.transformer.Tuple = typing.Tuple | |
| torch.nn.modules.transformer.List = typing.List | |
| except (ImportError, AttributeError): | |
| pass | |
| # ── Patch for old TabPFN compatibility with newer sklearn ──────────────────── | |
| try: | |
| import sklearn.utils.validation | |
| def _patch_validation(func): | |
| from functools import wraps | |
| def wrapper(*args, **kwargs): | |
| if 'force_all_finite' in kwargs: | |
| kwargs['ensure_all_finite'] = kwargs.pop('force_all_finite') | |
| return func(*args, **kwargs) | |
| return wrapper | |
| sklearn.utils.validation.check_X_y = _patch_validation(sklearn.utils.validation.check_X_y) | |
| sklearn.utils.validation.check_array = _patch_validation(sklearn.utils.validation.check_array) | |
| except (ImportError, AttributeError): | |
| pass | |
| from .base_wrapper import BaseModelWrapper | |
| logger = logging.getLogger(__name__) | |
| class TabPFNWrapper(BaseModelWrapper): | |
| """ | |
| TabPFN (Tabular Prior-Fitted Networks) wrapper. | |
| TabPFN uses pretrained transformers for zero-shot tabular prediction. | |
| Works best on datasets with <1000 samples and <100 features. | |
| Parameters | |
| ---------- | |
| task_type : str, default='classification' | |
| Task type (only 'classification' supported by TabPFN) | |
| n_ensemble : int, default=1 | |
| Number of ensemble members | |
| device : str, default='auto' | |
| Device: 'cpu', 'cuda', or 'auto' | |
| random_state : int, default=42 | |
| Random seed | |
| """ | |
| # Class-level cache: weights are loaded once and shared across ALL instances | |
| # in the same process. This prevents reloading 103 weight files on every CV fold. | |
| _shared_classifier = None | |
| def __init__( | |
| self, | |
| task_type: str = 'classification', | |
| n_ensemble: int = 1, | |
| device: str = 'auto', | |
| random_state: int = 42 | |
| ): | |
| super().__init__(task_type=task_type, random_state=random_state) | |
| if task_type != 'classification': | |
| raise ValueError("TabPFN only supports classification tasks") | |
| self.n_ensemble = n_ensemble | |
| self.device = device | |
| def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]) -> 'TabPFNWrapper': | |
| """ | |
| Fit TabPFN (stores training data for in-context learning). | |
| """ | |
| self._validate_input(X, y) | |
| # Check TabPFN constraints | |
| if X.shape[0] > 1024: | |
| logger.warning(f"TabPFN strictly requires <= 1024 samples to avoid Memory OOM. Subsampling {X.shape[0]} to 1024 samples.") | |
| sample_idx = np.random.RandomState(self.random_state).choice( | |
| len(X), 1024, replace=False | |
| ) | |
| if isinstance(X, pd.DataFrame): | |
| X = X.iloc[sample_idx] | |
| else: | |
| X = X[sample_idx] | |
| if isinstance(y, pd.Series): | |
| y = y.iloc[sample_idx] | |
| else: | |
| y = y[sample_idx] | |
| if X.shape[1] > 100: | |
| logger.warning(f"TabPFN strictly requires <= 100 features. Truncating {X.shape[1]} to 100 features.") | |
| if isinstance(X, pd.DataFrame): | |
| X = X.iloc[:, :100] | |
| else: | |
| X = X[:, :100] | |
| self.truncated_features_ = True | |
| else: | |
| self.truncated_features_ = False | |
| logger.info(f"Fitting TabPFN on {X.shape[0]} samples...") | |
| start_time = time.time() | |
| try: | |
| from tabpfn import TabPFNClassifier | |
| import tabpfn | |
| # Reuse class-level cached classifier so weights are only loaded ONCE | |
| # per process, not once per CV fold. | |
| if TabPFNWrapper._shared_classifier is None: | |
| logger.info("Creating new TabPFNClassifier and caching at class level...") | |
| # TabPFN v2: no device/N_ensemble args; token read from TABPFN_TOKEN env var. | |
| # TabPFN v0.1.x: needs device + N_ensemble_configurations. | |
| version = getattr(tabpfn, '__version__', '0') | |
| if version.startswith('0.1'): | |
| import torch | |
| actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else 'cpu' | |
| TabPFNWrapper._shared_classifier = TabPFNClassifier( | |
| device=actual_device, | |
| N_ensemble_configurations=self.n_ensemble | |
| ) | |
| else: | |
| # v2+: just instantiate — auth is via TABPFN_TOKEN env var | |
| TabPFNWrapper._shared_classifier = TabPFNClassifier() | |
| else: | |
| logger.info("Reusing cached TabPFN classifier (weights NOT reloaded).") | |
| self.model = TabPFNWrapper._shared_classifier | |
| # Fit — v0.1.x accepts overwrite_warning=True; v2+ does not. | |
| try: | |
| self.model.fit(X, y, overwrite_warning=True) | |
| except TypeError: | |
| self.model.fit(X, y) | |
| self.is_fitted = True | |
| self.fit_time = time.time() - start_time | |
| logger.info(f"TabPFN fitted in {self.fit_time:.2f} seconds") | |
| except ImportError: | |
| logger.error("TabPFN not installed") | |
| raise ImportError("Install TabPFN with: pip install tabpfn") | |
| except Exception as e: | |
| logger.error(f"Error fitting TabPFN: {e}") | |
| raise | |
| return self | |
| def predict(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray: | |
| """ | |
| Make predictions with TabPFN. | |
| Parameters | |
| ---------- | |
| X : pd.DataFrame or np.ndarray, shape (n_samples, n_features) | |
| Test features | |
| Returns | |
| ------- | |
| predictions : np.ndarray, shape (n_samples,) | |
| Predicted class labels | |
| """ | |
| if not self.is_fitted: | |
| raise ValueError("Model not fitted. Call fit() first.") | |
| self._validate_input(X) | |
| if getattr(self, 'truncated_features_', False) and X.shape[1] > 100: | |
| if isinstance(X, pd.DataFrame): | |
| X = X.iloc[:, :100] | |
| else: | |
| X = X[:, :100] | |
| logger.info(f"Predicting on {X.shape[0]} samples with TabPFN...") | |
| start_time = time.time() | |
| try: | |
| predictions = self.model.predict(X) | |
| self.predict_time = time.time() - start_time | |
| logger.info(f"Predictions complete in {self.predict_time:.2f} seconds") | |
| return predictions | |
| except Exception as e: | |
| logger.error(f"Error during prediction: {e}") | |
| raise | |
| def _predict_proba_impl(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray: | |
| """ | |
| Predict class probabilities with TabPFN. | |
| Parameters | |
| ---------- | |
| X : pd.DataFrame or np.ndarray, shape (n_samples, n_features) | |
| Test features | |
| Returns | |
| ------- | |
| probabilities : np.ndarray, shape (n_samples, n_classes) | |
| Class probabilities | |
| """ | |
| if getattr(self, 'truncated_features_', False) and X.shape[1] > 100: | |
| if isinstance(X, pd.DataFrame): | |
| X = X.iloc[:, :100] | |
| else: | |
| X = X[:, :100] | |
| return self.model.predict_proba(X) | |
| def get_params(self, deep: bool = True) -> dict: | |
| """Get parameters for this estimator.""" | |
| params = super().get_params(deep) | |
| params.update({ | |
| 'n_ensemble': self.n_ensemble, | |
| 'device': self.device | |
| }) | |
| return params | |