Spaces:
Running
Running
| """ | |
| TabPFN Wrapper | |
| ============== | |
| Sklearn-compatible wrapper for TabPFN (Tabular Pre-trained Transformers). | |
| TabPFN is a pretrained model for tabular classification using | |
| in-context learning (no training required). | |
| Author: UW MSIM Team | |
| Date: November 2025 | |
| """ | |
| import time | |
| import logging | |
| import os | |
| from typing import Optional, Union | |
| import numpy as np | |
| import pandas as pd | |
| # Automatically accept the TabPFN license to prevent browser/socket crashes on Windows | |
| os.environ["TABPFN_LICENSE"] = "accept" | |
| os.environ["TABPFN_ACCEPT_LICENSE"] = "1" | |
| # ββ Patch for old TabPFN compatibility with newer torch ββββββββββββββββββββββ | |
| try: | |
| import torch.nn.modules.transformer | |
| if not hasattr(torch.nn.modules.transformer, 'Optional'): | |
| import typing | |
| torch.nn.modules.transformer.Optional = typing.Optional | |
| torch.nn.modules.transformer.Any = typing.Any | |
| torch.nn.modules.transformer.Tuple = typing.Tuple | |
| torch.nn.modules.transformer.List = typing.List | |
| except (ImportError, AttributeError): | |
| pass | |
| # ββ Patch for old TabPFN compatibility with newer sklearn ββββββββββββββββββββ | |
| try: | |
| import sklearn.utils.validation | |
| def _patch_validation(func): | |
| from functools import wraps | |
| def wrapper(*args, **kwargs): | |
| if 'force_all_finite' in kwargs: | |
| kwargs['ensure_all_finite'] = kwargs.pop('force_all_finite') | |
| return func(*args, **kwargs) | |
| return wrapper | |
| sklearn.utils.validation.check_X_y = _patch_validation(sklearn.utils.validation.check_X_y) | |
| sklearn.utils.validation.check_array = _patch_validation(sklearn.utils.validation.check_array) | |
| except (ImportError, AttributeError): | |
| pass | |
| from .base_wrapper import BaseModelWrapper | |
| logger = logging.getLogger(__name__) | |
| class TabPFNWrapper(BaseModelWrapper): | |
| """ | |
| TabPFN (Tabular Prior-Fitted Networks) wrapper. | |
| TabPFN uses pretrained transformers for zero-shot tabular prediction. | |
| Works best on datasets with <1000 samples and <100 features. | |
| Parameters | |
| ---------- | |
| task_type : str, default='classification' | |
| Task type (only 'classification' supported by TabPFN) | |
| n_ensemble : int, default=1 | |
| Number of ensemble members | |
| device : str, default='auto' | |
| Device: 'cpu', 'cuda', or 'auto' | |
| random_state : int, default=42 | |
| Random seed | |
| """ | |
| def __init__( | |
| self, | |
| task_type: str = 'classification', | |
| n_ensemble: int = 1, | |
| device: str = 'auto', | |
| random_state: int = 42 | |
| ): | |
| super().__init__(task_type=task_type, random_state=random_state) | |
| if task_type != 'classification': | |
| raise ValueError("TabPFN only supports classification tasks") | |
| self.n_ensemble = n_ensemble | |
| self.device = device | |
| def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]) -> 'TabPFNWrapper': | |
| """ | |
| Fit TabPFN (stores training data for in-context learning). | |
| Parameters | |
| ---------- | |
| X : pd.DataFrame or np.ndarray, shape (n_samples, n_features) | |
| Training features (max 1000 samples, 100 features) | |
| y : pd.Series or np.ndarray, shape (n_samples,) | |
| Training target | |
| Returns | |
| ------- | |
| self : TabPFNWrapper | |
| Fitted model | |
| """ | |
| self._validate_input(X, y) | |
| # Check TabPFN constraints | |
| if X.shape[0] > 1024: | |
| logger.warning(f"TabPFN strictly requires <= 1024 samples to avoid Memory OOM. Subsampling {X.shape[0]} to 1024 samples.") | |
| sample_idx = np.random.RandomState(self.random_state).choice( | |
| len(X), 1024, replace=False | |
| ) | |
| if isinstance(X, pd.DataFrame): | |
| X = X.iloc[sample_idx] | |
| else: | |
| X = X[sample_idx] | |
| if isinstance(y, pd.Series): | |
| y = y.iloc[sample_idx] | |
| else: | |
| y = y[sample_idx] | |
| if X.shape[1] > 100: | |
| logger.warning(f"TabPFN strictly requires <= 100 features. Truncating {X.shape[1]} to 100 features.") | |
| if isinstance(X, pd.DataFrame): | |
| X = X.iloc[:, :100] | |
| else: | |
| X = X[:, :100] | |
| self.truncated_features_ = True | |
| else: | |
| self.truncated_features_ = False | |
| logger.info(f"Fitting TabPFN on {X.shape[0]} samples...") | |
| start_time = time.time() | |
| try: | |
| from tabpfn import TabPFNClassifier | |
| import torch | |
| import tabpfn | |
| actual_device = 'cuda' if (self.device == 'auto' and torch.cuda.is_available()) else ('cpu' if self.device == 'auto' else self.device) | |
| if hasattr(tabpfn, '__version__') and tabpfn.__version__.startswith('0.1'): | |
| self.model = TabPFNClassifier(device=actual_device, N_ensemble_configurations=self.n_ensemble) | |
| else: | |
| self.model = TabPFNClassifier(device=actual_device) | |
| # Fit model | |
| self.model.fit(X, y) | |
| self.is_fitted = True | |
| self.fit_time = time.time() - start_time | |
| logger.info(f"TabPFN fitted in {self.fit_time:.2f} seconds") | |
| except ImportError: | |
| logger.error("TabPFN not installed") | |
| raise ImportError("Install TabPFN with: pip install tabpfn") | |
| except Exception as e: | |
| logger.error(f"Error fitting TabPFN: {e}") | |
| raise | |
| return self | |
| def predict(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray: | |
| """ | |
| Make predictions with TabPFN. | |
| Parameters | |
| ---------- | |
| X : pd.DataFrame or np.ndarray, shape (n_samples, n_features) | |
| Test features | |
| Returns | |
| ------- | |
| predictions : np.ndarray, shape (n_samples,) | |
| Predicted class labels | |
| """ | |
| if not self.is_fitted: | |
| raise ValueError("Model not fitted. Call fit() first.") | |
| self._validate_input(X) | |
| if getattr(self, 'truncated_features_', False) and X.shape[1] > 100: | |
| if isinstance(X, pd.DataFrame): | |
| X = X.iloc[:, :100] | |
| else: | |
| X = X[:, :100] | |
| logger.info(f"Predicting on {X.shape[0]} samples with TabPFN...") | |
| start_time = time.time() | |
| try: | |
| predictions = self.model.predict(X) | |
| self.predict_time = time.time() - start_time | |
| logger.info(f"Predictions complete in {self.predict_time:.2f} seconds") | |
| return predictions | |
| except Exception as e: | |
| logger.error(f"Error during prediction: {e}") | |
| raise | |
| def _predict_proba_impl(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray: | |
| """ | |
| Predict class probabilities with TabPFN. | |
| Parameters | |
| ---------- | |
| X : pd.DataFrame or np.ndarray, shape (n_samples, n_features) | |
| Test features | |
| Returns | |
| ------- | |
| probabilities : np.ndarray, shape (n_samples, n_classes) | |
| Class probabilities | |
| """ | |
| if getattr(self, 'truncated_features_', False) and X.shape[1] > 100: | |
| if isinstance(X, pd.DataFrame): | |
| X = X.iloc[:, :100] | |
| else: | |
| X = X[:, :100] | |
| return self.model.predict_proba(X) | |
| def get_params(self, deep: bool = True) -> dict: | |
| """Get parameters for this estimator.""" | |
| params = super().get_params(deep) | |
| params.update({ | |
| 'n_ensemble': self.n_ensemble, | |
| 'device': self.device | |
| }) | |
| return params | |