File size: 2,416 Bytes
58a4549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6914420
 
 
 
58a4549
 
 
 
 
 
6914420
58a4549
6914420
 
58a4549
 
 
 
6914420
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
---
license: apache-2.0
tags:
  - tabular
  - in-context-learning
  - transformer
---

# TabDPT Checkpoints

Pre-trained [TabDPT](https://github.com/JesseCresswell/tfm-mia) model weights trained with three different random seeds. Each checkpoint is from epoch 2040 of production training.

## Files

| File | Training Seed |
|---|---|
| `production_seed42.safetensors` | 42 |
| `production_seed123.safetensors` | 123 |
| `production_seed456.safetensors` | 456 |

## Model Architecture

- **Embedding size:** 512
- **Attention heads:** 8
- **Layers:** 12
- **Hidden factor:** 2
- **Max features:** 100
- **Max classes:** 10

## Benchmark Results

### Classification: Breast Cancer (binary, 30 features)

| Checkpoint | Accuracy | Ensemble Accuracy |
|---|---|---|
| seed42 | 99.4% | 99.4% |
| seed123 | 98.8% | 98.8% |
| seed456 | 98.2% | 98.2% |
| HF default (Layer6/TabDPT) | — | 99.4% |

### Classification: Wine (3-class, 13 features)

| Checkpoint | Accuracy | Ensemble Accuracy |
|---|---|---|
| seed42 | 100% | 100% |
| seed123 | 100% | 100% |
| seed456 | 100% | 100% |
| HF default (Layer6/TabDPT) | — | 100% |

### Regression: Diabetes (10 features)

| Checkpoint | MSE | Correlation |
|---|---|---|
| seed42 | 2618.6 | 0.718 |
| seed123 | 2655.1 | 0.713 |
| seed456 | 2795.5 | 0.701 |
| HF default (Layer6/TabDPT) | 2673.1 | 0.711 |

## Training Stats (from checkpoint metadata)

| Metric | seed42 | seed123 | seed456 |
|---|---|---|---|
| CC18 Accuracy | 0.877 | 0.878 | 0.879 |
| CC18 F1 | 0.870 | 0.872 | 0.873 |
| CC18 AUC | 0.927 | 0.927 | 0.928 |
| CTR Correlation | 0.830 | 0.830 | 0.827 |
| CTR R² | 0.726 | 0.730 | 0.725 |

## Format

These checkpoints were converted from PyTorch Lightning `.ckpt` files (which include optimizer state, ~295MB each) to SafeTensors format (model weights only, ~103MB each). This is the same format used by the official `Layer6/TabDPT` release. The `tabdpt` package natively loads SafeTensors via the `model_weight_path` argument — no extra conversion needed.

## Usage

```python
from tabdpt import TabDPTClassifier
from huggingface_hub import hf_hub_download

# Download once (cached afterwards)
path = hf_hub_download("dwahdany/TabDPT", "production_seed42.safetensors")

# Use exactly like the default model
clf = TabDPTClassifier(model_weight_path=path)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)
```

Works identically with `TabDPTRegressor`.