| | --- |
| | license: apache-2.0 |
| | tags: |
| | - tabular |
| | - in-context-learning |
| | - transformer |
| | --- |
| | |
| | # TabDPT Checkpoints |
| |
|
| | Pre-trained [TabDPT](https://github.com/JesseCresswell/tfm-mia) model weights trained with three different random seeds. Each checkpoint is from epoch 2040 of production training. |
| |
|
| | ## Files |
| |
|
| | | File | Training Seed | |
| | |---|---| |
| | | `production_seed42.safetensors` | 42 | |
| | | `production_seed123.safetensors` | 123 | |
| | | `production_seed456.safetensors` | 456 | |
| |
|
| | ## Model Architecture |
| |
|
| | - **Embedding size:** 512 |
| | - **Attention heads:** 8 |
| | - **Layers:** 12 |
| | - **Hidden factor:** 2 |
| | - **Max features:** 100 |
| | - **Max classes:** 10 |
| |
|
| | ## Benchmark Results |
| |
|
| | ### Classification: Breast Cancer (binary, 30 features) |
| |
|
| | | Checkpoint | Accuracy | Ensemble Accuracy | |
| | |---|---|---| |
| | | seed42 | 99.4% | 99.4% | |
| | | seed123 | 98.8% | 98.8% | |
| | | seed456 | 98.2% | 98.2% | |
| | | HF default (Layer6/TabDPT) | — | 99.4% | |
| |
|
| | ### Classification: Wine (3-class, 13 features) |
| |
|
| | | Checkpoint | Accuracy | Ensemble Accuracy | |
| | |---|---|---| |
| | | seed42 | 100% | 100% | |
| | | seed123 | 100% | 100% | |
| | | seed456 | 100% | 100% | |
| | | HF default (Layer6/TabDPT) | — | 100% | |
| |
|
| | ### Regression: Diabetes (10 features) |
| |
|
| | | Checkpoint | MSE | Correlation | |
| | |---|---|---| |
| | | seed42 | 2618.6 | 0.718 | |
| | | seed123 | 2655.1 | 0.713 | |
| | | seed456 | 2795.5 | 0.701 | |
| | | HF default (Layer6/TabDPT) | 2673.1 | 0.711 | |
| |
|
| | ## Training Stats (from checkpoint metadata) |
| |
|
| | | Metric | seed42 | seed123 | seed456 | |
| | |---|---|---|---| |
| | | CC18 Accuracy | 0.877 | 0.878 | 0.879 | |
| | | CC18 F1 | 0.870 | 0.872 | 0.873 | |
| | | CC18 AUC | 0.927 | 0.927 | 0.928 | |
| | | CTR Correlation | 0.830 | 0.830 | 0.827 | |
| | | CTR R² | 0.726 | 0.730 | 0.725 | |
| |
|
| | ## Format |
| |
|
| | These checkpoints were converted from PyTorch Lightning `.ckpt` files (which include optimizer state, ~295MB each) to SafeTensors format (model weights only, ~103MB each). This is the same format used by the official `Layer6/TabDPT` release. The `tabdpt` package natively loads SafeTensors via the `model_weight_path` argument — no extra conversion needed. |
| |
|
| | ## Usage |
| |
|
| | ```python |
| | from tabdpt import TabDPTClassifier |
| | from huggingface_hub import hf_hub_download |
| | |
| | # Download once (cached afterwards) |
| | path = hf_hub_download("dwahdany/TabDPT", "production_seed42.safetensors") |
| | |
| | # Use exactly like the default model |
| | clf = TabDPTClassifier(model_weight_path=path) |
| | clf.fit(X_train, y_train) |
| | preds = clf.predict(X_test) |
| | ``` |
| |
|
| | Works identically with `TabDPTRegressor`. |
| |
|