| --- |
| license: apache-2.0 |
| tags: |
| - tabular |
| - foundation-model |
| - pretraining |
| - tabpfn |
| - schema-aware |
| - pytorch |
| datasets: |
| - avewright/tabula-pretraining-corpus-v2 |
| language: |
| - en |
| --- |
| |
| # Tabula v1 β Tabular Foundation Model (Pretrained) |
|
|
| A schema-aware tabular transformer pretrained on a large multi-source corpus |
| of real and synthetic tabular datasets. |
|
|
| ## Model Architecture |
|
|
| | Property | Value | |
| |---|---| |
| | Architecture | TabularTransformer | |
| | d_model | 256 | |
| | Heads | 8 | |
| | Layers | 8 | |
| | FFN dim | 512 | |
| | FFN activation | SwiGLU | |
| | Normalization | RMSNorm | |
| | Pooling | CLS token | |
| | Numeric embedding | Periodic (k=16) | |
| | Max numeric features | 64 | |
| | Max categories | 128 | |
| | Parameters | **10,752,769** (~10.75M) | |
| |
| ## Pretraining |
| |
| | Property | Value | |
| |---|---| |
| | Best checkpoint | Step 45,000 | |
| | Best val loss | 0.2295 | |
| | Rows seen at best | 23,040,000 | |
| | Final step | 61,825 | |
| | Total rows seen | 31,654,400 | |
| | Batch size | 512 | |
| | Learning rate | 3e-4 (cosine decay, 2K warmup) | |
| | AMP | fp16 | |
| | Hardware | NVIDIA RTX A4500 (20 GB) | |
| | Training time | ~3 hours | |
| |
| Loss objective: multi-task MSE on target prediction from mixed numeric/categorical features, |
| normalized per-column (z-score). Each batch samples from a fixed-width (64-feature) |
| schema where unused slots are masked with NaN. |
| |
| ## Pretraining Corpus |
| |
| Trained on [`avewright/tabula-pretraining-corpus-v2`](https://huggingface.co/datasets/avewright/tabula-pretraining-corpus-v2): |
| |
| | Source | OK Datasets | Status | |
| |---|---|---| |
| | PMLB | 422 | **Fully exhausted** (all 422 known datasets used) | |
| | OpenML | 2,949 | 4,886 attempted β 1,900 rejected (too few features), 37 download failures | |
| | HuggingFace | 0 | 67 attempted β format incompatibilities | |
| | **Synthetic** | (unlimited) | tree-prior, GMM, polynomial, SCM, regression, time-series, mixed-type | |
| |
| **Total corpus:** 541 shards, ~160 GB parquet. |
| **Format:** `feat_0..feat_63` (Float32, NaN=unused), `target` (Float32), `_source_meta` (JSON). |
| |
| ### Dataset Exhaustion Notes |
| |
| - **PMLB: fully exhausted.** All 422 of 423 known datasets successfully processed |
| (1 download failure: `chess`). No new PMLB datasets can be added without an |
| upstream PMLB library update. |
| |
| - **OpenML: largely exhausted.** 4,886 unique datasets attempted. 2,949 passed |
| the pipeline. The 1,900 `schema_fail` entries are almost entirely datasets with |
| only 1 output column and too few rows/features to be useful (e.g. `too small: (53, 1)`). |
| These are unrecoverable without lowering quality thresholds. There may be a small |
| tail of undiscovered OpenML datasets not yet paginated. |
|
|
| - **HuggingFace tabular:** 67 attempted from curated catalog. All failed due to |
| schema mismatches, missing splits, or download timeouts. Catalog needs expansion |
| with manually vetted datasets. |
|
|
| ## Files |
|
|
| | File | Description | |
| |---|---| |
| | `best.pt` | Best validation checkpoint (step 45,000, val_loss=0.2295) | |
| | `latest.pt` | Final training checkpoint (step 61,825) | |
| | `config.json` | Model and training hyperparameters | |
| | `training_log.txt` | Full training run output | |
|
|
| ## Usage |
|
|
| ```python |
| import torch |
| from tabula.models.transformer import TabularTransformer |
| from tabula.config import ModelConfig |
| |
| # Load checkpoint |
| ckpt = torch.load("best.pt", map_location="cpu", weights_only=False) |
| cfg = ckpt["config"].model |
| |
| # Reconstruct model |
| model = TabularTransformer( |
| d_model=cfg.d_model, n_heads=cfg.n_heads, n_layers=cfg.n_layers, |
| d_ff=cfg.d_ff, dropout=cfg.dropout, |
| num_numeric=64, num_categorical=0, num_text=0, |
| output_dim=1, |
| numeric_embedding=cfg.numeric_embedding, |
| numeric_periodic_features=cfg.numeric_periodic_features, |
| ffn_activation=cfg.ffn_activation, norm=cfg.norm, pooling=cfg.pooling, |
| ) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.eval() |
| ``` |
|
|
| ## Training Notes |
|
|
| The model uses a fixed-width schema (64 numeric slots) regardless of original |
| dataset width. Narrower datasets are zero-padded with NaN masks. This forces the |
| model to learn position-invariant feature representations compatible with arbitrary |
| tabular schemas. |
|
|
| Synthetic data fills gaps when real corpus buffer is empty, providing 100M+ rows |
| per session of controlled variation in feature distributions, missingness patterns, |
| and task types. |
|
|