TabDPT / README.md
dwahdany's picture
Add format section explaining safetensors compatibility
6914420 verified
---
license: apache-2.0
tags:
- tabular
- in-context-learning
- transformer
---
# TabDPT Checkpoints
Pre-trained [TabDPT](https://github.com/JesseCresswell/tfm-mia) model weights trained with three different random seeds. Each checkpoint is from epoch 2040 of production training.
## Files
| File | Training Seed |
|---|---|
| `production_seed42.safetensors` | 42 |
| `production_seed123.safetensors` | 123 |
| `production_seed456.safetensors` | 456 |
## Model Architecture
- **Embedding size:** 512
- **Attention heads:** 8
- **Layers:** 12
- **Hidden factor:** 2
- **Max features:** 100
- **Max classes:** 10
## Benchmark Results
### Classification: Breast Cancer (binary, 30 features)
| Checkpoint | Accuracy | Ensemble Accuracy |
|---|---|---|
| seed42 | 99.4% | 99.4% |
| seed123 | 98.8% | 98.8% |
| seed456 | 98.2% | 98.2% |
| HF default (Layer6/TabDPT) | — | 99.4% |
### Classification: Wine (3-class, 13 features)
| Checkpoint | Accuracy | Ensemble Accuracy |
|---|---|---|
| seed42 | 100% | 100% |
| seed123 | 100% | 100% |
| seed456 | 100% | 100% |
| HF default (Layer6/TabDPT) | — | 100% |
### Regression: Diabetes (10 features)
| Checkpoint | MSE | Correlation |
|---|---|---|
| seed42 | 2618.6 | 0.718 |
| seed123 | 2655.1 | 0.713 |
| seed456 | 2795.5 | 0.701 |
| HF default (Layer6/TabDPT) | 2673.1 | 0.711 |
## Training Stats (from checkpoint metadata)
| Metric | seed42 | seed123 | seed456 |
|---|---|---|---|
| CC18 Accuracy | 0.877 | 0.878 | 0.879 |
| CC18 F1 | 0.870 | 0.872 | 0.873 |
| CC18 AUC | 0.927 | 0.927 | 0.928 |
| CTR Correlation | 0.830 | 0.830 | 0.827 |
| CTR R² | 0.726 | 0.730 | 0.725 |
## Format
These checkpoints were converted from PyTorch Lightning `.ckpt` files (which include optimizer state, ~295MB each) to SafeTensors format (model weights only, ~103MB each). This is the same format used by the official `Layer6/TabDPT` release. The `tabdpt` package natively loads SafeTensors via the `model_weight_path` argument — no extra conversion needed.
## Usage
```python
from tabdpt import TabDPTClassifier
from huggingface_hub import hf_hub_download
# Download once (cached afterwards)
path = hf_hub_download("dwahdany/TabDPT", "production_seed42.safetensors")
# Use exactly like the default model
clf = TabDPTClassifier(model_weight_path=path)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)
```
Works identically with `TabDPTRegressor`.