File size: 4,277 Bytes
0bf0abe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | ---
license: apache-2.0
tags:
- tabular
- foundation-model
- pretraining
- tabpfn
- schema-aware
- pytorch
datasets:
- avewright/tabula-pretraining-corpus-v2
language:
- en
---
# Tabula v1 — Tabular Foundation Model (Pretrained)
A schema-aware tabular transformer pretrained on a large multi-source corpus
of real and synthetic tabular datasets.
## Model Architecture
| Property | Value |
|---|---|
| Architecture | TabularTransformer |
| d_model | 256 |
| Heads | 8 |
| Layers | 8 |
| FFN dim | 512 |
| FFN activation | SwiGLU |
| Normalization | RMSNorm |
| Pooling | CLS token |
| Numeric embedding | Periodic (k=16) |
| Max numeric features | 64 |
| Max categories | 128 |
| Parameters | **10,752,769** (~10.75M) |
## Pretraining
| Property | Value |
|---|---|
| Best checkpoint | Step 45,000 |
| Best val loss | 0.2295 |
| Rows seen at best | 23,040,000 |
| Final step | 61,825 |
| Total rows seen | 31,654,400 |
| Batch size | 512 |
| Learning rate | 3e-4 (cosine decay, 2K warmup) |
| AMP | fp16 |
| Hardware | NVIDIA RTX A4500 (20 GB) |
| Training time | ~3 hours |
Loss objective: multi-task MSE on target prediction from mixed numeric/categorical features,
normalized per-column (z-score). Each batch samples from a fixed-width (64-feature)
schema where unused slots are masked with NaN.
## Pretraining Corpus
Trained on [`avewright/tabula-pretraining-corpus-v2`](https://huggingface.co/datasets/avewright/tabula-pretraining-corpus-v2):
| Source | OK Datasets | Status |
|---|---|---|
| PMLB | 422 | **Fully exhausted** (all 422 known datasets used) |
| OpenML | 2,949 | 4,886 attempted — 1,900 rejected (too few features), 37 download failures |
| HuggingFace | 0 | 67 attempted — format incompatibilities |
| **Synthetic** | (unlimited) | tree-prior, GMM, polynomial, SCM, regression, time-series, mixed-type |
**Total corpus:** 541 shards, ~160 GB parquet.
**Format:** `feat_0..feat_63` (Float32, NaN=unused), `target` (Float32), `_source_meta` (JSON).
### Dataset Exhaustion Notes
- **PMLB: fully exhausted.** All 422 of 423 known datasets successfully processed
(1 download failure: `chess`). No new PMLB datasets can be added without an
upstream PMLB library update.
- **OpenML: largely exhausted.** 4,886 unique datasets attempted. 2,949 passed
the pipeline. The 1,900 `schema_fail` entries are almost entirely datasets with
only 1 output column and too few rows/features to be useful (e.g. `too small: (53, 1)`).
These are unrecoverable without lowering quality thresholds. There may be a small
tail of undiscovered OpenML datasets not yet paginated.
- **HuggingFace tabular:** 67 attempted from curated catalog. All failed due to
schema mismatches, missing splits, or download timeouts. Catalog needs expansion
with manually vetted datasets.
## Files
| File | Description |
|---|---|
| `best.pt` | Best validation checkpoint (step 45,000, val_loss=0.2295) |
| `latest.pt` | Final training checkpoint (step 61,825) |
| `config.json` | Model and training hyperparameters |
| `training_log.txt` | Full training run output |
## Usage
```python
import torch
from tabula.models.transformer import TabularTransformer
from tabula.config import ModelConfig
# Load checkpoint
ckpt = torch.load("best.pt", map_location="cpu", weights_only=False)
cfg = ckpt["config"].model
# Reconstruct model
model = TabularTransformer(
d_model=cfg.d_model, n_heads=cfg.n_heads, n_layers=cfg.n_layers,
d_ff=cfg.d_ff, dropout=cfg.dropout,
num_numeric=64, num_categorical=0, num_text=0,
output_dim=1,
numeric_embedding=cfg.numeric_embedding,
numeric_periodic_features=cfg.numeric_periodic_features,
ffn_activation=cfg.ffn_activation, norm=cfg.norm, pooling=cfg.pooling,
)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
```
## Training Notes
The model uses a fixed-width schema (64 numeric slots) regardless of original
dataset width. Narrower datasets are zero-padded with NaN masks. This forces the
model to learn position-invariant feature representations compatible with arbitrary
tabular schemas.
Synthetic data fills gaps when real corpus buffer is empty, providing 100M+ rows
per session of controlled variation in feature distributions, missingness patterns,
and task types.
|