tabpfn-trainer / app.py
JimmyBhoy's picture
Fix CUDA OOM by batching TabPFN ensemble predictions
1dd230d verified
import inspect
import json
import math
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import List
import gradio as gr
import joblib
import numpy as np
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import StratifiedShuffleSplit, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OrdinalEncoder
from tabpfn import TabPFNRegressor
DATASET_REPO = "JimmyBhoy/propertypeek-training-data"
DATASET_FILE = "exports/20260330-204118/sales_training.parquet"
MODEL_DIR = Path(os.getenv("MODEL_DIR", "/data/models"))
MODEL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_PATH = MODEL_DIR / "tabpfn_sales_model.joblib"
STATUS_PATH = MODEL_DIR / "training_status.json"
MAX_TABPFN_SAMPLES = 10_000
DEFAULT_BATCH_SIZE = 10_000
DEFAULT_PREDICTION_BATCH_SIZE = 100
class BatchedTabPFNRegressor:
def __init__(
self,
models: List[TabPFNRegressor],
strategy: str,
prediction_batch_size: int = DEFAULT_PREDICTION_BATCH_SIZE,
):
self.models = models
self.strategy = strategy
self.prediction_batch_size = max(1, int(prediction_batch_size))
def _predict_single_model_batched(self, model, X):
predictions = []
for i in range(0, len(X), self.prediction_batch_size):
pred = model.predict(X[i : i + self.prediction_batch_size])
predictions.append(np.asarray(pred, dtype=np.float32).reshape(-1))
return np.concatenate(predictions, axis=0)
def predict(self, X):
if not self.models:
raise RuntimeError("No trained models available.")
preds = [self._predict_single_model_batched(model, X) for model in self.models]
if len(preds) == 1:
return preds[0]
return np.mean(np.vstack(preds), axis=0)
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _write_status(payload: dict) -> None:
payload = {"updated_at": _now_iso(), **payload}
STATUS_PATH.write_text(json.dumps(payload, indent=2), encoding="utf-8")
def _pick_target_column(df: pd.DataFrame) -> str:
candidates = ["sale_price", "sold_price", "price", "target", "y"]
for col in candidates:
if col in df.columns and pd.api.types.is_numeric_dtype(df[col]):
return col
numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
if not numeric_cols:
raise ValueError("No numeric target candidate found in parquet.")
return numeric_cols[-1]
def _build_preprocessor(X: pd.DataFrame) -> ColumnTransformer:
num_cols = [c for c in X.columns if pd.api.types.is_numeric_dtype(X[c])]
cat_cols = [c for c in X.columns if c not in num_cols]
return ColumnTransformer(
transformers=[
(
"categorical",
Pipeline(
steps=[
("imputer", SimpleImputer(strategy="most_frequent")),
(
"encoder",
OrdinalEncoder(
handle_unknown="use_encoded_value",
unknown_value=-1,
),
),
]
),
cat_cols,
),
(
"numeric",
Pipeline(steps=[("imputer", SimpleImputer(strategy="median"))]),
num_cols,
),
],
remainder="drop",
)
def _make_strat_bins(y: np.ndarray, n_bins: int = 20) -> np.ndarray:
# For regression stratification: quantile bins with duplicate handling
series = pd.Series(y)
q = min(max(5, n_bins), max(2, len(series) // 1000))
try:
binned = pd.qcut(series, q=q, duplicates="drop")
return binned.astype(str).to_numpy()
except Exception:
# Fallback: coarse fixed-width bins
edges = np.linspace(float(np.min(y)), float(np.max(y)), num=6)
return np.digitize(y, edges[1:-1], right=False).astype(str)
def _stratified_sample_indices(y: np.ndarray, sample_size: int, random_state: int = 42) -> np.ndarray:
n = len(y)
if n <= sample_size:
return np.arange(n)
bins = _make_strat_bins(y)
splitter = StratifiedShuffleSplit(n_splits=1, train_size=sample_size, random_state=random_state)
idx, _ = next(splitter.split(np.zeros((n, 1)), bins))
return np.asarray(idx)
def _fit_single_model(X_part: np.ndarray, y_part: np.ndarray, device: str):
model = TabPFNRegressor(device=device)
model.fit(X_part, y_part)
return model
def _fit_with_tabpfn_subsample_if_available(X_train: np.ndarray, y_train: np.ndarray, device: str, log):
init_sig = inspect.signature(TabPFNRegressor.__init__)
if "subsample_size" in init_sig.parameters:
log("Detected TabPFNRegressor(subsample_size=...) support; using built-in subsampling=10,000")
model = TabPFNRegressor(device=device, subsample_size=MAX_TABPFN_SAMPLES)
model.fit(X_train, y_train)
return [model], "tabpfn_subsample"
log("Built-in subsampling arg not found. Falling back to stratified 10k sample.")
idx = _stratified_sample_indices(y_train, MAX_TABPFN_SAMPLES)
model = _fit_single_model(X_train[idx], y_train[idx], device=device)
return [model], "stratified_10k_fallback"
def run_training(
strategy,
batch_size,
max_models,
random_state,
prediction_batch_size,
progress=gr.Progress(track_tqdm=True),
):
logs = []
def log(msg: str):
stamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
line = f"[{stamp}] {msg}"
logs.append(line)
_write_status({"status": "running", "log": logs})
try:
batch_size = int(batch_size)
max_models = int(max_models)
random_state = int(random_state)
prediction_batch_size = max(1, int(prediction_batch_size))
if batch_size > MAX_TABPFN_SAMPLES:
log(f"batch_size {batch_size} > 10,000; capping to 10,000")
batch_size = MAX_TABPFN_SAMPLES
log(f"Starting TabPFN training run | strategy={strategy}")
parquet_path = hf_hub_download(
repo_id=DATASET_REPO,
repo_type="dataset",
filename=DATASET_FILE,
)
log(f"Downloaded dataset file: {DATASET_FILE}")
progress(0.08, desc="Loading parquet")
df = pd.read_parquet(parquet_path)
log(f"Loaded parquet rows={len(df):,}, columns={len(df.columns)}")
target_col = _pick_target_column(df)
log(f"Using target column: {target_col}")
X = df.drop(columns=[target_col]).copy()
y = df[target_col].astype(float)
for col in X.columns:
if pd.api.types.is_datetime64_any_dtype(X[col]):
X[col] = X[col].astype("int64") // 10**9
preprocessor = _build_preprocessor(X)
progress(0.16, desc="Train/test split")
X_train, X_test, y_train, y_test = train_test_split(
X,
y,
test_size=0.2,
random_state=random_state,
)
progress(0.25, desc="Preprocessing")
X_train_p = np.asarray(preprocessor.fit_transform(X_train), dtype=np.float32)
X_test_p = np.asarray(preprocessor.transform(X_test), dtype=np.float32)
y_train_np = np.asarray(y_train, dtype=np.float32)
y_test_np = np.asarray(y_test, dtype=np.float32)
device = "cuda" if torch.cuda.is_available() else "cpu"
log(f"Training on device: {device}")
log(f"Train rows after split: {len(X_train_p):,}")
trained_models = []
strategy_used = strategy
if strategy == "stratified_10k":
progress(0.45, desc="Training stratified 10k sample")
idx = _stratified_sample_indices(y_train_np, min(batch_size, MAX_TABPFN_SAMPLES), random_state)
log(f"Training single model on stratified sample size={len(idx):,}")
trained_models = [_fit_single_model(X_train_p[idx], y_train_np[idx], device=device)]
elif strategy == "ensemble_batches":
n = len(X_train_p)
eff_batch = min(batch_size, MAX_TABPFN_SAMPLES)
n_batches = int(math.ceil(n / eff_batch))
if max_models > 0:
n_batches = min(n_batches, max_models)
rng = np.random.default_rng(random_state)
perm = rng.permutation(n)
log(f"Training ensemble across {n_batches} batches of up to {eff_batch:,} rows each")
for i in range(n_batches):
start = i * eff_batch
end = min((i + 1) * eff_batch, n)
batch_idx = perm[start:end]
if len(batch_idx) == 0:
break
p = 0.35 + 0.45 * ((i + 1) / max(n_batches, 1))
progress(p, desc=f"Training batch model {i + 1}/{n_batches}")
log(f"Fitting model {i + 1}/{n_batches} on rows={len(batch_idx):,}")
trained_models.append(_fit_single_model(X_train_p[batch_idx], y_train_np[batch_idx], device=device))
elif strategy == "tabpfn_subsample":
progress(0.45, desc="Training with built-in/fallback subsampling")
trained_models, strategy_used = _fit_with_tabpfn_subsample_if_available(
X_train_p,
y_train_np,
device=device,
log=log,
)
else:
raise ValueError(f"Unknown strategy: {strategy}")
if not trained_models:
raise RuntimeError("No models were trained.")
progress(0.85, desc="Evaluating")
ensemble_model = BatchedTabPFNRegressor(
trained_models,
strategy_used,
prediction_batch_size=prediction_batch_size,
)
preds = ensemble_model.predict(X_test_p)
rmse = float(mean_squared_error(y_test_np, preds, squared=False))
r2 = float(r2_score(y_test_np, preds))
log(f"Validation RMSE: {rmse:,.4f}")
log(f"Validation R2: {r2:,.4f}")
log(f"Models trained: {len(trained_models)}")
artifact = {
"model": ensemble_model,
"models": trained_models,
"preprocessor": preprocessor,
"target_column": target_col,
"feature_columns": list(X.columns),
"metrics": {"rmse": rmse, "r2": r2},
"row_count": int(len(df)),
"train_rows": int(len(X_train_p)),
"strategy_requested": strategy,
"strategy_used": strategy_used,
"batch_size": int(batch_size),
"max_models": int(max_models),
"num_models": int(len(trained_models)),
"prediction_batch_size": int(prediction_batch_size),
"trained_at_utc": _now_iso(),
}
joblib.dump(artifact, MODEL_PATH)
log(f"Saved model artifact: {MODEL_PATH}")
_write_status(
{
"status": "completed",
"dataset_repo": DATASET_REPO,
"dataset_file": DATASET_FILE,
"rows_loaded": int(len(df)),
"train_rows": int(len(X_train_p)),
"target_column": target_col,
"device": device,
"strategy_requested": strategy,
"strategy_used": strategy_used,
"batch_size": int(batch_size),
"max_models": int(max_models),
"num_models": int(len(trained_models)),
"prediction_batch_size": int(prediction_batch_size),
"metrics": {"rmse": rmse, "r2": r2},
"model_path": str(MODEL_PATH),
"log": logs,
}
)
progress(1.0, desc="Done")
return "\n".join(logs), str(MODEL_PATH)
except Exception as e:
err = f"Training failed: {type(e).__name__}: {e}"
logs.append(err)
_write_status({"status": "failed", "error": err, "log": logs})
return "\n".join(logs), ""
def read_latest_status():
if STATUS_PATH.exists():
return STATUS_PATH.read_text(encoding="utf-8")
return json.dumps({"status": "idle", "message": "No training run yet."}, indent=2)
with gr.Blocks() as demo:
gr.Markdown("# TabPFN Trainer")
gr.Markdown(f"Dataset: `{DATASET_REPO}` → `{DATASET_FILE}`")
gr.Markdown("TabPFN fit-size workaround includes 10k stratified sample and batched ensembling.")
strategy = gr.Dropdown(
choices=["ensemble_batches", "stratified_10k", "tabpfn_subsample"],
value="ensemble_batches",
label="Training Strategy",
)
with gr.Row():
batch_size = gr.Number(value=10_000, label="Batch Size (max 10,000)", precision=0)
max_models = gr.Number(value=8, label="Max Models (0 = all batches)", precision=0)
random_state = gr.Number(value=42, label="Random Seed", precision=0)
prediction_batch_size = gr.Number(value=100, label="Prediction Batch Size", precision=0)
with gr.Row():
train_btn = gr.Button("Start Training", variant="primary")
status_btn = gr.Button("Refresh Status")
logs_box = gr.Textbox(label="Training Logs", lines=18)
model_box = gr.Textbox(label="Model Output Path")
status_box = gr.Code(label="Latest Status JSON", language="json")
train_btn.click(
run_training,
inputs=[strategy, batch_size, max_models, random_state, prediction_batch_size],
outputs=[logs_box, model_box],
api_name="train",
)
status_btn.click(read_latest_status, outputs=[status_box], api_name="status")
if __name__ == "__main__":
demo.queue().launch()