File size: 5,230 Bytes
d967216 5e92195 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
---
language:
- en
license: mit
tags:
- table-structure-recognition
- computer-vision
- pytorch
- document-ai
- table-detection
library_name: pytorch
pipeline_tag: image-classification
datasets:
- ds4sd/FinTabNet_OTSL
---
# TABLET Split Model - Table Structure Recognition
This repository contains the **Split Model** implementation from the paper [TABLET: Learning From Instructions For Tabular Data](https://arxiv.org/pdf/2506.07015v1), trained for detecting row and column splits in table images.
## Model Description
The Split Model is a deep learning architecture designed to detect horizontal and vertical splits in table images, enabling accurate table structure recognition. The model processes a table image and predicts the positions of row and column boundaries.
### Architecture
The model consists of three main components:
1. **Modified ResNet-18 Backbone**
- Removed max pooling layer for better spatial resolution
- Halved channel dimensions for efficiency (32→256 channels)
- Outputs features at 1/16 resolution (60×60 for 960×960 input)
2. **Feature Pyramid Network (FPN)**
- Upsamples backbone features to 1/2 resolution (480×480)
- Reduces channels to 128 dimensions
3. **Dual Transformer Branches**
- **Horizontal Branch**: Detects row splits using 1D transformer
- **Vertical Branch**: Detects column splits using 1D transformer
- Each branch combines:
- Global features: Learnable weighted averaging
- Local features: Spatial pooling with 1×1 convolution
- Positional embeddings: 1D learned embeddings
- 3-layer transformer encoder with 8 attention heads
### Training Details
- **Dataset**: Combination of FinTabNet and PubTabNet (OTSL format)
- **Input Size**: 960×960 pixels
- **Batch Size**: 32
- **Epochs**: 16
- **Optimizer**: AdamW (lr=3e-4, weight_decay=5e-4)
- **Loss Function**: Focal Loss (α=1.0, γ=2.0)
- **Ground Truth**: Dynamic gap-based split detection from OTSL annotations
## Installation
```bash
pip install torch torchvision pillow numpy
```
## Usage
### Basic Inference
```python
import torch
from PIL import Image
import torchvision.transforms as transforms
from split_model import SplitModel
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SplitModel().to(device)
# Load checkpoint
checkpoint = torch.load('split_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Prepare image
transform = transforms.Compose([
transforms.Resize((960, 960)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open('table_image.png').convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
h_pred, v_pred = model(image_tensor) # Returns [1, 480] predictions
# Upsample to 960 for visualization
h_pred = h_pred.repeat_interleave(2, dim=1) # [1, 960]
v_pred = v_pred.repeat_interleave(2, dim=1) # [1, 960]
# Apply threshold
h_splits = (h_pred > 0.5).float()
v_splits = (v_pred > 0.5).float()
# Count rows and columns
num_rows = h_splits.sum().item() + 1
num_cols = v_splits.sum().item() + 1
print(f"Detected {num_rows} rows and {num_cols} columns")
```
### Visualize Predictions
Use the included visualization script to test on your images:
```bash
python test_split_by_images_folder.py \
--image-folder /path/to/images \
--output-folder predictions_output \
--model-path split_model.pth \
--threshold 0.5
```
## Model Performance
The model was trained on combined FinTabNet and PubTabNet datasets:
- Training samples: ~250K table images
- Validation F1 scores typically achieve >0.90 for both horizontal and vertical splits
- Robust to various table styles, merged cells, and complex layouts
## Files in this Repository
- `split_model.py` - Model architecture and dataset classes
- `train_split_fixed.py` - Training script
- `test_split_by_images_folder.py` - Inference and visualization script
- `split_model.pth` - Trained model weights
## Key Features
- **Dynamic Gap Detection**: Automatically handles varying gap widths between cells
- **Overlap Handling**: Correctly processes tables with overlapping cell boundaries
- **Focal Loss Training**: Addresses class imbalance between split and non-split pixels
- **Transformer-based**: Captures long-range dependencies for complex table structures
## Citation
If you use this model, please cite the original TABLET paper:
```bibtex
@article{tablet2025,
title={TABLET: Learning From Instructions For Tabular Data},
author={[Authors from paper]},
journal={arXiv preprint arXiv:2506.07015},
year={2025}
}
```
## Paper Reference
This implementation is based on the Split Model described in Section 3.2 of:
[TABLET: Learning From Instructions For Tabular Data](https://arxiv.org/pdf/2506.07015v1)
## License
This model is released for research purposes. Please refer to the original paper for more details.
## Acknowledgments
- Original paper authors for the TABLET framework
- FinTabNet and PubTabNet datasets for training data
- PyTorch team for the deep learning framework
|