File size: 5,230 Bytes
d967216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e92195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
---
language:
- en
license: mit
tags:
- table-structure-recognition
- computer-vision
- pytorch
- document-ai
- table-detection
library_name: pytorch
pipeline_tag: image-classification
datasets:
- ds4sd/FinTabNet_OTSL
---

# TABLET Split Model - Table Structure Recognition

This repository contains the **Split Model** implementation from the paper [TABLET: Learning From Instructions For Tabular Data](https://arxiv.org/pdf/2506.07015v1), trained for detecting row and column splits in table images.

## Model Description

The Split Model is a deep learning architecture designed to detect horizontal and vertical splits in table images, enabling accurate table structure recognition. The model processes a table image and predicts the positions of row and column boundaries.

### Architecture

The model consists of three main components:

1. **Modified ResNet-18 Backbone**
   - Removed max pooling layer for better spatial resolution
   - Halved channel dimensions for efficiency (32→256 channels)
   - Outputs features at 1/16 resolution (60×60 for 960×960 input)

2. **Feature Pyramid Network (FPN)**
   - Upsamples backbone features to 1/2 resolution (480×480)
   - Reduces channels to 128 dimensions

3. **Dual Transformer Branches**
   - **Horizontal Branch**: Detects row splits using 1D transformer
   - **Vertical Branch**: Detects column splits using 1D transformer
   - Each branch combines:
     - Global features: Learnable weighted averaging
     - Local features: Spatial pooling with 1×1 convolution
     - Positional embeddings: 1D learned embeddings
   - 3-layer transformer encoder with 8 attention heads

### Training Details

- **Dataset**: Combination of FinTabNet and PubTabNet (OTSL format)
- **Input Size**: 960×960 pixels
- **Batch Size**: 32
- **Epochs**: 16
- **Optimizer**: AdamW (lr=3e-4, weight_decay=5e-4)
- **Loss Function**: Focal Loss (α=1.0, γ=2.0)
- **Ground Truth**: Dynamic gap-based split detection from OTSL annotations

## Installation

```bash
pip install torch torchvision pillow numpy
```

## Usage

### Basic Inference

```python
import torch
from PIL import Image
import torchvision.transforms as transforms
from split_model import SplitModel

# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SplitModel().to(device)

# Load checkpoint
checkpoint = torch.load('split_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Prepare image
transform = transforms.Compose([
    transforms.Resize((960, 960)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

image = Image.open('table_image.png').convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(device)

# Predict
with torch.no_grad():
    h_pred, v_pred = model(image_tensor)  # Returns [1, 480] predictions

    # Upsample to 960 for visualization
    h_pred = h_pred.repeat_interleave(2, dim=1)  # [1, 960]
    v_pred = v_pred.repeat_interleave(2, dim=1)  # [1, 960]

    # Apply threshold
    h_splits = (h_pred > 0.5).float()
    v_splits = (v_pred > 0.5).float()

    # Count rows and columns
    num_rows = h_splits.sum().item() + 1
    num_cols = v_splits.sum().item() + 1

    print(f"Detected {num_rows} rows and {num_cols} columns")
```

### Visualize Predictions

Use the included visualization script to test on your images:

```bash
python test_split_by_images_folder.py \
    --image-folder /path/to/images \
    --output-folder predictions_output \
    --model-path split_model.pth \
    --threshold 0.5
```

## Model Performance

The model was trained on combined FinTabNet and PubTabNet datasets:
- Training samples: ~250K table images
- Validation F1 scores typically achieve >0.90 for both horizontal and vertical splits
- Robust to various table styles, merged cells, and complex layouts

## Files in this Repository

- `split_model.py` - Model architecture and dataset classes
- `train_split_fixed.py` - Training script
- `test_split_by_images_folder.py` - Inference and visualization script
- `split_model.pth` - Trained model weights

## Key Features

- **Dynamic Gap Detection**: Automatically handles varying gap widths between cells
- **Overlap Handling**: Correctly processes tables with overlapping cell boundaries
- **Focal Loss Training**: Addresses class imbalance between split and non-split pixels
- **Transformer-based**: Captures long-range dependencies for complex table structures

## Citation

If you use this model, please cite the original TABLET paper:

```bibtex
@article{tablet2025,
  title={TABLET: Learning From Instructions For Tabular Data},
  author={[Authors from paper]},
  journal={arXiv preprint arXiv:2506.07015},
  year={2025}
}
```

## Paper Reference

This implementation is based on the Split Model described in Section 3.2 of:
[TABLET: Learning From Instructions For Tabular Data](https://arxiv.org/pdf/2506.07015v1)

## License

This model is released for research purposes. Please refer to the original paper for more details.

## Acknowledgments

- Original paper authors for the TABLET framework
- FinTabNet and PubTabNet datasets for training data
- PyTorch team for the deep learning framework