tabula-v1 / README.md
avewright's picture
Upload Tabula v1 pretrained model β€” step 61,825, best_val=0.2295
0bf0abe verified
---
license: apache-2.0
tags:
- tabular
- foundation-model
- pretraining
- tabpfn
- schema-aware
- pytorch
datasets:
- avewright/tabula-pretraining-corpus-v2
language:
- en
---
# Tabula v1 β€” Tabular Foundation Model (Pretrained)
A schema-aware tabular transformer pretrained on a large multi-source corpus
of real and synthetic tabular datasets.
## Model Architecture
| Property | Value |
|---|---|
| Architecture | TabularTransformer |
| d_model | 256 |
| Heads | 8 |
| Layers | 8 |
| FFN dim | 512 |
| FFN activation | SwiGLU |
| Normalization | RMSNorm |
| Pooling | CLS token |
| Numeric embedding | Periodic (k=16) |
| Max numeric features | 64 |
| Max categories | 128 |
| Parameters | **10,752,769** (~10.75M) |
## Pretraining
| Property | Value |
|---|---|
| Best checkpoint | Step 45,000 |
| Best val loss | 0.2295 |
| Rows seen at best | 23,040,000 |
| Final step | 61,825 |
| Total rows seen | 31,654,400 |
| Batch size | 512 |
| Learning rate | 3e-4 (cosine decay, 2K warmup) |
| AMP | fp16 |
| Hardware | NVIDIA RTX A4500 (20 GB) |
| Training time | ~3 hours |
Loss objective: multi-task MSE on target prediction from mixed numeric/categorical features,
normalized per-column (z-score). Each batch samples from a fixed-width (64-feature)
schema where unused slots are masked with NaN.
## Pretraining Corpus
Trained on [`avewright/tabula-pretraining-corpus-v2`](https://huggingface.co/datasets/avewright/tabula-pretraining-corpus-v2):
| Source | OK Datasets | Status |
|---|---|---|
| PMLB | 422 | **Fully exhausted** (all 422 known datasets used) |
| OpenML | 2,949 | 4,886 attempted β€” 1,900 rejected (too few features), 37 download failures |
| HuggingFace | 0 | 67 attempted β€” format incompatibilities |
| **Synthetic** | (unlimited) | tree-prior, GMM, polynomial, SCM, regression, time-series, mixed-type |
**Total corpus:** 541 shards, ~160 GB parquet.
**Format:** `feat_0..feat_63` (Float32, NaN=unused), `target` (Float32), `_source_meta` (JSON).
### Dataset Exhaustion Notes
- **PMLB: fully exhausted.** All 422 of 423 known datasets successfully processed
(1 download failure: `chess`). No new PMLB datasets can be added without an
upstream PMLB library update.
- **OpenML: largely exhausted.** 4,886 unique datasets attempted. 2,949 passed
the pipeline. The 1,900 `schema_fail` entries are almost entirely datasets with
only 1 output column and too few rows/features to be useful (e.g. `too small: (53, 1)`).
These are unrecoverable without lowering quality thresholds. There may be a small
tail of undiscovered OpenML datasets not yet paginated.
- **HuggingFace tabular:** 67 attempted from curated catalog. All failed due to
schema mismatches, missing splits, or download timeouts. Catalog needs expansion
with manually vetted datasets.
## Files
| File | Description |
|---|---|
| `best.pt` | Best validation checkpoint (step 45,000, val_loss=0.2295) |
| `latest.pt` | Final training checkpoint (step 61,825) |
| `config.json` | Model and training hyperparameters |
| `training_log.txt` | Full training run output |
## Usage
```python
import torch
from tabula.models.transformer import TabularTransformer
from tabula.config import ModelConfig
# Load checkpoint
ckpt = torch.load("best.pt", map_location="cpu", weights_only=False)
cfg = ckpt["config"].model
# Reconstruct model
model = TabularTransformer(
d_model=cfg.d_model, n_heads=cfg.n_heads, n_layers=cfg.n_layers,
d_ff=cfg.d_ff, dropout=cfg.dropout,
num_numeric=64, num_categorical=0, num_text=0,
output_dim=1,
numeric_embedding=cfg.numeric_embedding,
numeric_periodic_features=cfg.numeric_periodic_features,
ffn_activation=cfg.ffn_activation, norm=cfg.norm, pooling=cfg.pooling,
)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
```
## Training Notes
The model uses a fixed-width schema (64 numeric slots) regardless of original
dataset width. Narrower datasets are zero-padded with NaN masks. This forces the
model to learn position-invariant feature representations compatible with arbitrary
tabular schemas.
Synthetic data fills gaps when real corpus buffer is empty, providing 100M+ rows
per session of controlled variation in feature distributions, missingness patterns,
and task types.