| from __future__ import annotations |
|
|
| import importlib |
| import math |
| import traceback |
| import time |
| from dataclasses import dataclass |
| from functools import lru_cache |
| from pathlib import Path |
| from typing import Callable |
|
|
| import gradio as gr |
| import numpy as np |
| import pandas as pd |
| import plotly.express as px |
| import plotly.graph_objects as go |
| from plotly.subplots import make_subplots |
| from sklearn.base import clone |
| from sklearn.compose import ColumnTransformer |
| from sklearn.datasets import fetch_california_housing, fetch_openml, load_breast_cancer, load_diabetes, load_digits, load_iris, load_wine, make_classification |
| from sklearn.dummy import DummyClassifier, DummyRegressor |
| from sklearn.ensemble import ( |
| AdaBoostClassifier, |
| AdaBoostRegressor, |
| ExtraTreesClassifier, |
| ExtraTreesRegressor, |
| GradientBoostingClassifier, |
| GradientBoostingRegressor, |
| HistGradientBoostingClassifier, |
| HistGradientBoostingRegressor, |
| RandomForestClassifier, |
| RandomForestRegressor, |
| ) |
| from sklearn.impute import SimpleImputer |
| from sklearn.linear_model import BayesianRidge, LinearRegression, LogisticRegression, Ridge |
| from sklearn.metrics import ( |
| accuracy_score, |
| f1_score, |
| mean_absolute_error, |
| mean_squared_error, |
| r2_score, |
| roc_auc_score, |
| ) |
| from sklearn.model_selection import train_test_split |
| from sklearn.naive_bayes import GaussianNB |
| from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor |
| from sklearn.pipeline import Pipeline |
| from sklearn.preprocessing import OneHotEncoder, StandardScaler |
| from sklearn.svm import SVC, SVR |
|
|
|
|
| APP_TITLE = "tabBench" |
| TABFM_MODEL_ID = "google/tabfm-1.0.0-pytorch" |
| RANDOM_STATE = 42 |
| CANDY_DATA_URL = "https://raw.githubusercontent.com/fivethirtyeight/data/master/candy-power-ranking/candy-data.csv" |
| GOOGLE_COLORS = ["#4285F4", "#DB4437", "#F4B400", "#0F9D58", "#A142F4", "#00ACC1"] |
| METRIC_CHOICES = ["accuracy", "f1_weighted", "roc_auc", "rmse", "mae", "r2", "seconds"] |
| TABFM_PRESETS = { |
| "Fast": {"n_estimators": 1, "max_num_rows": 256, "max_num_features": 64, "batch_size": 1, "enable_nnls": False, "n_feature_crosses": 0, "n_svd_features": 0, "max_eval_rows": 256}, |
| "Balanced": {"n_estimators": 4, "max_num_rows": 512, "max_num_features": 128, "batch_size": 1, "enable_nnls": False, "n_feature_crosses": 0, "n_svd_features": 0, "max_eval_rows": 512}, |
| "Default": {"n_estimators": 32, "max_num_rows": None, "max_num_features": 500, "batch_size": 1, "enable_nnls": False, "n_feature_crosses": 0, "n_svd_features": 0, "max_eval_rows": 1000}, |
| "Ensemble": {"n_estimators": 32, "max_num_rows": None, "max_num_features": 500, "batch_size": 1, "enable_nnls": True, "n_feature_crosses": "sqrt", "n_svd_features": "sqrt", "max_eval_rows": 1000}, |
| } |
|
|
|
|
| @dataclass(frozen=True) |
| class DatasetSpec: |
| name: str |
| task: str |
| target: str |
| source: str |
| rows: int |
| description: str |
| loader: Callable[[int, int], pd.DataFrame] |
|
|
|
|
| def _add_categorical_noise(df: pd.DataFrame, rng: np.random.Generator, prefix: str) -> pd.DataFrame: |
| df = df.copy() |
| df[f"{prefix}_segment"] = rng.choice(["A", "B", "C", "D"], len(df), p=[0.35, 0.25, 0.25, 0.15]) |
| df[f"{prefix}_region"] = rng.choice(["north", "south", "east", "west"], len(df)) |
| return df |
|
|
|
|
| def sample_df(df: pd.DataFrame, limit: int, seed: int) -> pd.DataFrame: |
| return df.sample(min(limit, len(df)), random_state=seed).reset_index(drop=True) |
|
|
|
|
| def find_first_data_file(root: str | Path, suffixes: tuple[str, ...]) -> Path: |
| root = Path(root) |
| for suffix in suffixes: |
| matches = sorted(root.rglob(f"*{suffix}")) |
| if matches: |
| return matches[0] |
| raise FileNotFoundError(f"No data file with suffixes {suffixes} found in {root}") |
|
|
|
|
| def normalize_columns(df: pd.DataFrame) -> pd.DataFrame: |
| df = df.copy() |
| df.columns = ( |
| pd.Index(df.columns) |
| .astype(str) |
| .str.strip() |
| .str.lower() |
| .str.replace(r"[^0-9a-z]+", "_", regex=True) |
| .str.strip("_") |
| ) |
| return df |
|
|
|
|
| def choose_col(df: pd.DataFrame, candidates: list[str], contains: list[str] | None = None) -> str: |
| normalized = {c.lower(): c for c in df.columns} |
| for candidate in candidates: |
| key = candidate.lower() |
| if key in normalized: |
| return normalized[key] |
| if contains: |
| for col in df.columns: |
| lower = str(col).lower() |
| if all(part in lower for part in contains): |
| return col |
| raise KeyError(f"None of {candidates} found in columns.") |
|
|
|
|
| def kaggle_csv(dataset_id: str, preferred_names: tuple[str, ...] = ()) -> pd.DataFrame: |
| import kagglehub |
|
|
| path = Path(kagglehub.dataset_download(dataset_id)) |
| csvs = sorted(path.rglob("*.csv")) |
| if preferred_names: |
| for preferred in preferred_names: |
| for csv in csvs: |
| if preferred.lower() in csv.name.lower(): |
| return pd.read_csv(csv) |
| if not csvs: |
| raise FileNotFoundError(f"No CSV files found in Kaggle dataset {dataset_id}.") |
| return pd.read_csv(csvs[0]) |
|
|
|
|
| def select_numeric_features(df: pd.DataFrame, target: str, max_features: int = 12) -> pd.DataFrame: |
| numeric = df.select_dtypes(include=np.number).columns.tolist() |
| cols = [c for c in numeric if c != target][:max_features] |
| return df[[*cols, target]].dropna(subset=[target]) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def openml_titanic() -> pd.DataFrame: |
| data = fetch_openml(data_id=40945, as_frame=True, parser="auto") |
| df = data.frame.copy() |
| keep = [c for c in ["pclass", "sex", "age", "sibsp", "parch", "fare", "embarked", "survived"] if c in df.columns] |
| return df[keep].dropna(subset=["survived"]) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def openml_ames_housing() -> pd.DataFrame: |
| data = fetch_openml(data_id=42165, as_frame=True, parser="auto") |
| df = data.frame.copy() |
| target = "SalePrice" if "SalePrice" in df.columns else data.target_names[0] |
| useful = [ |
| "OverallQual", |
| "GrLivArea", |
| "GarageCars", |
| "GarageArea", |
| "TotalBsmtSF", |
| "FullBath", |
| "YearBuilt", |
| "Neighborhood", |
| "HouseStyle", |
| target, |
| ] |
| cols = [c for c in useful if c in df.columns] |
| return df[cols].rename(columns={target: "sale_price"}).dropna(subset=["sale_price"]) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def openml_adult_income() -> pd.DataFrame: |
| data = fetch_openml(data_id=1590, as_frame=True, parser="auto") |
| df = data.frame.copy() |
| if "class" in df.columns: |
| df = df.rename(columns={"class": "income_gt_50k"}) |
| elif "income" in df.columns: |
| df = df.rename(columns={"income": "income_gt_50k"}) |
| return df.dropna(subset=["income_gt_50k"]) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def sklearn_california_housing() -> pd.DataFrame: |
| data = fetch_california_housing(as_frame=True) |
| df = data.frame.rename(columns={"MedHouseVal": "median_house_value"}) |
| return df |
|
|
|
|
| @lru_cache(maxsize=1) |
| def fivethirtyeight_candy() -> pd.DataFrame: |
| df = pd.read_csv(CANDY_DATA_URL) |
| return df.drop(columns=[c for c in ["competitorname"] if c in df.columns]).dropna(subset=["winpercent"]) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def kaggle_credit_card_fraud() -> pd.DataFrame: |
| import kagglehub |
|
|
| path = kagglehub.dataset_download("mlg-ulb/creditcardfraud") |
| csv_path = find_first_data_file(path, (".csv",)) |
| df = pd.read_csv(csv_path) |
| if "Class" in df.columns: |
| df = df.rename(columns={"Class": "fraud"}) |
| return df.dropna(subset=["fraud"]) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def kaggle_epirecipes() -> pd.DataFrame: |
| import kagglehub |
|
|
| path = kagglehub.dataset_download("hugodarwood/epirecipes") |
| try: |
| json_path = find_first_data_file(path, (".json",)) |
| df = pd.read_json(json_path) |
| except FileNotFoundError: |
| csv_path = find_first_data_file(path, (".csv",)) |
| df = pd.read_csv(csv_path) |
| if "rating" not in df.columns: |
| raise ValueError("Epicurious dataset does not include a rating column.") |
| preferred = [ |
| "calories", |
| "protein", |
| "fat", |
| "sodium", |
| "dessert", |
| "dinner", |
| "breakfast", |
| "healthy", |
| "vegetarian", |
| "vegan", |
| "cakeweek", |
| "rating", |
| ] |
| cols = [c for c in preferred if c in df.columns] |
| return df[cols].dropna(subset=["rating"]) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def kaggle_calcofi() -> pd.DataFrame: |
| df = normalize_columns(kaggle_csv("sohier/calcofi", ("bottle",))) |
| target = choose_col(df, ["t_deg_c", "temperature"], ["t", "deg"]) |
| salinity = choose_col(df, ["salnty", "salinity"], ["sal"]) |
| cols = [c for c in [salinity, "depthm", "o2ml_l", "sio3um", "no3um", "po4um", target] if c in df.columns] |
| return df[cols].rename(columns={target: "water_temperature", salinity: "salinity"}).dropna(subset=["water_temperature"]) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def kaggle_szeged_weather() -> pd.DataFrame: |
| df = normalize_columns(kaggle_csv("budincsevity/szeged-weather")) |
| target = choose_col(df, ["apparent_temperature_c", "apparent_temperature"], ["apparent", "temperature"]) |
| cols = [c for c in ["temperature_c", "humidity", "wind_speed_km_h", "wind_bearing_degrees", "visibility_km", "pressure_millibars", "summary", "precip_type", target] if c in df.columns] |
| return df[cols].rename(columns={target: "apparent_temperature"}).dropna(subset=["apparent_temperature"]) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def kaggle_weather_ww2() -> pd.DataFrame: |
| df = normalize_columns(kaggle_csv("smid80/weatherww2", ("summary",))) |
| target = choose_col(df, ["maxtemp", "max_temp", "max"], ["max"]) |
| cols = [c for c in ["mintemp", "meantemp", "precip", "snowfall", "yr", "mo", "da", target] if c in df.columns] |
| return df[cols].rename(columns={target: "max_temperature"}).dropna(subset=["max_temperature"]) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def kaggle_montreal_bike_lanes() -> pd.DataFrame: |
| df = normalize_columns(kaggle_csv("pablomonleon/montreal-bike-lanes")) |
| numeric = df.select_dtypes(include=np.number) |
| if numeric.shape[1] < 2: |
| raise ValueError("Montreal bike lanes dataset needs at least two numeric count columns.") |
| target = numeric.columns[-1] |
| cols = numeric.columns[: min(8, len(numeric.columns))].tolist() |
| if target not in cols: |
| cols.append(target) |
| return df[cols].rename(columns={target: "rider_count"}).dropna(subset=["rider_count"]) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def kaggle_nyc_bike_crossings() -> pd.DataFrame: |
| df = normalize_columns(kaggle_csv("new-york-city/nyc-east-river-bicycle-crossings")) |
| numeric = df.select_dtypes(include=np.number) |
| target = choose_col(numeric, ["total", "total_bicycle_count"], ["total"]) |
| cols = [c for c in numeric.columns[:8] if c != target] |
| return df[[*cols, target]].rename(columns={target: "total_bike_crossings"}).dropna(subset=["total_bike_crossings"]) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def kaggle_uk_road_safety() -> pd.DataFrame: |
| df = normalize_columns(kaggle_csv("bluehorseshoe/uk-2016-road-safety-data", ("accident",))) |
| target = choose_col(df, ["number_of_casualties", "casualties"], ["casual"]) |
| preferred = [ |
| "number_of_vehicles", |
| "day_of_week", |
| "speed_limit", |
| "light_conditions", |
| "weather_conditions", |
| "road_surface_conditions", |
| "urban_or_rural_area", |
| target, |
| ] |
| cols = [c for c in preferred if c in df.columns] |
| return df[cols].rename(columns={target: "casualty_count"}).dropna(subset=["casualty_count"]) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def kaggle_kcbs_bbq() -> pd.DataFrame: |
| df = normalize_columns(kaggle_csv("jaysobel/kcbs-bbq")) |
| numeric = df.select_dtypes(include=np.number) |
| if "place" in df.columns: |
| target = "place" |
| elif "rank" in df.columns: |
| target = "rank" |
| else: |
| target = numeric.columns[0] |
| out = df.copy() |
| out["first_place"] = pd.to_numeric(out[target], errors="coerce").eq(1).astype(int) |
| feature_cols = [c for c in numeric.columns if c != target][:10] |
| categorical_cols = [c for c in out.columns if c not in numeric.columns and c != target][:4] |
| return out[[*feature_cols, *categorical_cols, "first_place"]].dropna(subset=["first_place"]) |
|
|
|
|
| def load_iris_df(limit: int, seed: int) -> pd.DataFrame: |
| data = load_iris(as_frame=True) |
| df = data.frame.rename(columns={"target": "species"}) |
| df["species"] = df["species"].map(dict(enumerate(data.target_names))) |
| return df.sample(min(limit, len(df)), random_state=seed) |
|
|
|
|
| def load_wine_df(limit: int, seed: int) -> pd.DataFrame: |
| data = load_wine(as_frame=True) |
| df = data.frame.rename(columns={"target": "wine_class"}) |
| return df.sample(min(limit, len(df)), random_state=seed) |
|
|
|
|
| def load_breast_cancer_df(limit: int, seed: int) -> pd.DataFrame: |
| data = load_breast_cancer(as_frame=True) |
| df = data.frame.rename(columns={"target": "diagnosis"}) |
| df["diagnosis"] = df["diagnosis"].map({0: "malignant", 1: "benign"}) |
| return df.sample(min(limit, len(df)), random_state=seed) |
|
|
|
|
| def load_digits_df(limit: int, seed: int) -> pd.DataFrame: |
| data = load_digits(as_frame=True) |
| df = data.frame.rename(columns={"target": "digit"}) |
| return df.sample(min(limit, len(df)), random_state=seed) |
|
|
|
|
| def load_diabetes_df(limit: int, seed: int) -> pd.DataFrame: |
| data = load_diabetes(as_frame=True) |
| df = data.frame.rename(columns={"target": "disease_progression"}) |
| return df.sample(min(limit, len(df)), random_state=seed) |
|
|
|
|
| def load_california_housing_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(sklearn_california_housing(), limit, seed) |
| except Exception: |
| return load_synthetic_housing_df(limit, seed).rename(columns={"sale_price": "median_house_value"}) |
|
|
|
|
| def load_ames_housing_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(openml_ames_housing(), limit, seed) |
| except Exception: |
| return load_synthetic_housing_df(limit, seed) |
|
|
|
|
| def load_synthetic_housing_df(limit: int, seed: int) -> pd.DataFrame: |
| rng = np.random.default_rng(seed) |
| n = min(limit, 12000) |
| bedrooms = rng.integers(1, 7, n) |
| sqft = rng.normal(1750, 650, n).clip(450, 5200) |
| age = rng.integers(0, 90, n) |
| zipcode = rng.choice(["94016", "98101", "10011", "60614", "78704", "30309"], n) |
| price = 120000 + sqft * rng.normal(230, 20, n) + bedrooms * 18000 - age * 1400 |
| price += pd.Series(zipcode).map({"94016": 260000, "98101": 140000, "10011": 210000, "60614": 80000, "78704": 110000, "30309": 70000}).to_numpy() |
| price += rng.normal(0, 45000, n) |
| return pd.DataFrame( |
| { |
| "sqft": sqft.round(0), |
| "bedrooms": bedrooms, |
| "home_age": age, |
| "zipcode": zipcode, |
| "has_garage": rng.choice(["yes", "no"], n, p=[0.72, 0.28]), |
| "sale_price": price.round(0), |
| } |
| ) |
|
|
|
|
| def load_titanic_proxy_df(limit: int, seed: int) -> pd.DataFrame: |
| rng = np.random.default_rng(seed) |
| n = min(limit, 891) |
| sex = rng.choice(["female", "male"], n, p=[0.38, 0.62]) |
| pclass = rng.choice([1, 2, 3], n, p=[0.24, 0.21, 0.55]) |
| age = rng.normal(30, 14, n).clip(0.5, 78) |
| fare = np.exp(rng.normal(3.1, 0.85, n)) * (4 - pclass) |
| embarked = rng.choice(["S", "C", "Q"], n, p=[0.72, 0.19, 0.09]) |
| logit = 1.6 * (sex == "female") + 0.9 * (pclass == 1) + 0.25 * (pclass == 2) - 0.025 * age + 0.01 * fare - 1.1 |
| survived = rng.binomial(1, 1 / (1 + np.exp(-logit))) |
| return pd.DataFrame( |
| { |
| "pclass": pclass, |
| "sex": sex, |
| "age": age.round(1), |
| "sibsp": rng.integers(0, 5, n), |
| "parch": rng.integers(0, 4, n), |
| "fare": fare.round(2), |
| "embarked": embarked, |
| "survived": survived, |
| } |
| ) |
|
|
|
|
| def load_titanic_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(openml_titanic(), limit, seed) |
| except Exception: |
| return load_titanic_proxy_df(limit, seed) |
|
|
|
|
| def load_credit_fraud_proxy_df(limit: int, seed: int) -> pd.DataFrame: |
| rng = np.random.default_rng(seed) |
| n = min(limit, 50000) |
| x, y = make_classification( |
| n_samples=n, |
| n_features=18, |
| n_informative=8, |
| n_redundant=4, |
| weights=[0.985, 0.015], |
| class_sep=1.6, |
| random_state=seed, |
| ) |
| df = pd.DataFrame(x, columns=[f"v{i}" for i in range(1, 19)]) |
| df["amount"] = np.exp(rng.normal(3.2, 1.0, n)).round(2) |
| df["merchant_category"] = rng.choice(["travel", "grocery", "electronics", "fuel", "cash"], n) |
| df["fraud"] = y |
| return df |
|
|
|
|
| def load_credit_fraud_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(kaggle_credit_card_fraud(), limit, seed) |
| except Exception: |
| return load_credit_fraud_proxy_df(limit, seed) |
|
|
|
|
| def load_epirecipes_proxy_df(limit: int, seed: int) -> pd.DataFrame: |
| rng = np.random.default_rng(seed) |
| n = min(limit, 20000) |
| calories = rng.gamma(4, 120, n) |
| protein = rng.gamma(2, 12, n) |
| fat = rng.gamma(2.5, 9, n) |
| sodium = rng.gamma(2.4, 180, n) |
| course = rng.choice(["main", "dessert", "side", "salad", "breakfast"], n) |
| cuisine = rng.choice(["american", "italian", "mexican", "asian", "mediterranean"], n) |
| rating = 2.8 + 0.12 * (course == "dessert") + 0.18 * (cuisine == "italian") - 0.00035 * sodium + rng.normal(0, 0.65, n) |
| return pd.DataFrame( |
| { |
| "calories": calories.round(0), |
| "protein": protein.round(1), |
| "fat": fat.round(1), |
| "sodium": sodium.round(0), |
| "course": course, |
| "cuisine": cuisine, |
| "make_again": rng.choice(["yes", "no"], n, p=[0.66, 0.34]), |
| "rating": rating.clip(0, 5).round(2), |
| } |
| ) |
|
|
|
|
| def load_epirecipes_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(kaggle_epirecipes(), limit, seed) |
| except Exception: |
| return load_epirecipes_proxy_df(limit, seed) |
|
|
|
|
| def load_epirecipes_cakeweek_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| df = kaggle_epirecipes().copy() |
| if "cakeweek" not in df.columns: |
| raise KeyError("cakeweek") |
| df["cakeweek"] = pd.to_numeric(df["cakeweek"], errors="coerce").fillna(0).astype(int) |
| return sample_df(df, limit, seed) |
| except Exception: |
| df = load_epirecipes_proxy_df(limit, seed).copy() |
| df["cakeweek"] = ((df["course"] == "dessert") & (df["rating"] >= df["rating"].median())).astype(int) |
| return df |
|
|
|
|
| def load_candy_proxy_df(limit: int, seed: int) -> pd.DataFrame: |
| rng = np.random.default_rng(seed) |
| n = min(limit, 1200) |
| chocolate = rng.binomial(1, 0.45, n) |
| fruity = rng.binomial(1, 0.38, n) |
| caramel = rng.binomial(1, 0.24, n) |
| pricepercent = rng.beta(2, 4, n) |
| sugarpercent = rng.beta(3, 2, n) |
| winpercent = 35 + 18 * chocolate + 8 * caramel + 9 * sugarpercent - 10 * pricepercent + rng.normal(0, 8, n) |
| return pd.DataFrame( |
| { |
| "chocolate": chocolate, |
| "fruity": fruity, |
| "caramel": caramel, |
| "peanutyalmondy": rng.binomial(1, 0.2, n), |
| "nougat": rng.binomial(1, 0.14, n), |
| "crispedricewafer": rng.binomial(1, 0.16, n), |
| "hard": rng.binomial(1, 0.28, n), |
| "bar": rng.binomial(1, 0.36, n), |
| "sugarpercent": sugarpercent.round(3), |
| "pricepercent": pricepercent.round(3), |
| "winpercent": winpercent.clip(5, 95).round(2), |
| } |
| ) |
|
|
|
|
| def load_candy_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(fivethirtyeight_candy(), limit, seed) |
| except Exception: |
| return load_candy_proxy_df(limit, seed) |
|
|
|
|
| def load_candy_chocolate_df(limit: int, seed: int) -> pd.DataFrame: |
| df = load_candy_df(limit, seed).copy() |
| if "chocolate" not in df.columns: |
| raise gr.Error("Candy dataset does not include the chocolate target.") |
| return df |
|
|
|
|
| def load_adult_income_proxy_df(limit: int, seed: int) -> pd.DataFrame: |
| rng = np.random.default_rng(seed) |
| n = min(limit, 30000) |
| education_num = rng.integers(6, 17, n) |
| hours = rng.normal(40, 12, n).clip(1, 80) |
| age = rng.normal(39, 13, n).clip(18, 75) |
| occupation = rng.choice(["tech", "sales", "ops", "admin", "service", "exec"], n) |
| logit = -6 + 0.16 * age + 0.36 * education_num + 0.035 * hours + 0.9 * (occupation == "exec") + 0.55 * (occupation == "tech") |
| income = rng.binomial(1, 1 / (1 + np.exp(-logit))) |
| return pd.DataFrame( |
| { |
| "age": age.round(0), |
| "education_num": education_num, |
| "hours_per_week": hours.round(0), |
| "occupation": occupation, |
| "marital_status": rng.choice(["single", "married", "divorced"], n), |
| "income_gt_50k": income, |
| } |
| ) |
|
|
|
|
| def load_adult_income_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(openml_adult_income(), limit, seed) |
| except Exception: |
| return load_adult_income_proxy_df(limit, seed) |
|
|
|
|
| def load_bike_demand_proxy_df(limit: int, seed: int) -> pd.DataFrame: |
| rng = np.random.default_rng(seed) |
| n = min(limit, 15000) |
| hour = rng.integers(0, 24, n) |
| temp = rng.normal(21, 9, n).clip(-5, 40) |
| workingday = rng.binomial(1, 0.69, n) |
| weather = rng.choice(["clear", "mist", "rain", "storm"], n, p=[0.55, 0.28, 0.14, 0.03]) |
| commute_peak = ((hour >= 7) & (hour <= 9)) | ((hour >= 16) & (hour <= 18)) |
| count = 80 + 115 * commute_peak + 5.5 * temp + 45 * workingday - 75 * (weather == "rain") - 130 * (weather == "storm") |
| count += rng.normal(0, 45, n) |
| return pd.DataFrame({"hour": hour, "temp": temp.round(1), "workingday": workingday, "weather": weather, "rental_count": count.clip(0).round(0)}) |
|
|
|
|
| def load_calcofi_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(kaggle_calcofi(), limit, seed) |
| except Exception: |
| rng = np.random.default_rng(seed) |
| n = min(limit, 12000) |
| salinity = rng.normal(33.5, 0.7, n) |
| depth = rng.gamma(2.0, 40.0, n) |
| temp = 23 - 0.38 * depth / 10 - 1.7 * (salinity - 33.5) + rng.normal(0, 1.8, n) |
| return pd.DataFrame({"salinity": salinity, "depthm": depth, "water_temperature": temp}) |
|
|
|
|
| def load_szeged_weather_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(kaggle_szeged_weather(), limit, seed) |
| except Exception: |
| rng = np.random.default_rng(seed) |
| n = min(limit, 20000) |
| humidity = rng.beta(4, 2, n) |
| temp = rng.normal(12, 10, n) |
| apparent = temp - 5 * humidity + rng.normal(0, 2.5, n) |
| return pd.DataFrame({"temperature_c": temp, "humidity": humidity, "wind_speed_km_h": rng.gamma(2, 4, n), "apparent_temperature": apparent}) |
|
|
|
|
| def load_weather_ww2_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(kaggle_weather_ww2(), limit, seed) |
| except Exception: |
| rng = np.random.default_rng(seed) |
| n = min(limit, 15000) |
| mintemp = rng.normal(15, 9, n) |
| return pd.DataFrame({"mintemp": mintemp, "precip": rng.gamma(1.5, 2, n), "max_temperature": mintemp + rng.normal(8, 3, n)}) |
|
|
|
|
| def load_montreal_bike_lanes_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(kaggle_montreal_bike_lanes(), limit, seed) |
| except Exception: |
| return load_bike_demand_proxy_df(limit, seed).rename(columns={"rental_count": "rider_count"}) |
|
|
|
|
| def load_nyc_bike_crossings_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(kaggle_nyc_bike_crossings(), limit, seed) |
| except Exception: |
| df = load_bike_demand_proxy_df(limit, seed) |
| df["brooklyn_bridge"] = (df["rental_count"] * 0.32).round() |
| df["manhattan_bridge"] = (df["rental_count"] * 0.27).round() |
| df["total_bike_crossings"] = df["rental_count"] |
| return df.drop(columns=["rental_count"]) |
|
|
|
|
| def load_uk_road_safety_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(kaggle_uk_road_safety(), limit, seed) |
| except Exception: |
| rng = np.random.default_rng(seed) |
| n = min(limit, 30000) |
| vehicles = rng.integers(1, 5, n) |
| speed = rng.choice([20, 30, 40, 50, 60, 70], n) |
| casualties = rng.poisson(0.4 + vehicles * 0.28 + (speed > 50) * 0.25, n) |
| return pd.DataFrame({"number_of_vehicles": vehicles, "speed_limit": speed, "light_conditions": rng.choice(["daylight", "dark"], n), "casualty_count": casualties}) |
|
|
|
|
| def load_kcbs_bbq_df(limit: int, seed: int) -> pd.DataFrame: |
| try: |
| return sample_df(kaggle_kcbs_bbq(), limit, seed) |
| except Exception: |
| rng = np.random.default_rng(seed) |
| n = min(limit, 8000) |
| score = rng.normal(165, 12, n) |
| first = rng.binomial(1, 1 / (1 + np.exp(-(score - 184) / 5))) |
| return pd.DataFrame({"score": score, "contest_size": rng.integers(10, 80, n), "category": rng.choice(["chicken", "ribs", "pork", "brisket"], n), "first_place": first}) |
|
|
|
|
| DATASETS: list[DatasetSpec] = [ |
| DatasetSpec("Titanic Survival", "classification", "survived", "OpenML data_id=40945", 1309, "Mixed categorical/numeric binary classification.", load_titanic_df), |
| DatasetSpec("Ames Housing Prices", "regression", "sale_price", "OpenML data_id=42165", 1460, "Ames real-estate regression with neighborhood and quality features.", load_ames_housing_df), |
| DatasetSpec("California Housing", "regression", "median_house_value", "sklearn California housing", 20640, "Block-level California housing value regression.", load_california_housing_df), |
| DatasetSpec("Credit Card Fraud", "classification", "fraud", "KaggleHub mlg-ulb/creditcardfraud", 284807, "Large imbalanced binary fraud task.", load_credit_fraud_df), |
| DatasetSpec("Epicurious Recipes", "regression", "rating", "KaggleHub hugodarwood/epirecipes", 20000, "Recipe nutrition and tags to rating.", load_epirecipes_df), |
| DatasetSpec("Halloween Candy", "regression", "winpercent", "FiveThirtyEight GitHub CSV", 85, "Candy attributes to popularity score.", load_candy_df), |
| DatasetSpec("Candy Chocolate", "classification", "chocolate", "FiveThirtyEight GitHub CSV", 85, "Predict whether a candy is chocolate from other candy attributes.", load_candy_chocolate_df), |
| DatasetSpec("Epicurious Cakeweek", "classification", "cakeweek", "KaggleHub hugodarwood/epirecipes", 20000, "Predict cakeweek recipes from nutrition and recipe tags.", load_epirecipes_cakeweek_df), |
| DatasetSpec("CalCOFI Ocean Temperature", "regression", "water_temperature", "KaggleHub sohier/calcofi", 864863, "Predict ocean water temperature from salinity and chemistry readings.", load_calcofi_df), |
| DatasetSpec("Szeged Apparent Temperature", "regression", "apparent_temperature", "KaggleHub budincsevity/szeged-weather", 96453, "Predict apparent temperature from humidity, wind, pressure, and weather.", load_szeged_weather_df), |
| DatasetSpec("WW2 Max Temperature", "regression", "max_temperature", "KaggleHub smid80/weatherww2", 119040, "Predict daily maximum temperature from minimum temperature and weather fields.", load_weather_ww2_df), |
| DatasetSpec("Montreal Bike Lane Counts", "regression", "rider_count", "KaggleHub pablomonleon/montreal-bike-lanes", 319, "Predict rider counts on one Montreal bike path from other paths.", load_montreal_bike_lanes_df), |
| DatasetSpec("NYC Bike Crossings", "regression", "total_bike_crossings", "KaggleHub new-york-city/nyc-east-river-bicycle-crossings", 210, "Predict total East River bicycle crossings from bridge counts.", load_nyc_bike_crossings_df), |
| DatasetSpec("UK Road Casualties", "regression", "casualty_count", "KaggleHub bluehorseshoe/uk-2016-road-safety-data", 136621, "Predict accident casualty count from road safety fields.", load_uk_road_safety_df), |
| DatasetSpec("KCBS BBQ First Place", "classification", "first_place", "KaggleHub jaysobel/kcbs-bbq", 1559, "Predict whether a BBQ competition team wins first place.", load_kcbs_bbq_df), |
| DatasetSpec("Adult Income", "classification", "income_gt_50k", "OpenML data_id=1590", 48842, "Demographic and work attributes to income bucket.", load_adult_income_df), |
| DatasetSpec("Bike Demand", "regression", "rental_count", "Kaggle-style proxy", 15000, "Weather and time features to rental demand.", load_bike_demand_proxy_df), |
| DatasetSpec("Iris", "classification", "species", "sklearn", 150, "Classic multi-class flower classification.", load_iris_df), |
| DatasetSpec("Wine", "classification", "wine_class", "sklearn", 178, "Chemical analysis to cultivar class.", load_wine_df), |
| DatasetSpec("Breast Cancer", "classification", "diagnosis", "sklearn", 569, "Diagnostic measurements to benign/malignant label.", load_breast_cancer_df), |
| DatasetSpec("Digits", "classification", "digit", "sklearn", 1797, "Pixel features to handwritten digit class.", load_digits_df), |
| DatasetSpec("Diabetes", "regression", "disease_progression", "sklearn", 442, "Clinical variables to disease progression.", load_diabetes_df), |
| ] |
|
|
|
|
| def dataset_names() -> list[str]: |
| return [d.name for d in DATASETS] |
|
|
|
|
| def get_spec(name: str) -> DatasetSpec: |
| return next(d for d in DATASETS if d.name == name) |
|
|
|
|
| def get_dataset(name: str, sample_size: int, seed: int) -> pd.DataFrame: |
| spec = get_spec(name) |
| return spec.loader(sample_size, seed).reset_index(drop=True) |
|
|
|
|
| def split_xy(df: pd.DataFrame, target: str) -> tuple[pd.DataFrame, pd.Series]: |
| cleaned = df.dropna(axis=1, how="all").copy() |
| if target not in cleaned.columns: |
| raise gr.Error(f"Target column '{target}' was not found.") |
| y = cleaned[target] |
| x = cleaned.drop(columns=[target]) |
| if x.empty: |
| raise gr.Error("Dataset must include at least one feature column.") |
| return x, y |
|
|
|
|
| def coerce_numeric_target(y: pd.Series) -> pd.Series: |
| if y.dtype.kind in "ifu": |
| return pd.to_numeric(y, errors="coerce") |
| cleaned = y.astype("string").str.strip().str.replace(",", "", regex=False) |
| return pd.to_numeric(cleaned, errors="coerce") |
|
|
|
|
| def prepare_xy(df: pd.DataFrame, target: str, task: str | None) -> tuple[pd.DataFrame, pd.Series, str]: |
| x, raw_y = split_xy(df, target) |
| inferred_task = task or infer_task(raw_y) |
| y = raw_y.copy() |
|
|
| if inferred_task == "regression": |
| y = coerce_numeric_target(y) |
| valid_target = y.notna() & np.isfinite(y.to_numpy(dtype=float)) |
| if valid_target.sum() < 2: |
| raise gr.Error("Regression target must contain at least two numeric values.") |
| dropped = len(y) - int(valid_target.sum()) |
| x = x.loc[valid_target].reset_index(drop=True) |
| y = y.loc[valid_target].reset_index(drop=True) |
| if dropped and x.empty: |
| raise gr.Error("No usable rows remain after dropping non-numeric regression targets.") |
| else: |
| valid_target = raw_y.notna() |
| if valid_target.sum() < 2: |
| raise gr.Error("Classification target must contain at least two non-empty values.") |
| x = x.loc[valid_target].reset_index(drop=True) |
| y = raw_y.loc[valid_target].reset_index(drop=True) |
|
|
| if x.empty: |
| raise gr.Error("Dataset must include at least one feature column.") |
| return x, y, inferred_task |
|
|
|
|
| def infer_task(y: pd.Series) -> str: |
| if y.dtype.kind in "ifu" and y.nunique(dropna=True) > 20: |
| return "regression" |
| numeric_y = coerce_numeric_target(y) |
| non_missing = y.notna().sum() |
| numeric_non_missing = numeric_y.notna().sum() |
| if non_missing and numeric_non_missing / non_missing >= 0.9 and numeric_y.nunique(dropna=True) > 20: |
| return "regression" |
| return "classification" |
|
|
|
|
| def make_preprocessor(x: pd.DataFrame, scale_numeric: bool = False) -> ColumnTransformer: |
| numeric_cols = x.select_dtypes(include=np.number).columns.tolist() |
| categorical_cols = [c for c in x.columns if c not in numeric_cols] |
| numeric_steps: list[tuple[str, object]] = [("impute", SimpleImputer(strategy="median"))] |
| if scale_numeric: |
| numeric_steps.append(("scale", StandardScaler())) |
| transformers: list[tuple[str, object, list[str]]] = [] |
| if numeric_cols: |
| transformers.append(("num", Pipeline(numeric_steps), numeric_cols)) |
| if categorical_cols: |
| transformers.append( |
| ( |
| "cat", |
| Pipeline( |
| [ |
| ("impute", SimpleImputer(strategy="most_frequent")), |
| ("encode", OneHotEncoder(handle_unknown="ignore", sparse_output=False, max_categories=32)), |
| ] |
| ), |
| categorical_cols, |
| ) |
| ) |
| return ColumnTransformer(transformers=transformers, remainder="drop", verbose_feature_names_out=False) |
|
|
|
|
| def available_baselines(task: str) -> dict[str, object]: |
| if task == "classification": |
| models: dict[str, object] = { |
| "Logistic": Pipeline([("prep", make_preprocessor(pd.DataFrame(), True)), ("model", LogisticRegression(max_iter=800))]), |
| "RandomForest": Pipeline([("prep", make_preprocessor(pd.DataFrame())), ("model", RandomForestClassifier(n_estimators=80, min_samples_leaf=2, n_jobs=-1, random_state=RANDOM_STATE))]), |
| "ExtraTrees": Pipeline([("prep", make_preprocessor(pd.DataFrame())), ("model", ExtraTreesClassifier(n_estimators=120, min_samples_leaf=2, n_jobs=-1, random_state=RANDOM_STATE))]), |
| "GradientBoosting": Pipeline([("prep", make_preprocessor(pd.DataFrame())), ("model", GradientBoostingClassifier(n_estimators=100, learning_rate=0.06, random_state=RANDOM_STATE))]), |
| "HistGradientBoosting": Pipeline([("prep", make_preprocessor(pd.DataFrame())), ("model", HistGradientBoostingClassifier(max_iter=120, random_state=RANDOM_STATE))]), |
| "AdaBoost": Pipeline([("prep", make_preprocessor(pd.DataFrame())), ("model", AdaBoostClassifier(n_estimators=80, learning_rate=0.08, random_state=RANDOM_STATE))]), |
| "NaiveBayes": Pipeline([("prep", make_preprocessor(pd.DataFrame(), True)), ("model", GaussianNB())]), |
| "KNN": Pipeline([("prep", make_preprocessor(pd.DataFrame(), True)), ("model", KNeighborsClassifier(n_neighbors=7))]), |
| "SVC": Pipeline([("prep", make_preprocessor(pd.DataFrame(), True)), ("model", SVC(C=1.0, probability=True, random_state=RANDOM_STATE))]), |
| "Dummy": Pipeline([("prep", make_preprocessor(pd.DataFrame())), ("model", DummyClassifier(strategy="most_frequent"))]), |
| } |
| else: |
| models = { |
| "LinearRegression": Pipeline([("prep", make_preprocessor(pd.DataFrame(), True)), ("model", LinearRegression())]), |
| "Ridge": Pipeline([("prep", make_preprocessor(pd.DataFrame(), True)), ("model", Ridge(alpha=1.0))]), |
| "BayesianRidge": Pipeline([("prep", make_preprocessor(pd.DataFrame(), True)), ("model", BayesianRidge())]), |
| "RandomForest": Pipeline([("prep", make_preprocessor(pd.DataFrame())), ("model", RandomForestRegressor(n_estimators=80, min_samples_leaf=2, n_jobs=-1, random_state=RANDOM_STATE))]), |
| "ExtraTrees": Pipeline([("prep", make_preprocessor(pd.DataFrame())), ("model", ExtraTreesRegressor(n_estimators=120, min_samples_leaf=2, n_jobs=-1, random_state=RANDOM_STATE))]), |
| "GradientBoosting": Pipeline([("prep", make_preprocessor(pd.DataFrame())), ("model", GradientBoostingRegressor(n_estimators=100, learning_rate=0.06, random_state=RANDOM_STATE))]), |
| "HistGradientBoosting": Pipeline([("prep", make_preprocessor(pd.DataFrame())), ("model", HistGradientBoostingRegressor(max_iter=120, random_state=RANDOM_STATE))]), |
| "AdaBoost": Pipeline([("prep", make_preprocessor(pd.DataFrame())), ("model", AdaBoostRegressor(n_estimators=80, learning_rate=0.08, random_state=RANDOM_STATE))]), |
| "KNN": Pipeline([("prep", make_preprocessor(pd.DataFrame(), True)), ("model", KNeighborsRegressor(n_neighbors=7))]), |
| "SVR": Pipeline([("prep", make_preprocessor(pd.DataFrame(), True)), ("model", SVR(C=1.0))]), |
| "Dummy": Pipeline([("prep", make_preprocessor(pd.DataFrame())), ("model", DummyRegressor(strategy="median"))]), |
| } |
| if importlib.util.find_spec("xgboost"): |
| from xgboost import XGBClassifier, XGBRegressor |
|
|
| if task == "classification": |
| models["XGBoost"] = Pipeline( |
| [ |
| ("prep", make_preprocessor(pd.DataFrame())), |
| ("model", XGBClassifier(n_estimators=80, max_depth=4, learning_rate=0.08, eval_metric="logloss", random_state=RANDOM_STATE)), |
| ] |
| ) |
| else: |
| models["XGBoost"] = Pipeline( |
| [ |
| ("prep", make_preprocessor(pd.DataFrame())), |
| ("model", XGBRegressor(n_estimators=80, max_depth=4, learning_rate=0.08, random_state=RANDOM_STATE)), |
| ] |
| ) |
| if importlib.util.find_spec("lightgbm"): |
| try: |
| from lightgbm import LGBMClassifier, LGBMRegressor |
|
|
| if task == "classification": |
| models["LightGBM"] = Pipeline( |
| [ |
| ("prep", make_preprocessor(pd.DataFrame())), |
| ("model", LGBMClassifier(n_estimators=120, learning_rate=0.06, random_state=RANDOM_STATE, verbose=-1)), |
| ] |
| ) |
| else: |
| models["LightGBM"] = Pipeline( |
| [ |
| ("prep", make_preprocessor(pd.DataFrame())), |
| ("model", LGBMRegressor(n_estimators=120, learning_rate=0.06, random_state=RANDOM_STATE, verbose=-1)), |
| ] |
| ) |
| except Exception: |
| pass |
| return models |
|
|
|
|
| def rebuild_pipeline(model: Pipeline, x_train: pd.DataFrame) -> Pipeline: |
| pipe = clone(model) |
| wants_scale = pipe.steps[0][1].transformers and "scale" in str(pipe.steps[0][1].transformers[0][1]) |
| pipe.steps[0] = ("prep", make_preprocessor(x_train, scale_numeric=wants_scale)) |
| return pipe |
|
|
|
|
| @lru_cache(maxsize=2) |
| def load_tabfm_model(task: str): |
| from tabfm import tabfm_v1_0_0_pytorch |
|
|
| model_type = "classification" if task == "classification" else "regression" |
| return tabfm_v1_0_0_pytorch.load(model_type=model_type) |
|
|
|
|
| def resolve_tabfm_params( |
| preset: str, |
| n_estimators: int, |
| max_num_rows: int, |
| max_num_features: int, |
| batch_size: int, |
| enable_nnls: bool, |
| n_feature_crosses: str, |
| n_svd_features: str, |
| max_eval_rows: int, |
| ) -> dict[str, object]: |
| if preset in TABFM_PRESETS: |
| return dict(TABFM_PRESETS[preset]) |
| resolved_rows = None if max_num_rows <= 0 else int(max_num_rows) |
| resolved = { |
| "n_estimators": int(n_estimators), |
| "max_num_rows": resolved_rows, |
| "max_num_features": int(max_num_features), |
| "batch_size": int(batch_size), |
| "enable_nnls": bool(enable_nnls), |
| "n_feature_crosses": 0 if n_feature_crosses == "0" else n_feature_crosses, |
| "n_svd_features": 0 if n_svd_features == "0" else n_svd_features, |
| "max_eval_rows": None if max_eval_rows <= 0 else int(max_eval_rows), |
| } |
| if resolved["enable_nnls"] and resolved["max_num_rows"] is not None: |
| resolved["max_num_rows"] = None |
| return resolved |
|
|
|
|
| def run_tabfm(task: str, x_train: pd.DataFrame, x_test: pd.DataFrame, y_train: pd.Series, tabfm_params: dict[str, object]): |
| from tabfm import TabFMClassifier, TabFMRegressor |
|
|
| foundation_model = load_tabfm_model(task) |
| estimator = ( |
| TabFMClassifier(model=foundation_model, **tabfm_params) |
| if task == "classification" |
| else TabFMRegressor(model=foundation_model, **tabfm_params) |
| ) |
| estimator.fit(x_train, y_train.to_numpy()) |
| pred = estimator.predict(x_test) |
| proba = estimator.predict_proba(x_test) if task == "classification" and hasattr(estimator, "predict_proba") else None |
| return pred, proba |
|
|
|
|
| def clean_tabfm_features(x_train: pd.DataFrame, x_test: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]: |
| cleaned_train = pd.DataFrame(index=x_train.index) |
| cleaned_test = pd.DataFrame(index=x_test.index) |
|
|
| for column in x_train.columns: |
| train_series = x_train[column].replace([np.inf, -np.inf], np.nan) |
| test_series = x_test[column].replace([np.inf, -np.inf], np.nan) |
| if train_series.isna().all() and test_series.isna().all(): |
| continue |
|
|
| if pd.api.types.is_bool_dtype(train_series): |
| cleaned_train[column] = train_series.astype("float64").fillna(0.0) |
| cleaned_test[column] = test_series.astype("float64").fillna(0.0) |
| elif pd.api.types.is_numeric_dtype(train_series): |
| train_numeric = pd.to_numeric(train_series, errors="coerce") |
| test_numeric = pd.to_numeric(test_series, errors="coerce") |
| fill_value = train_numeric.median() |
| if pd.isna(fill_value): |
| fill_value = 0.0 |
| cleaned_train[column] = train_numeric.fillna(fill_value) |
| cleaned_test[column] = test_numeric.fillna(fill_value) |
| elif pd.api.types.is_datetime64_any_dtype(train_series): |
| cleaned_train[column] = train_series.dt.strftime("%Y-%m-%d %H:%M:%S").fillna("__missing__") |
| cleaned_test[column] = test_series.dt.strftime("%Y-%m-%d %H:%M:%S").fillna("__missing__") |
| else: |
| train_text = train_series.astype("string").str.strip() |
| test_text = test_series.astype("string").str.strip() |
| train_numeric = coerce_numeric_target(train_series) |
| test_numeric = coerce_numeric_target(test_series) |
| non_missing = train_text.notna() & (train_text != "") |
| numeric_ratio = train_numeric.notna().sum() / non_missing.sum() if non_missing.sum() else 0 |
| if numeric_ratio >= 0.9: |
| fill_value = train_numeric.median() |
| if pd.isna(fill_value): |
| fill_value = 0.0 |
| cleaned_train[column] = train_numeric.fillna(fill_value) |
| cleaned_test[column] = test_numeric.fillna(fill_value) |
| else: |
| cleaned_train[column] = train_text.fillna("__missing__").replace("", "__missing__") |
| cleaned_test[column] = test_text.fillna("__missing__").replace("", "__missing__") |
|
|
| if cleaned_train.empty: |
| raise gr.Error("TabFM needs at least one non-empty feature column.") |
|
|
| return cleaned_train.reset_index(drop=True), cleaned_test.reset_index(drop=True) |
|
|
|
|
| def tabfm_failure_note(exc: Exception) -> str: |
| detail = f"{type(exc).__name__}: {exc}" |
| print("TabFM failed:\n" + traceback.format_exc()) |
| env_errors = (ImportError, ModuleNotFoundError, OSError) |
| if isinstance(exc, env_errors): |
| return f"TabFM did not run because the runtime could not load it: `{detail}`. On Spaces, keep Python 3.11 and allow the GitHub dependency plus model download for `{TABFM_MODEL_ID}`." |
| return f"TabFM failed while processing this dataset: `{detail}`." |
|
|
|
|
| def score_predictions(task: str, y_true: pd.Series, pred, proba=None) -> dict[str, float]: |
| if task == "classification": |
| metrics = { |
| "accuracy": accuracy_score(y_true, pred), |
| "f1_weighted": f1_score(y_true, pred, average="weighted", zero_division=0), |
| } |
| if proba is not None and len(np.unique(y_true)) == 2: |
| try: |
| metrics["roc_auc"] = roc_auc_score(y_true, proba[:, 1]) |
| except Exception: |
| metrics["roc_auc"] = np.nan |
| else: |
| metrics["roc_auc"] = np.nan |
| metrics["rank_score"] = np.nanmean([metrics["accuracy"], metrics["f1_weighted"], metrics["roc_auc"]]) |
| return metrics |
| rmse = math.sqrt(mean_squared_error(y_true, pred)) |
| mae = mean_absolute_error(y_true, pred) |
| r2 = r2_score(y_true, pred) |
| return {"rmse": rmse, "mae": mae, "r2": r2, "rank_score": -rmse} |
|
|
|
|
| def benchmark_frame( |
| df: pd.DataFrame, |
| target: str, |
| task: str | None, |
| sample_size: int, |
| test_size: float, |
| seed: int, |
| selected_models: list[str], |
| include_tabfm: bool, |
| tabfm_params: dict[str, object] | None = None, |
| ) -> tuple[pd.DataFrame, pd.DataFrame, str]: |
| df = df.sample(min(sample_size, len(df)), random_state=seed).reset_index(drop=True) |
| x, y, task = prepare_xy(df, target, task) |
| if task == "classification" and y.nunique(dropna=True) < 2: |
| raise gr.Error("Classification needs at least two target classes.") |
| stratify = y if task == "classification" and y.value_counts().min() >= 2 else None |
| x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=seed, stratify=stratify) |
|
|
| rows: list[dict[str, object]] = [] |
| notes: list[str] = [] |
| selected_models = selected_models or [] |
| models = available_baselines(task) |
| for name, model in models.items(): |
| if name not in selected_models: |
| continue |
| start = time.perf_counter() |
| try: |
| pipe = rebuild_pipeline(model, x_train) |
| pipe.fit(x_train, y_train) |
| pred = pipe.predict(x_test) |
| proba = pipe.predict_proba(x_test) if task == "classification" and hasattr(pipe, "predict_proba") else None |
| metrics = score_predictions(task, y_test, pred, proba) |
| rows.append({"model": name, "status": "ok", "seconds": time.perf_counter() - start, **metrics}) |
| except Exception as exc: |
| rows.append({"model": name, "status": f"failed: {exc}", "seconds": time.perf_counter() - start}) |
| for name in selected_models: |
| if name not in models: |
| rows.append({"model": name, "status": f"not compatible with {task} or unavailable", "seconds": 0.0, "rank_score": np.nan}) |
|
|
| if include_tabfm: |
| start = time.perf_counter() |
| try: |
| tabfm_params = tabfm_params or TABFM_PRESETS["Fast"] |
| tabfm_display_params = dict(tabfm_params) |
| tabfm_model_params = dict(tabfm_params) |
| tabfm_eval_rows = tabfm_model_params.pop("max_eval_rows", None) |
| if tabfm_eval_rows is not None and len(x_test) > int(tabfm_eval_rows): |
| eval_idx = x_test.sample(int(tabfm_eval_rows), random_state=seed).index |
| x_eval = x_test.loc[eval_idx] |
| y_eval = y_test.loc[eval_idx] |
| status = f"ok ({len(x_eval):,}/{len(x_test):,} test rows)" |
| else: |
| x_eval = x_test |
| y_eval = y_test |
| status = "ok" |
| x_tab_train, x_tab_eval = clean_tabfm_features(x_train, x_eval) |
| pred, proba = run_tabfm(task, x_tab_train, x_tab_eval, y_train, tabfm_model_params) |
| metrics = score_predictions(task, y_eval, pred, proba) |
| rows.append({"model": "TabFM", "status": status, "seconds": time.perf_counter() - start, **metrics}) |
| except Exception as exc: |
| rows.append({"model": "TabFM", "status": f"unavailable: {exc}", "seconds": time.perf_counter() - start}) |
| notes.append(tabfm_failure_note(exc)) |
| else: |
| rows.append({"model": "TabFM", "status": "skipped - enable Run TabFM live", "seconds": 0.0, "rank_score": np.nan}) |
| notes.append("TabFM is listed as skipped because **Run TabFM live** is off. Enable it to benchmark TabFM; the first run may download large model weights.") |
|
|
| results = pd.DataFrame(rows) |
| metric_cols = [c for c in ["accuracy", "f1_weighted", "roc_auc", "rmse", "mae", "r2", "rank_score", "seconds"] if c in results.columns] |
| if not results.empty and "rank_score" in results.columns: |
| results = results.sort_values("rank_score", ascending=False, na_position="last").reset_index(drop=True) |
| results.insert(0, "rank", np.arange(1, len(results) + 1)) |
| preview = pd.concat([x_test.reset_index(drop=True).head(12), y_test.reset_index(drop=True).head(12).rename(target)], axis=1) |
| summary = ( |
| f"**Task:** {task} \n" |
| f"**Rows used:** {len(df):,} | **Train:** {len(x_train):,} | **Test:** {len(x_test):,} | **Features:** {x.shape[1]:,} \n" |
| f"**Primary rank:** {'higher accuracy/F1/AUC' if task == 'classification' else 'lower RMSE'}" |
| ) |
| if notes: |
| summary += "\n\n" + "\n".join(f"- {note}" for note in notes) |
| if include_tabfm: |
| summary += f"\n\n**TabFM params:** `{tabfm_display_params}`" |
| return results[["rank", "model", "status", *metric_cols]], preview, summary |
|
|
|
|
| def metric_chart(results: pd.DataFrame, selected_metrics: list[str] | None = None, chart_style: str = "Line") -> go.Figure: |
| if results is None or results.empty: |
| return go.Figure() |
| selected_metrics = selected_metrics or METRIC_CHOICES |
| metric_cols = [c for c in selected_metrics if c in results.columns and results[c].notna().any()] |
| if not metric_cols: |
| metric_cols = [c for c in METRIC_CHOICES if c in results.columns and results[c].notna().any()] |
| if not metric_cols: |
| return go.Figure() |
|
|
| clean = results.sort_values("rank") if "rank" in results.columns else results.copy() |
| clean = clean.dropna(subset=metric_cols, how="all") |
| if clean.empty: |
| return go.Figure() |
| x_labels = clean["model"].astype(str).tolist() |
| if chart_style == "Radar": |
| normalized = clean[["model", *metric_cols]].copy() |
| for metric in metric_cols: |
| values = pd.to_numeric(normalized[metric], errors="coerce") |
| lo, hi = values.min(), values.max() |
| if pd.isna(lo) or pd.isna(hi) or hi == lo: |
| normalized[metric] = 0.5 |
| elif metric in {"rmse", "mae", "seconds"}: |
| normalized[metric] = 1 - ((values - lo) / (hi - lo)) |
| else: |
| normalized[metric] = (values - lo) / (hi - lo) |
| fig = go.Figure() |
| theta = metric_cols + [metric_cols[0]] |
| for idx, row in normalized.iterrows(): |
| values = [row[m] for m in metric_cols] + [row[metric_cols[0]]] |
| fig.add_trace( |
| go.Scatterpolar( |
| r=values, |
| theta=theta, |
| fill="toself", |
| name=str(row["model"]), |
| line=dict(color=GOOGLE_COLORS[idx % len(GOOGLE_COLORS)], width=2), |
| opacity=0.78, |
| ) |
| ) |
| fig.update_layout( |
| template="plotly_white", |
| height=420, |
| margin=dict(l=35, r=35, t=35, b=25), |
| polar=dict(radialaxis=dict(visible=True, range=[0, 1])), |
| legend_title_text="Model", |
| title="Normalized metric shape (higher is better)", |
| ) |
| return fig |
|
|
| fig = make_subplots( |
| rows=len(metric_cols), |
| cols=1, |
| shared_xaxes=True, |
| vertical_spacing=0.08, |
| subplot_titles=[metric.replace("_", " ").upper() for metric in metric_cols], |
| ) |
| for idx, metric in enumerate(metric_cols, start=1): |
| fig.add_trace( |
| go.Scatter( |
| x=x_labels, |
| y=clean[metric], |
| mode="lines+markers", |
| name=metric, |
| line=dict(color=GOOGLE_COLORS[(idx - 1) % len(GOOGLE_COLORS)], width=3, shape="spline"), |
| marker=dict(size=9, line=dict(color="white", width=1.5)), |
| hovertemplate=f"<b>%{{x}}</b><br>{metric}: %{{y:.4f}}<extra></extra>", |
| ), |
| row=idx, |
| col=1, |
| ) |
| if metric in {"accuracy", "f1_weighted", "roc_auc", "r2"}: |
| fig.update_yaxes(range=[min(0, float(clean[metric].min()) - 0.05), 1.02], row=idx, col=1) |
| fig.update_layout( |
| template="plotly_white", |
| height=max(320, 185 * len(metric_cols)), |
| margin=dict(l=30, r=20, t=40, b=35), |
| showlegend=False, |
| hovermode="x unified", |
| ) |
| return fig |
|
|
|
|
| def bar_chart(results: pd.DataFrame, selected_metrics: list[str] | None = None) -> go.Figure: |
| if results is None or results.empty: |
| return go.Figure() |
| results = results.copy() |
| selected_metrics = selected_metrics or METRIC_CHOICES |
| metric_cols = [c for c in selected_metrics if c in results.columns and results[c].notna().any()] |
| if not metric_cols: |
| metric_cols = [c for c in METRIC_CHOICES if c in results.columns and results[c].notna().any()] |
| if not metric_cols: |
| return go.Figure() |
| clean = results.sort_values("rank") if "rank" in results.columns else results.copy() |
| clean = clean.dropna(subset=metric_cols, how="all") |
| if clean.empty: |
| return go.Figure() |
| long = clean.melt(id_vars=["model"], value_vars=metric_cols, var_name="metric", value_name="score") |
| fig = px.bar( |
| long, |
| x="model", |
| y="score", |
| color="metric", |
| barmode="group", |
| color_discrete_sequence=GOOGLE_COLORS, |
| title="Grouped metric comparison", |
| ) |
| fig.update_layout( |
| template="plotly_white", |
| height=360, |
| margin=dict(l=25, r=20, t=45, b=35), |
| legend_title_text="Metric", |
| hovermode="x unified", |
| ) |
| return fig |
|
|
|
|
| def time_chart(results: pd.DataFrame) -> go.Figure: |
| if results is None or results.empty or "seconds" not in results: |
| return go.Figure() |
| results = results.copy() |
| if "status" in results.columns: |
| results = results[~results["status"].astype(str).str.startswith("skipped")] |
| if results.empty: |
| return go.Figure() |
| fig = px.scatter( |
| results, |
| x="seconds", |
| y="model", |
| size=np.maximum(results.get("rank_score", pd.Series([1] * len(results))).fillna(0).abs(), 0.1), |
| color="model", |
| color_discrete_sequence=px.colors.qualitative.Set2, |
| ) |
| fig.update_layout(template="plotly_white", height=300, margin=dict(l=20, r=20, t=25, b=20), showlegend=False) |
| return fig |
|
|
|
|
| def run_catalog( |
| dataset_name: str, |
| sample_size: int, |
| test_percent: int, |
| seed: int, |
| selected_models: list[str], |
| include_tabfm: bool, |
| selected_metrics: list[str], |
| chart_style: str, |
| tabfm_preset: str, |
| tabfm_n_estimators: int, |
| tabfm_max_rows: int, |
| tabfm_max_features: int, |
| tabfm_batch_size: int, |
| tabfm_enable_nnls: bool, |
| tabfm_crosses: str, |
| tabfm_svd: str, |
| tabfm_max_eval_rows: int, |
| ): |
| spec = get_spec(dataset_name) |
| df = get_dataset(dataset_name, sample_size, seed) |
| tabfm_params = resolve_tabfm_params(tabfm_preset, tabfm_n_estimators, tabfm_max_rows, tabfm_max_features, tabfm_batch_size, tabfm_enable_nnls, tabfm_crosses, tabfm_svd, tabfm_max_eval_rows) |
| results, preview, summary = benchmark_frame(df, spec.target, spec.task, sample_size, test_percent / 100, seed, selected_models, include_tabfm, tabfm_params) |
| return summary, results.round(4), metric_chart(results, selected_metrics, chart_style), time_chart(results), bar_chart(results, selected_metrics), preview |
|
|
|
|
| def run_upload( |
| file, |
| target: str, |
| task: str, |
| sample_size: int, |
| test_percent: int, |
| seed: int, |
| selected_models: list[str], |
| include_tabfm: bool, |
| selected_metrics: list[str], |
| chart_style: str, |
| tabfm_preset: str, |
| tabfm_n_estimators: int, |
| tabfm_max_rows: int, |
| tabfm_max_features: int, |
| tabfm_batch_size: int, |
| tabfm_enable_nnls: bool, |
| tabfm_crosses: str, |
| tabfm_svd: str, |
| tabfm_max_eval_rows: int, |
| ): |
| if file is None: |
| raise gr.Error("Upload a CSV file first.") |
| df = pd.read_csv(file.name) |
| selected_task = None if task == "Auto" else task.lower() |
| tabfm_params = resolve_tabfm_params(tabfm_preset, tabfm_n_estimators, tabfm_max_rows, tabfm_max_features, tabfm_batch_size, tabfm_enable_nnls, tabfm_crosses, tabfm_svd, tabfm_max_eval_rows) |
| results, preview, summary = benchmark_frame(df, target, selected_task, sample_size, test_percent / 100, seed, selected_models, include_tabfm, tabfm_params) |
| return summary, results.round(4), metric_chart(results, selected_metrics, chart_style), time_chart(results), bar_chart(results, selected_metrics), preview |
|
|
|
|
| def redraw_metric_chart(results: pd.DataFrame, selected_metrics: list[str], chart_style: str): |
| if results is None or len(results) == 0: |
| return go.Figure() |
| return metric_chart(pd.DataFrame(results), selected_metrics, chart_style) |
|
|
|
|
| def redraw_bar_chart(results: pd.DataFrame, selected_metrics: list[str]): |
| if results is None or len(results) == 0: |
| return go.Figure() |
| return bar_chart(pd.DataFrame(results), selected_metrics) |
|
|
|
|
| def catalog_table() -> pd.DataFrame: |
| return pd.DataFrame( |
| [ |
| { |
| "dataset": d.name, |
| "task": d.task, |
| "target": d.target, |
| "rows": d.rows, |
| "source": d.source, |
| "description": d.description, |
| } |
| for d in DATASETS |
| ] |
| ) |
|
|
|
|
| DEFAULT_MODELS = [ |
| "Logistic", |
| "LinearRegression", |
| "Ridge", |
| "BayesianRidge", |
| "NaiveBayes", |
| "RandomForest", |
| "ExtraTrees", |
| "GradientBoosting", |
| "HistGradientBoosting", |
| "AdaBoost", |
| "KNN", |
| "SVC", |
| "SVR", |
| "XGBoost", |
| "LightGBM", |
| "Dummy", |
| ] |
| DEFAULT_SELECTED_MODELS = ["HistGradientBoosting", "XGBoost", "LightGBM", "Dummy"] |
|
|
|
|
| def build_app() -> gr.Blocks: |
| css = """ |
| body, .gradio-container { background: #f7f9fd; color: #101828; } |
| .shell { max-width: 1440px; margin: 0 auto; } |
| .hero { background: linear-gradient(135deg, #ffffff 0%, #f6f9ff 56%, #fff7ed 100%); border: 1px solid #e9edf5; border-radius: 18px; padding: 26px 28px; box-shadow: 0 20px 55px rgba(15, 23, 42, 0.07); } |
| .hero h1 { font-size: 36px; line-height: 1.05; margin: 0 0 8px; letter-spacing: 0; } |
| .hero p { margin: 0; color: #667085; font-size: 15px; } |
| .stat-card { background: #fff; border: 1px solid #edf0f5; border-radius: 14px; padding: 18px; box-shadow: 0 12px 35px rgba(15, 23, 42, 0.05); min-height: 118px; } |
| .stat-card .label { color: #667085; font-size: 13px; } |
| .stat-card .value { font-size: 28px; font-weight: 760; margin-top: 12px; } |
| .stat-card .trend { display: inline-block; margin-left: 8px; font-size: 12px; color: #027a48; background: #ecfdf3; border-radius: 999px; padding: 2px 8px; } |
| .panel { background: #fff; border: 1px solid #edf0f5; border-radius: 14px; padding: 14px; box-shadow: 0 12px 35px rgba(15, 23, 42, 0.04); } |
| .control-panel { background: linear-gradient(180deg, #ffffff 0%, #fbfcff 100%); border-top: 4px solid #f97316; } |
| .control-panel label span { color: #344054; font-weight: 720; } |
| .control-panel input, .control-panel textarea, .control-panel select { border-radius: 10px !important; } |
| .control-panel .wrap { gap: 9px !important; } |
| .control-panel .token, .control-panel [data-testid="token"] { background: #fff7ed !important; color: #c2410c !important; border: 1px solid #fed7aa !important; border-radius: 999px !important; } |
| .control-panel .checkbox label { border-radius: 999px !important; } |
| .gr-button-primary { background: linear-gradient(135deg, #f97316, #ea4335) !important; border-color: #f97316 !important; box-shadow: 0 12px 24px rgba(249, 115, 22, 0.25) !important; border-radius: 12px !important; min-height: 46px !important; font-weight: 760 !important; } |
| .plot-container, .table-wrap { border-radius: 14px !important; overflow: hidden; } |
| footer { display: none !important; } |
| """ |
| with gr.Blocks(title=APP_TITLE, css=css, theme=gr.themes.Soft(primary_hue="orange", secondary_hue="violet")) as demo: |
| with gr.Column(elem_classes=["shell"]): |
| gr.HTML( |
| """ |
| <div class="hero"> |
| <h1>tabBench</h1> |
| <p>A clean arena for benchmarking <strong>google/tabfm-1.0.0-pytorch</strong> against practical tabular baselines across small, classic, imbalanced, and user-uploaded datasets.</p> |
| </div> |
| """ |
| ) |
| with gr.Row(): |
| gr.HTML(f'<div class="stat-card"><div class="label">Benchmark catalog</div><div class="value">{len(DATASETS)} <span class="trend">mixed tasks</span></div><div class="label">Classification + regression</div></div>') |
| gr.HTML('<div class="stat-card"><div class="label">Linked HF model</div><div class="value">TabFM <span class="trend">1.0</span></div><div class="label">google/tabfm-1.0.0-pytorch</div></div>') |
| gr.HTML('<div class="stat-card"><div class="label">User datasets</div><div class="value">CSV <span class="trend">upload</span></div><div class="label">Pick target, task, sample size</div></div>') |
| with gr.Tabs(): |
| with gr.Tab("Arena"): |
| with gr.Row(): |
| with gr.Column(scale=1, elem_classes=["panel", "control-panel"]): |
| dataset = gr.Dropdown(dataset_names(), value="Titanic Survival", label="Dataset") |
| sample = gr.Slider(100, 50000, value=1000, step=100, label="Sample size") |
| test_pct = gr.Slider(10, 40, value=25, step=5, label="Test split (%)") |
| seed = gr.Number(value=42, precision=0, label="Random seed") |
| models = gr.Dropdown(DEFAULT_MODELS, value=DEFAULT_SELECTED_MODELS, multiselect=True, label="Models") |
| include_tabfm = gr.Checkbox(value=False, label="Run TabFM live (adds TabFM row)") |
| metric_toggles = gr.Dropdown(METRIC_CHOICES, value=["rmse", "mae", "r2", "accuracy", "f1_weighted", "roc_auc"], multiselect=True, label="Chart metrics") |
| chart_style = gr.Radio(["Line", "Radar"], value="Line", label="Chart style") |
| with gr.Accordion("TabFM tuning", open=False): |
| tabfm_preset = gr.Dropdown(list(TABFM_PRESETS.keys()) + ["Custom"], value="Fast", label="Preset") |
| tabfm_n_estimators = gr.Slider(1, 32, value=1, step=1, label="Estimators") |
| tabfm_max_rows = gr.Slider(0, 5000, value=256, step=64, label="Max context rows (0 = no cap)") |
| tabfm_max_features = gr.Slider(8, 500, value=64, step=8, label="Max features") |
| tabfm_batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch size") |
| tabfm_enable_nnls = gr.Checkbox(value=False, label="NNLS ensemble weights") |
| tabfm_crosses = gr.Radio(["0", "sqrt"], value="0", label="Feature crosses") |
| tabfm_svd = gr.Radio(["0", "sqrt"], value="0", label="SVD features") |
| tabfm_max_eval_rows = gr.Slider(0, 5000, value=256, step=64, label="Max TabFM test rows (0 = no cap)") |
| run_btn = gr.Button("Run benchmark", variant="primary") |
| with gr.Column(scale=3): |
| summary = gr.Markdown() |
| leaderboard = gr.Dataframe(label="Leaderboard", interactive=False) |
| bars = gr.Plot(label="Main grouped comparison") |
| with gr.Row(): |
| chart = gr.Plot(label="Metric comparison") |
| speed = gr.Plot(label="Speed") |
| preview = gr.Dataframe(label="Held-out preview", interactive=False) |
| run_inputs = [ |
| dataset, |
| sample, |
| test_pct, |
| seed, |
| models, |
| include_tabfm, |
| metric_toggles, |
| chart_style, |
| tabfm_preset, |
| tabfm_n_estimators, |
| tabfm_max_rows, |
| tabfm_max_features, |
| tabfm_batch_size, |
| tabfm_enable_nnls, |
| tabfm_crosses, |
| tabfm_svd, |
| tabfm_max_eval_rows, |
| ] |
| run_outputs = [summary, leaderboard, chart, speed, bars, preview] |
| run_btn.click(run_catalog, run_inputs, run_outputs) |
| metric_toggles.change(redraw_metric_chart, [leaderboard, metric_toggles, chart_style], chart) |
| metric_toggles.change(redraw_bar_chart, [leaderboard, metric_toggles], bars) |
| chart_style.change(redraw_metric_chart, [leaderboard, metric_toggles, chart_style], chart) |
| demo.load(run_catalog, run_inputs, run_outputs) |
| with gr.Tab("Upload Dataset"): |
| with gr.Row(): |
| with gr.Column(scale=1, elem_classes=["panel", "control-panel"]): |
| file = gr.File(label="CSV file", file_types=[".csv"]) |
| target = gr.Textbox(label="Target column") |
| task = gr.Radio(["Auto", "Classification", "Regression"], value="Auto", label="Task") |
| upload_sample = gr.Slider(100, 50000, value=1000, step=100, label="Sample size") |
| upload_test_pct = gr.Slider(10, 40, value=25, step=5, label="Test split (%)") |
| upload_seed = gr.Number(value=42, precision=0, label="Random seed") |
| upload_models = gr.Dropdown(DEFAULT_MODELS, value=DEFAULT_SELECTED_MODELS, multiselect=True, label="Models") |
| upload_tabfm = gr.Checkbox(value=False, label="Run TabFM live (adds TabFM row)") |
| upload_metric_toggles = gr.Dropdown(METRIC_CHOICES, value=["rmse", "mae", "r2", "accuracy", "f1_weighted", "roc_auc"], multiselect=True, label="Chart metrics") |
| upload_chart_style = gr.Radio(["Line", "Radar"], value="Line", label="Chart style") |
| with gr.Accordion("TabFM tuning", open=False): |
| upload_tabfm_preset = gr.Dropdown(list(TABFM_PRESETS.keys()) + ["Custom"], value="Fast", label="Preset") |
| upload_tabfm_n_estimators = gr.Slider(1, 32, value=1, step=1, label="Estimators") |
| upload_tabfm_max_rows = gr.Slider(0, 5000, value=256, step=64, label="Max context rows (0 = no cap)") |
| upload_tabfm_max_features = gr.Slider(8, 500, value=64, step=8, label="Max features") |
| upload_tabfm_batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch size") |
| upload_tabfm_enable_nnls = gr.Checkbox(value=False, label="NNLS ensemble weights") |
| upload_tabfm_crosses = gr.Radio(["0", "sqrt"], value="0", label="Feature crosses") |
| upload_tabfm_svd = gr.Radio(["0", "sqrt"], value="0", label="SVD features") |
| upload_tabfm_max_eval_rows = gr.Slider(0, 5000, value=256, step=64, label="Max TabFM test rows (0 = no cap)") |
| upload_btn = gr.Button("Run uploaded dataset", variant="primary") |
| with gr.Column(scale=3): |
| upload_summary = gr.Markdown() |
| upload_leaderboard = gr.Dataframe(label="Upload leaderboard", interactive=False) |
| upload_bars = gr.Plot(label="Main grouped comparison") |
| with gr.Row(): |
| upload_chart = gr.Plot(label="Metric comparison") |
| upload_speed = gr.Plot(label="Speed") |
| upload_preview = gr.Dataframe(label="Held-out preview", interactive=False) |
| upload_btn.click( |
| run_upload, |
| [ |
| file, |
| target, |
| task, |
| upload_sample, |
| upload_test_pct, |
| upload_seed, |
| upload_models, |
| upload_tabfm, |
| upload_metric_toggles, |
| upload_chart_style, |
| upload_tabfm_preset, |
| upload_tabfm_n_estimators, |
| upload_tabfm_max_rows, |
| upload_tabfm_max_features, |
| upload_tabfm_batch_size, |
| upload_tabfm_enable_nnls, |
| upload_tabfm_crosses, |
| upload_tabfm_svd, |
| upload_tabfm_max_eval_rows, |
| ], |
| [upload_summary, upload_leaderboard, upload_chart, upload_speed, upload_bars, upload_preview], |
| ) |
| upload_metric_toggles.change(redraw_metric_chart, [upload_leaderboard, upload_metric_toggles, upload_chart_style], upload_chart) |
| upload_metric_toggles.change(redraw_bar_chart, [upload_leaderboard, upload_metric_toggles], upload_bars) |
| upload_chart_style.change(redraw_metric_chart, [upload_leaderboard, upload_metric_toggles, upload_chart_style], upload_chart) |
| with gr.Tab("Dataset Catalog"): |
| gr.Dataframe(catalog_table(), interactive=False, label="Included benchmark catalog") |
| gr.Markdown( |
| """ |
| Most catalog datasets are loaded from OpenML, KaggleHub, FiveThirtyEight GitHub data, or sklearn. Each remote loader has a fallback so the Space remains usable if an upstream dataset is temporarily unavailable. |
| """ |
| ) |
| with gr.Tab("Implementation Notes"): |
| gr.Markdown( |
| """ |
| This Space declares `models: google/tabfm-1.0.0-pytorch` in its README metadata, which is what Hugging Face uses to associate Spaces with model pages. |
| |
| TabFM is attempted only when **Run TabFM live** is enabled because the first run downloads large model weights and CPU Basic inference can be slow. Use the **Fast** preset for a quick smoke test, then increase estimators/context rows for stronger but slower runs. |
| |
| The TabFM integration follows the Google Research README pattern: load `tabfm_v1_0_0_pytorch`, wrap it with `TabFMClassifier` or `TabFMRegressor`, call `fit` for context preparation, then `predict`. |
| """ |
| ) |
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| build_app().queue(max_size=16).launch() |
|
|