Upload 20 files
Browse files- .gitattributes +36 -35
- README.md +217 -0
- data/dataset.py +202 -0
- data/transform.py +334 -0
- image/architecture.png +3 -0
- model/htr_convtext.py +446 -0
- model/layer.py +75 -0
- model/resnet18.py +411 -0
- model/tcm_head.py +133 -0
- requirements.txt +8 -0
- run/iam.sh +1 -0
- run/lam.sh +1 -0
- run/read2016.sh +1 -0
- run/vnondb.sh +1 -0
- test.py +140 -0
- train.py +441 -0
- utils/option.py +235 -0
- utils/sam.py +63 -0
- utils/utils.py +276 -0
- valid.py +77 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,36 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
image/architecture.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,220 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
- vi
|
| 6 |
+
pipeline_tag: image-to-text
|
| 7 |
+
model-index:
|
| 8 |
+
- name: HTR-ConvText
|
| 9 |
+
results:
|
| 10 |
+
- task:
|
| 11 |
+
type: image-to-text
|
| 12 |
+
name: Handwritten Text Recognition
|
| 13 |
+
dataset:
|
| 14 |
+
name: IAM
|
| 15 |
+
type: iam
|
| 16 |
+
split: test
|
| 17 |
+
metrics:
|
| 18 |
+
- type: cer
|
| 19 |
+
value: 4.0
|
| 20 |
+
name: Test CER
|
| 21 |
+
- type: wer
|
| 22 |
+
value: 12.9
|
| 23 |
+
name: Test WER
|
| 24 |
+
- task:
|
| 25 |
+
type: image-to-text
|
| 26 |
+
name: Handwritten Text Recognition
|
| 27 |
+
dataset:
|
| 28 |
+
name: LAM
|
| 29 |
+
type: lam
|
| 30 |
+
split: test
|
| 31 |
+
metrics:
|
| 32 |
+
- type: cer
|
| 33 |
+
value: 2.7
|
| 34 |
+
name: Test CER
|
| 35 |
+
- type: wer
|
| 36 |
+
value: 7.0
|
| 37 |
+
name: Test WER
|
| 38 |
+
- task:
|
| 39 |
+
type: image-to-text
|
| 40 |
+
name: Handwritten Text Recognition
|
| 41 |
+
dataset:
|
| 42 |
+
name: READ2016
|
| 43 |
+
type: read2016
|
| 44 |
+
split: test
|
| 45 |
+
metrics:
|
| 46 |
+
- type: cer
|
| 47 |
+
value: 3.6
|
| 48 |
+
name: Test CER
|
| 49 |
+
- type: wer
|
| 50 |
+
value: 15.7
|
| 51 |
+
name: Test WER
|
| 52 |
+
- task:
|
| 53 |
+
type: image-to-text
|
| 54 |
+
name: Handwritten Text Recognition
|
| 55 |
+
dataset:
|
| 56 |
+
name: HANDS-VNOnDB
|
| 57 |
+
type: hands-vnondb
|
| 58 |
+
split: test
|
| 59 |
+
metrics:
|
| 60 |
+
- type: cer
|
| 61 |
+
value: 3.45
|
| 62 |
+
name: Test CER
|
| 63 |
+
- type: wer
|
| 64 |
+
value: 8.9
|
| 65 |
+
name: Test WER
|
| 66 |
---
|
| 67 |
+
---
|
| 68 |
+
# HTR-ConvText: Leveraging Convolution and Textual Information for Handwritten Text Recognition
|
| 69 |
+
|
| 70 |
+
<div align="center"> <img src="image/architecture.png" alt="HTR-ConvText Architecture" width="800"/> </div>
|
| 71 |
+
|
| 72 |
+
<p align="center">
|
| 73 |
+
<a href="https://huggingface.co/DAIR-Group/HTR-ConvText">
|
| 74 |
+
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue">
|
| 75 |
+
</a>
|
| 76 |
+
<a href="https://github.com/DAIR-Group/HTR-ConvText">
|
| 77 |
+
<img alt="GitHub" src="https://img.shields.io/badge/GitHub-Repo-181717.svg?logo=github&logoColor=white">
|
| 78 |
+
</a>
|
| 79 |
+
<a href="https://github.com/DAIR-Group/HTR-ConvText/blob/main/LICENSE">
|
| 80 |
+
<img alt="License" src="https://img.shields.io/badge/License-Apache%202.0-green">
|
| 81 |
+
</a>
|
| 82 |
+
<a href="https://arxiv.org/abs/2512.05021">
|
| 83 |
+
<img alt="arXiv" src="https://img.shields.io/badge/arXiv-2512.05021-b31b1b.svg">
|
| 84 |
+
</a>
|
| 85 |
+
</p>
|
| 86 |
+
|
| 87 |
+
## Highlights
|
| 88 |
+
|
| 89 |
+
HTR-ConvText is a novel hybrid architecture for Handwritten Text Recognition (HTR) that effectively balances local feature extraction with global contextual modeling. Designed to overcome the limitations of standard CTC-based decoding and data-hungry Transformers, HTR-ConvText delivers state-of-the-art performance with the following key features:
|
| 90 |
+
|
| 91 |
+
- **Hybrid CNN-ViT Architecture**: Seamlessly integrates a ResNet backbone with MobileViT blocks (MVP) and Conditional Positional Encoding, enabling the model to capture fine-grained stroke details while maintaining global spatial awareness.
|
| 92 |
+
- **Hierarchical ConvText Encoder**: A U-Net-like encoder structure that interleaves Multi-Head Self-Attention with Depthwise Convolutions. This design efficiently models both long-range dependencies and local structural patterns.
|
| 93 |
+
- **Textual Context Module (TCM)**: An innovative training-only auxiliary module that injects bidirectional linguistic priors into the visual encoder. This mitigates the conditional independence weakness of CTC decoding without adding any latency during inference.
|
| 94 |
+
- **State-of-the-Art Performance**: Outperforms existing methods on major benchmarks including IAM (English), READ2016 (German), LAM (Italian), and HANDS-VNOnDB (Vietnamese), specifically excelling in low-resource scenarios and complex diacritics.
|
| 95 |
+
|
| 96 |
+
## Model Overview
|
| 97 |
+
|
| 98 |
+
HTR-ConvText configurations and specifications:
|
| 99 |
+
|
| 100 |
+
| Feature | Specification |
|
| 101 |
+
| ------------------- | --------------------------------------------------- |
|
| 102 |
+
| Architecture Type | Hybrid CNN + Vision Transformer (Encoder-Only) |
|
| 103 |
+
| Parameters | ~65.9M |
|
| 104 |
+
| Backbone | ResNet-18 + MobileViT w/ Positional Encoding (MVP) |
|
| 105 |
+
| Encoder Layers | 8 ConvText Blocks (Hierarchical) |
|
| 106 |
+
| Attention Heads | 8 |
|
| 107 |
+
| Embedding Dimension | 512 |
|
| 108 |
+
| Image Input Size | 512×64 |
|
| 109 |
+
| Inference Strategy | Standard CTC Decoding (TCM is removed at inference) |
|
| 110 |
+
|
| 111 |
+
For more details, including ablation studies and theoretical proofs, please refer to our [Technical Report](https://arxiv.org/pdf/2512.05021).
|
| 112 |
+
|
| 113 |
+
## Performance
|
| 114 |
+
|
| 115 |
+
We evaluated HTR-ConvText across four diverse datasets. The model achieves new SOTA results with the lowest Character Error Rate (CER) and Word Error Rate (WER) without requiring massive synthetic pre-training.
|
| 116 |
+
|
| 117 |
+
| Dataset | Language | Ours CER (%) | HTR-VT | OrigamiNet | TrOCR | CRNN |
|
| 118 |
+
|-----------|-------------|--------------|--------|------------|-------|-------|
|
| 119 |
+
| IAM | English | 4.0 | 4.7 | 4.8 | 7.3 | 7.8 |
|
| 120 |
+
| LAM | Italian | 2.7 | 2.8 | 3.0 | 3.6 | 3.8 |
|
| 121 |
+
| READ2016 | German | 3.6 | 3.9 | - | - | 4.7 |
|
| 122 |
+
| VNOnDB | Vietnamese | 3.45 | 4.26 | 7.6 | - | 10.53 |
|
| 123 |
+
|
| 124 |
+
## Quickstart
|
| 125 |
+
|
| 126 |
+
### Instalation
|
| 127 |
+
|
| 128 |
+
1. **Clone the repository**
|
| 129 |
+
```cmd
|
| 130 |
+
git clone https://github.com/0xk0ry/HTR-ConvText.git
|
| 131 |
+
cd HTR-ConvText
|
| 132 |
+
```
|
| 133 |
+
2. **Create and activate a Python 3.9+ Conda environment**
|
| 134 |
+
```cmd
|
| 135 |
+
conda create -n htr-convtext python=3.9 -y
|
| 136 |
+
conda activate htr-convtext
|
| 137 |
+
```
|
| 138 |
+
3. **Install PyTorch** using the wheel that matches your CUDA driver (swap the index for CPU-only builds):
|
| 139 |
+
```cmd
|
| 140 |
+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126
|
| 141 |
+
```
|
| 142 |
+
4. **Install the remaining project requirements** (everything except PyTorch, which you already picked in step 3).
|
| 143 |
+
```cmd
|
| 144 |
+
pip install -r requirements.txt
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
The code was tested on Python 3.9 and PyTorch 2.9.1.
|
| 148 |
+
|
| 149 |
+
### Data Preparation
|
| 150 |
+
|
| 151 |
+
We provide split files (train.ln, val.ln, test.ln) for IAM, READ2016, LAM, and VNOnDB under data/. Organize your data as follows:
|
| 152 |
+
|
| 153 |
+
```
|
| 154 |
+
./data/iam/
|
| 155 |
+
├── train.ln
|
| 156 |
+
├── val.ln
|
| 157 |
+
├── test.ln
|
| 158 |
+
└── lines
|
| 159 |
+
├── a01-000u-00.png
|
| 160 |
+
├── a01-000u-00.txt
|
| 161 |
+
└── ...
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
### Training
|
| 165 |
+
|
| 166 |
+
We provide comprehensive scripts in the ./run/ directory. To train on the IAM dataset with the Textual Context Module (TCM) enabled:
|
| 167 |
+
|
| 168 |
+
```
|
| 169 |
+
# Using the provided script
|
| 170 |
+
bash run/iam.sh
|
| 171 |
+
|
| 172 |
+
# OR running directly via Python
|
| 173 |
+
python train.py \
|
| 174 |
+
--use-wandb \
|
| 175 |
+
--dataset iam \
|
| 176 |
+
--tcm-enable \
|
| 177 |
+
--exp-name "htr-convtext-iam" \
|
| 178 |
+
--img-size 512 64 \
|
| 179 |
+
--train-bs 32 \
|
| 180 |
+
--val-bs 8 \
|
| 181 |
+
--data-path /path/to/iam/lines/ \
|
| 182 |
+
--train-data-list data/iam/train.ln \
|
| 183 |
+
--val-data-list data/iam/val.ln \
|
| 184 |
+
--test-data-list data/iam/test.ln \
|
| 185 |
+
--nb-cls 80
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
### Inference / Evaluation
|
| 189 |
+
|
| 190 |
+
To evaluate a pre-trained checkpoint on the test set:
|
| 191 |
+
|
| 192 |
+
```
|
| 193 |
+
python test.py \
|
| 194 |
+
--resume ./checkpoints/best_CER.pth \
|
| 195 |
+
--dataset iam \
|
| 196 |
+
--img-size 512 64 \
|
| 197 |
+
--data-path /path/to/iam/lines/ \
|
| 198 |
+
--test-data-list data/iam/test.ln \
|
| 199 |
+
--nb-cls 80
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
## Citation
|
| 203 |
+
|
| 204 |
+
If you find our work helpful, please cite our paper:
|
| 205 |
+
|
| 206 |
+
```
|
| 207 |
+
@misc{truc2025htrconvtex,
|
| 208 |
+
title={HTR-ConvText: Leveraging Convolution and Textual Information for Handwritten Text Recognition},
|
| 209 |
+
author={Pham Thach Thanh Truc and Dang Hoai Nam and Huynh Tong Dang Khoa and Vo Nguyen Le Duy},
|
| 210 |
+
year={2025},
|
| 211 |
+
eprint={2512.05021},
|
| 212 |
+
archivePrefix={arXiv},
|
| 213 |
+
primaryClass={cs.CV},
|
| 214 |
+
url={https://arxiv.org/abs/2512.05021},
|
| 215 |
+
}
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
## Acknowledgement
|
| 219 |
+
|
| 220 |
+
This project is inspired by and adapted from [HTR-VT](https://github.com/Intellindust-AI-Lab/HTR-VT). We gratefully acknowledge the authors for their open-source contributions.
|
data/dataset.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision.transforms import ColorJitter
|
| 2 |
+
from data import transform as transform
|
| 3 |
+
from utils import utils
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import itertools
|
| 7 |
+
import os
|
| 8 |
+
import skimage
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def SameTrCollate(batch, args):
|
| 14 |
+
|
| 15 |
+
images, labels = zip(*batch)
|
| 16 |
+
images = [Image.fromarray(np.uint8(images[i][0] * 255))
|
| 17 |
+
for i in range(len(images))]
|
| 18 |
+
|
| 19 |
+
# Apply data augmentations with 90% probability
|
| 20 |
+
if np.random.rand() < 0.5:
|
| 21 |
+
images = [transform.RandomTransform(
|
| 22 |
+
args.proj)(image) for image in images]
|
| 23 |
+
|
| 24 |
+
if np.random.rand() < 0.5:
|
| 25 |
+
kernel_h = utils.randint(1, args.dila_ero_max_kernel + 1)
|
| 26 |
+
kernel_w = utils.randint(1, args.dila_ero_max_kernel + 1)
|
| 27 |
+
if utils.randint(0, 2) == 0:
|
| 28 |
+
images = [transform.Erosion((kernel_w, kernel_h), args.dila_ero_iter)(
|
| 29 |
+
image) for image in images]
|
| 30 |
+
else:
|
| 31 |
+
images = [transform.Dilation((kernel_w, kernel_h), args.dila_ero_iter)(
|
| 32 |
+
image) for image in images]
|
| 33 |
+
|
| 34 |
+
if np.random.rand() < 0.5:
|
| 35 |
+
images = [ColorJitter(args.jitter_brightness, args.jitter_contrast, args.jitter_saturation,
|
| 36 |
+
args.jitter_hue)(image) for image in images]
|
| 37 |
+
|
| 38 |
+
# Convert images to tensors
|
| 39 |
+
|
| 40 |
+
image_tensors = [torch.from_numpy(
|
| 41 |
+
np.array(image, copy=True)) for image in images]
|
| 42 |
+
image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)
|
| 43 |
+
image_tensors = image_tensors.unsqueeze(1).float()
|
| 44 |
+
image_tensors = image_tensors / 255.
|
| 45 |
+
return image_tensors, labels
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class myLoadDS(Dataset):
|
| 49 |
+
def __init__(self, flist, dpath, img_size=[512, 32], ralph=None, fmin=True, mln=None, dataset=None):
|
| 50 |
+
self.fns = get_files(flist, dpath)
|
| 51 |
+
self.tlbls = get_labels(self.fns)
|
| 52 |
+
self.img_size = img_size
|
| 53 |
+
if ralph is not None:
|
| 54 |
+
self.ralph = ralph
|
| 55 |
+
elif dataset is not None:
|
| 56 |
+
if dataset == 'iam':
|
| 57 |
+
self.ralph = {
|
| 58 |
+
idx: char for idx, char in enumerate(
|
| 59 |
+
' !"#&\'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
| 60 |
+
)
|
| 61 |
+
}
|
| 62 |
+
elif dataset == 'lam':
|
| 63 |
+
self.ralph = {
|
| 64 |
+
idx: char for idx, char in enumerate(
|
| 65 |
+
' !"#%&\'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXZabcdefghijlmnopqrstuvwxyz|°·ÈÉàèéìòù–'
|
| 66 |
+
)
|
| 67 |
+
}
|
| 68 |
+
elif dataset == 'read2016':
|
| 69 |
+
self.ralph = {
|
| 70 |
+
idx: char for idx, char in enumerate(
|
| 71 |
+
' ()+,-./0123456789:<>ABCDEFGHIJKLMNOPQRSTUVWYZ[]abcdefghijklmnopqrstuvwxyz¾Ößäöüÿāēōūȳ̄̈—'
|
| 72 |
+
)
|
| 73 |
+
}
|
| 74 |
+
elif dataset == 'vnondb':
|
| 75 |
+
self.ralph = {
|
| 76 |
+
idx: char for idx, char in enumerate(
|
| 77 |
+
' !"%&()*,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvxyzÀÁÂÔÚÝàáâãèéêìíòóôõùúýĂăĐđĩũƠơƯưạẢảẤấẦầẩẫậắằẳẵặẹẻẽếỀềỂểễỆệỉịọỏỐốỒồổỗộớờỞởỡợụỦủứừửữựỳỷỹ'
|
| 78 |
+
)
|
| 79 |
+
}
|
| 80 |
+
else:
|
| 81 |
+
alph = get_alphabet(self.tlbls)
|
| 82 |
+
self.ralph = dict(zip(alph.values(), alph.keys()))
|
| 83 |
+
self.alph = alph
|
| 84 |
+
else:
|
| 85 |
+
alph = get_alphabet(self.tlbls)
|
| 86 |
+
self.ralph = dict(zip(alph.values(), alph.keys()))
|
| 87 |
+
self.alph = alph
|
| 88 |
+
if mln != None:
|
| 89 |
+
filt = [len(x) <= mln if fmin else len(x)
|
| 90 |
+
>= mln for x in self.tlbls]
|
| 91 |
+
self.tlbls = np.asarray(self.tlbls)[filt].tolist()
|
| 92 |
+
self.fns = np.asarray(self.fns)[filt].tolist()
|
| 93 |
+
|
| 94 |
+
def __len__(self):
|
| 95 |
+
return len(self.fns)
|
| 96 |
+
|
| 97 |
+
def __getitem__(self, index):
|
| 98 |
+
timgs = get_images(self.fns[index], self.img_size[0], self.img_size[1])
|
| 99 |
+
timgs = timgs.transpose((2, 0, 1))
|
| 100 |
+
|
| 101 |
+
return (timgs, self.tlbls[index])
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _read_text(path):
|
| 105 |
+
"""Read a text file with robust encoding handling.
|
| 106 |
+
Try UTF-8 first, then fall back to common Windows encodings.
|
| 107 |
+
"""
|
| 108 |
+
encodings = ['utf-8', 'utf-8-sig', 'cp1258', 'cp1252', 'latin-1']
|
| 109 |
+
last_err = None
|
| 110 |
+
for enc in encodings:
|
| 111 |
+
try:
|
| 112 |
+
with open(path, 'r', encoding=enc) as f:
|
| 113 |
+
return f.read()
|
| 114 |
+
except UnicodeDecodeError as e:
|
| 115 |
+
last_err = e
|
| 116 |
+
continue
|
| 117 |
+
except FileNotFoundError:
|
| 118 |
+
raise
|
| 119 |
+
# As a last resort, ignore errors to avoid crashing the training loop
|
| 120 |
+
with open(path, 'r', encoding='utf-8', errors='ignore') as f:
|
| 121 |
+
return f.read()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _read_lines(path):
|
| 125 |
+
txt = _read_text(path)
|
| 126 |
+
return txt.splitlines()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_files(nfile, dpath):
|
| 130 |
+
fnames = _read_lines(nfile)
|
| 131 |
+
fnames = [dpath + x.strip() for x in fnames]
|
| 132 |
+
return fnames
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def npThum(img, max_w, max_h):
|
| 136 |
+
x, y = np.shape(img)[:2]
|
| 137 |
+
|
| 138 |
+
y = min(int(y * max_h / x), max_w)
|
| 139 |
+
x = max_h
|
| 140 |
+
|
| 141 |
+
img = np.array(Image.fromarray(img).resize((y, x)))
|
| 142 |
+
return img
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_images(fname, max_w=500, max_h=500, nch=1): # args.max_w args.max_h args.nch
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
|
| 149 |
+
image_data = np.array(Image.open(fname).convert('L'))
|
| 150 |
+
image_data = npThum(image_data, max_w, max_h)
|
| 151 |
+
image_data = skimage.img_as_float32(image_data)
|
| 152 |
+
|
| 153 |
+
h, w = np.shape(image_data)[:2]
|
| 154 |
+
if image_data.ndim < 3:
|
| 155 |
+
image_data = np.expand_dims(image_data, axis=-1)
|
| 156 |
+
|
| 157 |
+
if nch == 3 and image_data.shape[2] != 3:
|
| 158 |
+
image_data = np.tile(image_data, 3)
|
| 159 |
+
|
| 160 |
+
image_data = np.pad(image_data, ((0, 0), (0, max_w - np.shape(image_data)[1]), (0, 0)), mode='constant',
|
| 161 |
+
constant_values=(1.0))
|
| 162 |
+
|
| 163 |
+
except IOError as e:
|
| 164 |
+
print('Could not read:', fname, ':', e)
|
| 165 |
+
|
| 166 |
+
return image_data
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def get_labels(fnames):
|
| 170 |
+
labels = []
|
| 171 |
+
for id, image_file in enumerate(fnames):
|
| 172 |
+
fn = os.path.splitext(image_file)[0] + '.txt'
|
| 173 |
+
lbl = _read_text(fn)
|
| 174 |
+
lbl = ' '.join(lbl.split()) # remove linebreaks if present
|
| 175 |
+
labels.append(lbl)
|
| 176 |
+
|
| 177 |
+
return labels
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def get_alphabet(labels):
|
| 181 |
+
coll = ''.join(labels)
|
| 182 |
+
unq = sorted(list(set(coll)))
|
| 183 |
+
unq = [''.join(i) for i in itertools.product(unq, repeat=1)]
|
| 184 |
+
alph = dict(zip(unq, range(len(unq))))
|
| 185 |
+
|
| 186 |
+
return alph
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def cycle_dpp(iterable):
|
| 190 |
+
epoch = 0
|
| 191 |
+
iterable.sampler.set_epoch(epoch)
|
| 192 |
+
while True:
|
| 193 |
+
for x in iterable:
|
| 194 |
+
yield x
|
| 195 |
+
epoch += 1
|
| 196 |
+
iterable.sampler.set_epoch(epoch)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def cycle_data(iterable):
|
| 200 |
+
while True:
|
| 201 |
+
for x in iterable:
|
| 202 |
+
yield x
|
data/transform.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from skimage import transform as stf
|
| 5 |
+
from numpy import random, floor
|
| 6 |
+
from PIL import Image, ImageOps
|
| 7 |
+
from cv2 import erode, dilate, normalize
|
| 8 |
+
from torchvision.transforms import RandomCrop
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
class Dilation:
|
| 12 |
+
"""
|
| 13 |
+
OCR: stroke width increasing
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, kernel, iterations):
|
| 16 |
+
self.kernel = np.ones(kernel, np.uint8)
|
| 17 |
+
self.iterations = iterations
|
| 18 |
+
|
| 19 |
+
def __call__(self, x):
|
| 20 |
+
return Image.fromarray(dilate(np.array(x), self.kernel, iterations=self.iterations))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Erosion:
|
| 24 |
+
"""
|
| 25 |
+
OCR: stroke width decreasing
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, kernel, iterations):
|
| 29 |
+
self.kernel = np.ones(kernel, np.uint8)
|
| 30 |
+
self.iterations = iterations
|
| 31 |
+
|
| 32 |
+
def __call__(self, x):
|
| 33 |
+
return Image.fromarray(erode(np.array(x), self.kernel, iterations=self.iterations))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ElasticDistortion:
|
| 37 |
+
"""
|
| 38 |
+
Elastic Distortion adapted from https://github.com/IntuitionMachines/OrigamiNet
|
| 39 |
+
Used in "OrigamiNet: Weakly-Supervised, Segmentation-Free, One-Step, Full Page TextRecognition by learning to unfold",
|
| 40 |
+
Yousef, Mohamed and Bishop, Tom E., The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2020
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, grid, magnitude, min_sep):
|
| 44 |
+
|
| 45 |
+
self.grid_width, self.grid_height = grid
|
| 46 |
+
self.xmagnitude, self.ymagnitude = magnitude
|
| 47 |
+
self.min_h_sep, self.min_v_sep = min_sep
|
| 48 |
+
|
| 49 |
+
def __call__(self, x):
|
| 50 |
+
w, h = x.size
|
| 51 |
+
|
| 52 |
+
horizontal_tiles = self.grid_width
|
| 53 |
+
vertical_tiles = self.grid_height
|
| 54 |
+
|
| 55 |
+
width_of_square = int(floor(w / float(horizontal_tiles)))
|
| 56 |
+
height_of_square = int(floor(h / float(vertical_tiles)))
|
| 57 |
+
|
| 58 |
+
width_of_last_square = w - (width_of_square * (horizontal_tiles - 1))
|
| 59 |
+
height_of_last_square = h - (height_of_square * (vertical_tiles - 1))
|
| 60 |
+
|
| 61 |
+
dimensions = []
|
| 62 |
+
shift = [[(0, 0) for x in range(horizontal_tiles)] for y in range(vertical_tiles)]
|
| 63 |
+
|
| 64 |
+
for vertical_tile in range(vertical_tiles):
|
| 65 |
+
for horizontal_tile in range(horizontal_tiles):
|
| 66 |
+
if vertical_tile == (vertical_tiles - 1) and horizontal_tile == (horizontal_tiles - 1):
|
| 67 |
+
dimensions.append([horizontal_tile * width_of_square,
|
| 68 |
+
vertical_tile * height_of_square,
|
| 69 |
+
width_of_last_square + (horizontal_tile * width_of_square),
|
| 70 |
+
height_of_last_square + (height_of_square * vertical_tile)])
|
| 71 |
+
elif vertical_tile == (vertical_tiles - 1):
|
| 72 |
+
dimensions.append([horizontal_tile * width_of_square,
|
| 73 |
+
vertical_tile * height_of_square,
|
| 74 |
+
width_of_square + (horizontal_tile * width_of_square),
|
| 75 |
+
height_of_last_square + (height_of_square * vertical_tile)])
|
| 76 |
+
elif horizontal_tile == (horizontal_tiles - 1):
|
| 77 |
+
dimensions.append([horizontal_tile * width_of_square,
|
| 78 |
+
vertical_tile * height_of_square,
|
| 79 |
+
width_of_last_square + (horizontal_tile * width_of_square),
|
| 80 |
+
height_of_square + (height_of_square * vertical_tile)])
|
| 81 |
+
else:
|
| 82 |
+
dimensions.append([horizontal_tile * width_of_square,
|
| 83 |
+
vertical_tile * height_of_square,
|
| 84 |
+
width_of_square + (horizontal_tile * width_of_square),
|
| 85 |
+
height_of_square + (height_of_square * vertical_tile)])
|
| 86 |
+
|
| 87 |
+
sm_h = min(self.xmagnitude,
|
| 88 |
+
width_of_square - (self.min_h_sep + shift[vertical_tile][horizontal_tile - 1][
|
| 89 |
+
0])) if horizontal_tile > 0 else self.xmagnitude
|
| 90 |
+
sm_v = min(self.ymagnitude,
|
| 91 |
+
height_of_square - (self.min_v_sep + shift[vertical_tile - 1][horizontal_tile][
|
| 92 |
+
1])) if vertical_tile > 0 else self.ymagnitude
|
| 93 |
+
|
| 94 |
+
dx = random.randint(-sm_h, self.xmagnitude)
|
| 95 |
+
dy = random.randint(-sm_v, self.ymagnitude)
|
| 96 |
+
shift[vertical_tile][horizontal_tile] = (dx, dy)
|
| 97 |
+
|
| 98 |
+
shift = list(itertools.chain.from_iterable(shift))
|
| 99 |
+
|
| 100 |
+
last_column = []
|
| 101 |
+
for i in range(vertical_tiles):
|
| 102 |
+
last_column.append((horizontal_tiles - 1) + horizontal_tiles * i)
|
| 103 |
+
|
| 104 |
+
last_row = range((horizontal_tiles * vertical_tiles) - horizontal_tiles, horizontal_tiles * vertical_tiles)
|
| 105 |
+
|
| 106 |
+
polygons = []
|
| 107 |
+
for x1, y1, x2, y2 in dimensions:
|
| 108 |
+
polygons.append([x1, y1, x1, y2, x2, y2, x2, y1])
|
| 109 |
+
|
| 110 |
+
polygon_indices = []
|
| 111 |
+
for i in range((vertical_tiles * horizontal_tiles) - 1):
|
| 112 |
+
if i not in last_row and i not in last_column:
|
| 113 |
+
polygon_indices.append([i, i + 1, i + horizontal_tiles, i + 1 + horizontal_tiles])
|
| 114 |
+
|
| 115 |
+
for id, (a, b, c, d) in enumerate(polygon_indices):
|
| 116 |
+
dx = shift[id][0]
|
| 117 |
+
dy = shift[id][1]
|
| 118 |
+
|
| 119 |
+
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[a]
|
| 120 |
+
polygons[a] = [x1, y1,
|
| 121 |
+
x2, y2,
|
| 122 |
+
x3 + dx, y3 + dy,
|
| 123 |
+
x4, y4]
|
| 124 |
+
|
| 125 |
+
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[b]
|
| 126 |
+
polygons[b] = [x1, y1,
|
| 127 |
+
x2 + dx, y2 + dy,
|
| 128 |
+
x3, y3,
|
| 129 |
+
x4, y4]
|
| 130 |
+
|
| 131 |
+
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[c]
|
| 132 |
+
polygons[c] = [x1, y1,
|
| 133 |
+
x2, y2,
|
| 134 |
+
x3, y3,
|
| 135 |
+
x4 + dx, y4 + dy]
|
| 136 |
+
|
| 137 |
+
x1, y1, x2, y2, x3, y3, x4, y4 = polygons[d]
|
| 138 |
+
polygons[d] = [x1 + dx, y1 + dy,
|
| 139 |
+
x2, y2,
|
| 140 |
+
x3, y3,
|
| 141 |
+
x4, y4]
|
| 142 |
+
|
| 143 |
+
generated_mesh = []
|
| 144 |
+
for i in range(len(dimensions)):
|
| 145 |
+
generated_mesh.append([dimensions[i], polygons[i]])
|
| 146 |
+
|
| 147 |
+
self.generated_mesh = generated_mesh
|
| 148 |
+
|
| 149 |
+
return x.transform(x.size, Image.MESH, self.generated_mesh, resample=Image.BICUBIC)
|
| 150 |
+
|
| 151 |
+
class RandomTransform:
|
| 152 |
+
"""
|
| 153 |
+
Random Transform adapted from https://github.com/IntuitionMachines/OrigamiNet
|
| 154 |
+
Used in "OrigamiNet: Weakly-Supervised, Segmentation-Free, One-Step, Full Page TextRecognition by learning to unfold",
|
| 155 |
+
Yousef, Mohamed and Bishop, Tom E., The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2020
|
| 156 |
+
"""
|
| 157 |
+
def __init__(self, val):
|
| 158 |
+
|
| 159 |
+
self.val = val
|
| 160 |
+
|
| 161 |
+
def __call__(self, x):
|
| 162 |
+
w, h = x.size
|
| 163 |
+
|
| 164 |
+
dw, dh = (self.val, 0) if random.randint(0, 2) == 0 else (0, self.val)
|
| 165 |
+
|
| 166 |
+
def rd(d):
|
| 167 |
+
return random.uniform(-d, d)
|
| 168 |
+
|
| 169 |
+
def fd(d):
|
| 170 |
+
return random.uniform(-dw, d)
|
| 171 |
+
|
| 172 |
+
# generate a random projective transform
|
| 173 |
+
# adapted from https://navoshta.com/traffic-signs-classification/
|
| 174 |
+
tl_top = rd(dh)
|
| 175 |
+
tl_left = fd(dw)
|
| 176 |
+
bl_bottom = rd(dh)
|
| 177 |
+
bl_left = fd(dw)
|
| 178 |
+
tr_top = rd(dh)
|
| 179 |
+
tr_right = fd(min(w * 3 / 4 - tl_left, dw))
|
| 180 |
+
br_bottom = rd(dh)
|
| 181 |
+
br_right = fd(min(w * 3 / 4 - bl_left, dw))
|
| 182 |
+
|
| 183 |
+
tform = stf.ProjectiveTransform()
|
| 184 |
+
tform.estimate(np.array(( #从对应点估计变换矩阵
|
| 185 |
+
(tl_left, tl_top),
|
| 186 |
+
(bl_left, h - bl_bottom),
|
| 187 |
+
(w - br_right, h - br_bottom),
|
| 188 |
+
(w - tr_right, tr_top)
|
| 189 |
+
)), np.array((
|
| 190 |
+
[0, 0],
|
| 191 |
+
[0, h - 1],
|
| 192 |
+
[w - 1, h - 1],
|
| 193 |
+
[w - 1, 0]
|
| 194 |
+
)))
|
| 195 |
+
|
| 196 |
+
# determine shape of output image, to preserve size
|
| 197 |
+
# trick take from the implementation of skimage.transform.rotate
|
| 198 |
+
corners = np.array([
|
| 199 |
+
[0, 0],
|
| 200 |
+
[0, h - 1],
|
| 201 |
+
[w - 1, h - 1],
|
| 202 |
+
[w - 1, 0]
|
| 203 |
+
])
|
| 204 |
+
|
| 205 |
+
corners = tform.inverse(corners)
|
| 206 |
+
minc = corners[:, 0].min()
|
| 207 |
+
minr = corners[:, 1].min()
|
| 208 |
+
maxc = corners[:, 0].max()
|
| 209 |
+
maxr = corners[:, 1].max()
|
| 210 |
+
out_rows = maxr - minr + 1
|
| 211 |
+
out_cols = maxc - minc + 1
|
| 212 |
+
output_shape = np.around((out_rows, out_cols))
|
| 213 |
+
|
| 214 |
+
# fit output image in new shape
|
| 215 |
+
translation = (minc, minr)
|
| 216 |
+
tform4 = stf.SimilarityTransform(translation=translation)
|
| 217 |
+
tform = tform4 + tform
|
| 218 |
+
# normalize
|
| 219 |
+
tform.params /= tform.params[2, 2]
|
| 220 |
+
|
| 221 |
+
x = stf.warp(np.array(x), tform, output_shape=output_shape, cval=255, preserve_range=True)
|
| 222 |
+
x = stf.resize(x, (h, w), preserve_range=True).astype(np.uint8)
|
| 223 |
+
|
| 224 |
+
return Image.fromarray(x)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class SignFlipping:
|
| 228 |
+
"""
|
| 229 |
+
Color inversion
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
def __init__(self):
|
| 233 |
+
pass
|
| 234 |
+
|
| 235 |
+
def __call__(self, x):
|
| 236 |
+
return ImageOps.invert(x)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class DPIAdjusting:
|
| 240 |
+
"""
|
| 241 |
+
Resolution modification
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(self, factor, preserve_ratio):
|
| 245 |
+
self.factor = factor
|
| 246 |
+
|
| 247 |
+
def __call__(self, x):
|
| 248 |
+
w, h = x.size
|
| 249 |
+
return x.resize((int(np.ceil(w * self.factor)), int(np.ceil(h * self.factor))), Image.BILINEAR)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class GaussianNoise:
|
| 254 |
+
"""
|
| 255 |
+
Add Gaussian Noise
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, std):
|
| 259 |
+
self.std = std
|
| 260 |
+
|
| 261 |
+
def __call__(self, x):
|
| 262 |
+
x_np = np.array(x)
|
| 263 |
+
mean, std = np.mean(x_np), np.std(x_np)
|
| 264 |
+
std = math.copysign(max(abs(std), 0.000001), std)
|
| 265 |
+
min_, max_ = np.min(x_np,), np.max(x_np)
|
| 266 |
+
normal_noise = np.random.randn(*x_np.shape)
|
| 267 |
+
if len(x_np.shape) == 3 and x_np.shape[2] == 3 and np.all(x_np[:, :, 0] == x_np[:, :, 1]) and np.all(x_np[:, :, 0] == x_np[:, :, 2]):
|
| 268 |
+
normal_noise[:, :, 1] = normal_noise[:, :, 2] = normal_noise[:, :, 0]
|
| 269 |
+
x_np = ((x_np-mean)/std + normal_noise*self.std) * std + mean
|
| 270 |
+
x_np = normalize(x_np, x_np, max_, min_, cv2.NORM_MINMAX)
|
| 271 |
+
|
| 272 |
+
return Image.fromarray(x_np.astype(np.uint8))
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class Sharpen:
|
| 276 |
+
"""
|
| 277 |
+
Add Gaussian Noise
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
def __init__(self, alpha, strength):
|
| 281 |
+
self.alpha = alpha
|
| 282 |
+
self.strength = strength
|
| 283 |
+
|
| 284 |
+
def __call__(self, x):
|
| 285 |
+
x_np = np.array(x)
|
| 286 |
+
id_matrix = np.array([[0, 0, 0],
|
| 287 |
+
[0, 1, 0],
|
| 288 |
+
[0, 0, 0]]
|
| 289 |
+
)
|
| 290 |
+
effect_matrix = np.array([[1, 1, 1],
|
| 291 |
+
[1, -(8+self.strength), 1],
|
| 292 |
+
[1, 1, 1]]
|
| 293 |
+
)
|
| 294 |
+
kernel = (1 - self.alpha) * id_matrix - self.alpha * effect_matrix
|
| 295 |
+
kernel = np.expand_dims(kernel, axis=2)
|
| 296 |
+
kernel = np.concatenate([kernel, kernel, kernel], axis=2)
|
| 297 |
+
sharpened = cv2.filter2D(x_np, -1, kernel=kernel[:, :, 0])
|
| 298 |
+
return Image.fromarray(sharpened.astype(np.uint8))
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class ZoomRatio:
|
| 302 |
+
"""
|
| 303 |
+
Crop by ratio
|
| 304 |
+
Preserve dimensions if keep_dim = True (= zoom)
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
def __init__(self, ratio_h, ratio_w, keep_dim=True):
|
| 308 |
+
self.ratio_w = ratio_w
|
| 309 |
+
self.ratio_h = ratio_h
|
| 310 |
+
self.keep_dim = keep_dim
|
| 311 |
+
|
| 312 |
+
def __call__(self, x):
|
| 313 |
+
w, h = x.size
|
| 314 |
+
x = RandomCrop((int(h * self.ratio_h), int(w * self.ratio_w)))(x)
|
| 315 |
+
if self.keep_dim:
|
| 316 |
+
x = x.resize((w, h), Image.BILINEAR)
|
| 317 |
+
return x
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class Tightening:
|
| 321 |
+
"""
|
| 322 |
+
Reduce interline spacing
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
def __init__(self, color=255, remove_proba=0.75):
|
| 326 |
+
self.color = color
|
| 327 |
+
self.remove_proba = remove_proba
|
| 328 |
+
|
| 329 |
+
def __call__(self, x):
|
| 330 |
+
x_np = np.array(x)
|
| 331 |
+
interline_indices = [np.all(line == 255) for line in x_np]
|
| 332 |
+
indices_to_removed = np.logical_and(np.random.choice([True, False], size=len(x_np), replace=True, p=[self.remove_proba, 1-self.remove_proba]), interline_indices)
|
| 333 |
+
new_x = x_np[np.logical_not(indices_to_removed)]
|
| 334 |
+
return Image.fromarray(new_x.astype(np.uint8))
|
image/architecture.png
ADDED
|
Git LFS Details
|
model/htr_convtext.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from timm.models.vision_transformer import Mlp, DropPath
|
| 5 |
+
from timm.layers import LayerScale
|
| 6 |
+
import numpy as np
|
| 7 |
+
from model import resnet18
|
| 8 |
+
from functools import partial
|
| 9 |
+
import random
|
| 10 |
+
import re
|
| 11 |
+
import warnings
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RelativePositionBias1D(nn.Module):
|
| 15 |
+
def __init__(self, num_heads: int, max_rel_positions: int = 1024):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.num_heads = num_heads
|
| 18 |
+
self.max_rel_positions = max(1, int(max_rel_positions))
|
| 19 |
+
self.bias = nn.Embedding(2 * self.max_rel_positions - 1, num_heads)
|
| 20 |
+
nn.init.zeros_(self.bias.weight)
|
| 21 |
+
|
| 22 |
+
def forward(self, N: int) -> torch.Tensor:
|
| 23 |
+
device = self.bias.weight.device
|
| 24 |
+
coords = torch.arange(N, device=device)
|
| 25 |
+
rel = coords[:, None] - coords[None, :]
|
| 26 |
+
rel = rel.clamp(-self.max_rel_positions + 1,
|
| 27 |
+
self.max_rel_positions - 1)
|
| 28 |
+
rel = rel + (self.max_rel_positions - 1)
|
| 29 |
+
bias = self.bias(rel)
|
| 30 |
+
return bias.permute(2, 0, 1).unsqueeze(0)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Attention(nn.Module):
|
| 34 |
+
def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 35 |
+
super().__init__()
|
| 36 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 37 |
+
self.num_heads = num_heads
|
| 38 |
+
head_dim = dim // num_heads
|
| 39 |
+
self.scale = head_dim ** -0.5
|
| 40 |
+
max_rel_positions = max(
|
| 41 |
+
1, int(num_patches)) if num_patches is not None else 1024
|
| 42 |
+
self.rel_pos_bias = RelativePositionBias1D(
|
| 43 |
+
num_heads=num_heads, max_rel_positions=max_rel_positions)
|
| 44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 46 |
+
self.proj = nn.Linear(dim, dim)
|
| 47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
B, N, C = x.shape
|
| 51 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
|
| 52 |
+
self.num_heads).permute(2, 0, 3, 1, 4)
|
| 53 |
+
q, k, v = qkv.unbind(0)
|
| 54 |
+
|
| 55 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 56 |
+
attn = attn + self.rel_pos_bias(N)
|
| 57 |
+
attn = attn.softmax(dim=-1)
|
| 58 |
+
attn = self.attn_drop(attn)
|
| 59 |
+
|
| 60 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 61 |
+
x = self.proj(x)
|
| 62 |
+
x = self.proj_drop(x)
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class FeedForward(nn.Module):
|
| 67 |
+
def __init__(self, dim, hidden_dim, dropout=0.1, activation=nn.SiLU):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.lin1 = nn.Linear(dim, hidden_dim)
|
| 70 |
+
self.act = activation()
|
| 71 |
+
self.lin2 = nn.Linear(hidden_dim, dim)
|
| 72 |
+
self.dropout = nn.Dropout(dropout)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
return self.dropout(self.lin2(self.act(self.lin1(x))))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ConvModule(nn.Module):
|
| 79 |
+
def __init__(self, dim, kernel_size=3, dropout=0.1, drop_path=0.0,
|
| 80 |
+
expansion=1.0, pre_norm=False, activation=nn.SiLU):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.pre_norm = nn.LayerNorm(dim) if pre_norm else None
|
| 83 |
+
hidden = int(round(dim * expansion))
|
| 84 |
+
|
| 85 |
+
self.pw1 = nn.Conv1d(dim, hidden, kernel_size=1, bias=True)
|
| 86 |
+
self.act1 = activation()
|
| 87 |
+
|
| 88 |
+
self.dw = nn.Conv1d(hidden, hidden, kernel_size=kernel_size,
|
| 89 |
+
padding=kernel_size // 2, groups=hidden, bias=True)
|
| 90 |
+
self.gn = nn.GroupNorm(1, hidden, eps=1e-5)
|
| 91 |
+
self.act2 = activation()
|
| 92 |
+
|
| 93 |
+
self.pw2 = nn.Conv1d(hidden, dim, kernel_size=1, bias=True)
|
| 94 |
+
self.dropout = nn.Dropout(dropout)
|
| 95 |
+
self.drop_path = DropPath(
|
| 96 |
+
drop_path) if drop_path > 0.0 else nn.Identity()
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
if self.pre_norm is not None:
|
| 100 |
+
x = self.pre_norm(x)
|
| 101 |
+
z = x.transpose(1, 2)
|
| 102 |
+
z = self.pw1(z)
|
| 103 |
+
z = self.act1(z)
|
| 104 |
+
z = self.dw(z)
|
| 105 |
+
z = self.gn(z)
|
| 106 |
+
z = self.act2(z)
|
| 107 |
+
z = self.pw2(z)
|
| 108 |
+
z = self.dropout(z).transpose(1, 2)
|
| 109 |
+
return self.drop_path(z)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Downsample1D(nn.Module):
|
| 113 |
+
def __init__(self, dim, kernel_size=3, stride=2, lowpass_init=True):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.dw = nn.Conv1d(dim, dim, kernel_size=kernel_size,
|
| 116 |
+
stride=stride, padding=kernel_size//2,
|
| 117 |
+
groups=dim, bias=False)
|
| 118 |
+
self.pw = nn.Conv1d(dim, dim, kernel_size=1, bias=True)
|
| 119 |
+
if lowpass_init:
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
w = torch.zeros_like(self.dw.weight)
|
| 122 |
+
w[:, 0, :] = 1.0 / kernel_size
|
| 123 |
+
self.dw.weight.copy_(w)
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
x = x.transpose(1, 2)
|
| 127 |
+
x = self.pw(self.dw(x))
|
| 128 |
+
return x.transpose(1, 2)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class Upsample1D(nn.Module):
|
| 132 |
+
def __init__(self, dim, mode: str = 'nearest'):
|
| 133 |
+
super().__init__()
|
| 134 |
+
assert mode in (
|
| 135 |
+
'nearest', 'linear'), "Upsample1D mode must be 'nearest' or 'linear'"
|
| 136 |
+
self.mode = mode
|
| 137 |
+
self.proj = nn.Conv1d(dim, dim, kernel_size=1, bias=True)
|
| 138 |
+
|
| 139 |
+
def forward(self, x, target_len: int):
|
| 140 |
+
x = x.transpose(1, 2)
|
| 141 |
+
if self.mode == 'nearest':
|
| 142 |
+
x = F.interpolate(x, size=target_len, mode='nearest')
|
| 143 |
+
else:
|
| 144 |
+
x = F.interpolate(x, size=target_len,
|
| 145 |
+
mode='linear', align_corners=False)
|
| 146 |
+
x = self.proj(x)
|
| 147 |
+
return x.transpose(1, 2)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class ConvTextBlock(nn.Module):
|
| 151 |
+
def __init__(self,
|
| 152 |
+
dim,
|
| 153 |
+
num_heads,
|
| 154 |
+
num_patches,
|
| 155 |
+
mlp_ratio=4.0,
|
| 156 |
+
ff_dropout=0.1,
|
| 157 |
+
attn_dropout=0.0,
|
| 158 |
+
conv_dropout=0.0,
|
| 159 |
+
conv_kernel_size=3,
|
| 160 |
+
conv_expansion=1.0,
|
| 161 |
+
norm_layer=nn.LayerNorm,
|
| 162 |
+
drop_path=0.0,
|
| 163 |
+
layerscale_init=1e-5):
|
| 164 |
+
super().__init__()
|
| 165 |
+
|
| 166 |
+
ff_hidden = int(dim * mlp_ratio)
|
| 167 |
+
|
| 168 |
+
self.attn = Attention(dim, num_patches, num_heads=num_heads,
|
| 169 |
+
qkv_bias=True, attn_drop=attn_dropout, proj_drop=ff_dropout)
|
| 170 |
+
|
| 171 |
+
self.ffn1 = FeedForward(
|
| 172 |
+
dim, ff_hidden, dropout=ff_dropout, activation=nn.SiLU)
|
| 173 |
+
self.conv = ConvModule(dim, kernel_size=conv_kernel_size,
|
| 174 |
+
dropout=conv_dropout, drop_path=0.0,
|
| 175 |
+
expansion=conv_expansion, pre_norm=False, activation=nn.SiLU)
|
| 176 |
+
self.ffn2 = FeedForward(
|
| 177 |
+
dim, ff_hidden, dropout=ff_dropout, activation=nn.SiLU)
|
| 178 |
+
|
| 179 |
+
self.postln_attn = norm_layer(dim, elementwise_affine=True)
|
| 180 |
+
self.postln_ffn1 = norm_layer(dim, elementwise_affine=True)
|
| 181 |
+
self.postln_conv = norm_layer(dim, elementwise_affine=True)
|
| 182 |
+
self.postln_ffn2 = norm_layer(dim, elementwise_affine=True)
|
| 183 |
+
|
| 184 |
+
self.dp_attn = DropPath(
|
| 185 |
+
drop_path) if drop_path > 0.0 else nn.Identity()
|
| 186 |
+
self.dp_ffn1 = DropPath(
|
| 187 |
+
drop_path) if drop_path > 0.0 else nn.Identity()
|
| 188 |
+
self.dp_conv = DropPath(
|
| 189 |
+
drop_path) if drop_path > 0.0 else nn.Identity()
|
| 190 |
+
self.dp_ffn2 = DropPath(
|
| 191 |
+
drop_path) if drop_path > 0.0 else nn.Identity()
|
| 192 |
+
|
| 193 |
+
self.ls_attn = LayerScale(dim, init_values=layerscale_init)
|
| 194 |
+
self.ls_ffn1 = LayerScale(dim, init_values=layerscale_init)
|
| 195 |
+
self.ls_conv = LayerScale(dim, init_values=layerscale_init)
|
| 196 |
+
self.ls_ffn2 = LayerScale(dim, init_values=layerscale_init)
|
| 197 |
+
|
| 198 |
+
def forward(self, x):
|
| 199 |
+
x = self.postln_attn(x + self.ls_attn(self.dp_attn(self.attn(x))))
|
| 200 |
+
x = self.postln_ffn1(
|
| 201 |
+
x + self.ls_ffn1(0.5 * self.dp_ffn1(self.ffn1(x))))
|
| 202 |
+
x = self.postln_conv(x + self.ls_conv(self.dp_conv(self.conv(x))))
|
| 203 |
+
x = self.postln_ffn2(
|
| 204 |
+
x + self.ls_ffn2(0.5 * self.dp_ffn2(self.ffn2(x))))
|
| 205 |
+
return x
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size):
|
| 209 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32)
|
| 210 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32)
|
| 211 |
+
grid = np.meshgrid(grid_w, grid_h)
|
| 212 |
+
grid = np.stack(grid, axis=0)
|
| 213 |
+
|
| 214 |
+
grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
|
| 215 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 216 |
+
return pos_embed
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 220 |
+
assert embed_dim % 2 == 0
|
| 221 |
+
|
| 222 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(
|
| 223 |
+
embed_dim // 2, grid[0])
|
| 224 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(
|
| 225 |
+
embed_dim // 2, grid[1])
|
| 226 |
+
|
| 227 |
+
emb = np.concatenate([emb_h, emb_w], axis=1)
|
| 228 |
+
return emb
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 232 |
+
assert embed_dim % 2 == 0
|
| 233 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 234 |
+
omega /= embed_dim / 2.
|
| 235 |
+
omega = 1. / 10000 ** omega
|
| 236 |
+
|
| 237 |
+
pos = pos.reshape(-1)
|
| 238 |
+
out = np.einsum('m,d->md', pos, omega)
|
| 239 |
+
|
| 240 |
+
emb_sin = np.sin(out)
|
| 241 |
+
emb_cos = np.cos(out)
|
| 242 |
+
|
| 243 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1)
|
| 244 |
+
return emb
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class HTR_ConvText(nn.Module):
|
| 248 |
+
def __init__(
|
| 249 |
+
self,
|
| 250 |
+
nb_cls=80,
|
| 251 |
+
img_size=[512, 64],
|
| 252 |
+
patch_size=[4, 32],
|
| 253 |
+
embed_dim=1024,
|
| 254 |
+
depth=24,
|
| 255 |
+
num_heads=16,
|
| 256 |
+
mlp_ratio=4.0,
|
| 257 |
+
norm_layer=nn.LayerNorm,
|
| 258 |
+
conv_kernel_size: int = 3,
|
| 259 |
+
dropout: float = 0.1,
|
| 260 |
+
drop_path: float = 0.1,
|
| 261 |
+
down_after: int = 2,
|
| 262 |
+
up_after: int = 4,
|
| 263 |
+
ds_kernel: int = 3,
|
| 264 |
+
max_seq_len: int = 1024,
|
| 265 |
+
upsample_mode: str = 'nearest',
|
| 266 |
+
):
|
| 267 |
+
super().__init__()
|
| 268 |
+
|
| 269 |
+
self.patch_embed = resnet18.ResNet18(embed_dim)
|
| 270 |
+
self.embed_dim = embed_dim
|
| 271 |
+
self.max_rel_pos = int(max_seq_len)
|
| 272 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 273 |
+
|
| 274 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
| 275 |
+
self.blocks = nn.ModuleList([
|
| 276 |
+
ConvTextBlock(embed_dim, num_heads, self.max_rel_pos,
|
| 277 |
+
mlp_ratio=mlp_ratio,
|
| 278 |
+
ff_dropout=dropout, attn_dropout=dropout,
|
| 279 |
+
conv_dropout=dropout, conv_kernel_size=conv_kernel_size,
|
| 280 |
+
conv_expansion=1.0,
|
| 281 |
+
norm_layer=norm_layer, drop_path=dpr[i],
|
| 282 |
+
layerscale_init=1e-5)
|
| 283 |
+
for i in range(depth)
|
| 284 |
+
])
|
| 285 |
+
|
| 286 |
+
self.norm = norm_layer(embed_dim, elementwise_affine=True)
|
| 287 |
+
self.head = torch.nn.Linear(embed_dim, nb_cls)
|
| 288 |
+
self.down_after = down_after
|
| 289 |
+
self.up_after = up_after
|
| 290 |
+
self.down1 = Downsample1D(embed_dim, kernel_size=ds_kernel)
|
| 291 |
+
self.up1 = Upsample1D(embed_dim, mode=upsample_mode)
|
| 292 |
+
self.initialize_weights()
|
| 293 |
+
|
| 294 |
+
def initialize_weights(self):
|
| 295 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
| 296 |
+
self.apply(self._init_weights)
|
| 297 |
+
|
| 298 |
+
def _init_weights(self, m):
|
| 299 |
+
if isinstance(m, nn.Linear):
|
| 300 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 301 |
+
if m.bias is not None:
|
| 302 |
+
nn.init.constant_(m.bias, 0)
|
| 303 |
+
elif isinstance(m, nn.LayerNorm):
|
| 304 |
+
nn.init.constant_(m.bias, 0)
|
| 305 |
+
nn.init.constant_(m.weight, 1.0)
|
| 306 |
+
|
| 307 |
+
def mask_random_1d(self, x, ratio):
|
| 308 |
+
B, L, _ = x.shape
|
| 309 |
+
mask = torch.ones(B, L, dtype=torch.bool).to(x.device)
|
| 310 |
+
if ratio <= 0.0 or ratio > 1.0:
|
| 311 |
+
return mask
|
| 312 |
+
num = int(round(ratio * L))
|
| 313 |
+
if num <= 0:
|
| 314 |
+
return mask
|
| 315 |
+
noise = torch.rand(B, L).to(x.device)
|
| 316 |
+
idx = noise.argsort(dim=1)[:, :num]
|
| 317 |
+
mask.scatter_(1, idx, False)
|
| 318 |
+
return mask
|
| 319 |
+
|
| 320 |
+
def mask_block_1d(self, x, ratio: float, max_block_length: int):
|
| 321 |
+
B, L, _ = x.shape
|
| 322 |
+
device = x.device
|
| 323 |
+
|
| 324 |
+
if ratio <= 0.0:
|
| 325 |
+
return torch.ones(B, L, 1, dtype=torch.bool, device=device)
|
| 326 |
+
if ratio >= 1.0:
|
| 327 |
+
return torch.zeros(B, L, 1, dtype=torch.bool, device=device)
|
| 328 |
+
|
| 329 |
+
target_mask_tokens = int(round(ratio * L))
|
| 330 |
+
K = target_mask_tokens // max_block_length
|
| 331 |
+
K = max(K, 1)
|
| 332 |
+
starts = torch.randint(0, max(1, L - max_block_length + 1), (B, K), device=device)
|
| 333 |
+
lengths = torch.randint(1, max_block_length + 1, (B, K), device=device)
|
| 334 |
+
positions = torch.arange(L, device=device).view(1, 1, L)
|
| 335 |
+
starts_exp = starts.unsqueeze(-1)
|
| 336 |
+
ends_exp = (starts + lengths).unsqueeze(-1).clamp(max=L)
|
| 337 |
+
blocks_mask = (positions >= starts_exp) & (positions < ends_exp)
|
| 338 |
+
masked_any = blocks_mask.any(dim=1)
|
| 339 |
+
keep_mask = ~masked_any
|
| 340 |
+
return keep_mask.unsqueeze(-1)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def mask_span_1d(self, x, ratio: float, max_span_length: int):
|
| 344 |
+
B, L, _ = x.shape
|
| 345 |
+
device = x.device
|
| 346 |
+
|
| 347 |
+
if ratio <= 0.0:
|
| 348 |
+
return torch.ones(B, L, 1, dtype=torch.bool, device=device)
|
| 349 |
+
if ratio >= 1.0:
|
| 350 |
+
return torch.zeros(B, L, 1, dtype=torch.bool, device=device)
|
| 351 |
+
|
| 352 |
+
target_mask_tokens = int(round(ratio * L))
|
| 353 |
+
K = target_mask_tokens // max_span_length
|
| 354 |
+
K = max(K, 1)
|
| 355 |
+
starts = torch.randint(0, max(1, L - max_span_length + 1), (B, K), device=device)
|
| 356 |
+
lengths = torch.full((B, K), max_span_length, device=device)
|
| 357 |
+
positions = torch.arange(L, device=device).view(1, 1, L)
|
| 358 |
+
starts_exp = starts.unsqueeze(-1)
|
| 359 |
+
ends_exp = (starts + lengths).unsqueeze(-1).clamp(max=L)
|
| 360 |
+
spans_mask = (positions >= starts_exp) & (positions < ends_exp)
|
| 361 |
+
masked_any = spans_mask.any(dim=1)
|
| 362 |
+
keep_mask = ~masked_any
|
| 363 |
+
return keep_mask.unsqueeze(-1)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def forward_features(self, x, use_masking=False,
|
| 367 |
+
mask_mode="span",
|
| 368 |
+
mask_ratio=0.5, block_span=4, max_span_length=8):
|
| 369 |
+
x = self.patch_embed(x)
|
| 370 |
+
B, C, W, H = x.shape
|
| 371 |
+
assert C == self.embed_dim, f"Expected embed_dim {self.embed_dim}, got {C}"
|
| 372 |
+
x = x.view(B, C, -1).permute(0, 2, 1)
|
| 373 |
+
|
| 374 |
+
masked_positions_1d = None
|
| 375 |
+
if use_masking:
|
| 376 |
+
if mask_mode == "random":
|
| 377 |
+
keep_mask_1d = self.mask_random_1d(x, mask_ratio).float()
|
| 378 |
+
mask = keep_mask_1d.unsqueeze(-1)
|
| 379 |
+
elif mask_mode in ("block"):
|
| 380 |
+
keep_mask = self.mask_block_1d(x, mask_ratio, block_span).float()
|
| 381 |
+
keep_mask_1d = keep_mask.squeeze(-1)
|
| 382 |
+
mask = keep_mask
|
| 383 |
+
elif mask_mode in ("span"):
|
| 384 |
+
keep_mask = self.mask_span_1d(
|
| 385 |
+
x, mask_ratio, max_span_length).float()
|
| 386 |
+
keep_mask_1d = keep_mask.squeeze(-1)
|
| 387 |
+
mask = keep_mask
|
| 388 |
+
else:
|
| 389 |
+
warnings.warn(
|
| 390 |
+
f"Unknown mask_mode '{mask_mode}', defaulting to span.")
|
| 391 |
+
keep_mask = self.mask_span_1d(
|
| 392 |
+
x, mask_ratio, max_span_length).float()
|
| 393 |
+
keep_mask_1d = keep_mask.squeeze(-1)
|
| 394 |
+
mask = keep_mask
|
| 395 |
+
masked_positions_1d = (1.0 - keep_mask_1d).clamp(min=0.0, max=1.0)
|
| 396 |
+
x = mask * x + (1.0 - mask) * \
|
| 397 |
+
self.mask_token.expand(x.size(0), x.size(1), x.size(2))
|
| 398 |
+
skip_hi = None
|
| 399 |
+
for i, blk in enumerate(self.blocks, 1):
|
| 400 |
+
x = blk(x)
|
| 401 |
+
if i == self.down_after:
|
| 402 |
+
skip_hi = x
|
| 403 |
+
if (x.size(1) % 2) == 1:
|
| 404 |
+
x = torch.cat([x, x[:, -1:, :]], dim=1)
|
| 405 |
+
x = self.down1(x)
|
| 406 |
+
if i == self.up_after:
|
| 407 |
+
assert skip_hi is not None, "Upsample requires a stored skip."
|
| 408 |
+
x = self.up1(x, target_len=skip_hi.size(1))
|
| 409 |
+
x = x + skip_hi
|
| 410 |
+
|
| 411 |
+
x = self.norm(x)
|
| 412 |
+
return x, masked_positions_1d
|
| 413 |
+
|
| 414 |
+
def forward(self, x, use_masking=False, return_features=False, return_mask=False,
|
| 415 |
+
mask_mode="span", mask_ratio=None, block_span=None, max_span_length=None):
|
| 416 |
+
feats, masked_positions_1d = self.forward_features(
|
| 417 |
+
x, use_masking=use_masking, mask_mode=mask_mode, mask_ratio=mask_ratio, block_span=block_span, max_span_length=max_span_length)
|
| 418 |
+
logits = self.head(feats)
|
| 419 |
+
if return_features and return_mask:
|
| 420 |
+
return logits, feats, (masked_positions_1d if masked_positions_1d is not None else None)
|
| 421 |
+
if return_features:
|
| 422 |
+
return logits, feats
|
| 423 |
+
if return_mask:
|
| 424 |
+
return logits, (masked_positions_1d if masked_positions_1d is not None else None)
|
| 425 |
+
return logits
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def create_model(nb_cls, img_size, mlp_ratio=4, **kwargs):
|
| 429 |
+
model = HTR_ConvText(
|
| 430 |
+
nb_cls,
|
| 431 |
+
img_size=img_size,
|
| 432 |
+
patch_size=(4, 64),
|
| 433 |
+
embed_dim=512,
|
| 434 |
+
depth=8,
|
| 435 |
+
num_heads=8,
|
| 436 |
+
mlp_ratio=mlp_ratio,
|
| 437 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 438 |
+
conv_kernel_size=7,
|
| 439 |
+
down_after=3,
|
| 440 |
+
up_after=7,
|
| 441 |
+
ds_kernel=3,
|
| 442 |
+
max_seq_len=128,
|
| 443 |
+
upsample_mode='nearest',
|
| 444 |
+
**kwargs,
|
| 445 |
+
)
|
| 446 |
+
return model
|
model/layer.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from typing import Optional, Union, Tuple
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ConvLayer2d(nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
in_channels: int,
|
| 10 |
+
out_channels: int,
|
| 11 |
+
kernel_size: Union[int, Tuple[int, int]],
|
| 12 |
+
stride: int = 1,
|
| 13 |
+
padding: int = 0,
|
| 14 |
+
dilation: int = 1,
|
| 15 |
+
groups: int = 1,
|
| 16 |
+
bias: bool = False,
|
| 17 |
+
use_norm: bool = True,
|
| 18 |
+
use_act: bool = True,
|
| 19 |
+
norm_layer: Optional[nn.Module] = None,
|
| 20 |
+
act_layer: Optional[nn.Module] = None,
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
layers = []
|
| 24 |
+
layers.append(
|
| 25 |
+
nn.Conv2d(
|
| 26 |
+
in_channels=in_channels,
|
| 27 |
+
out_channels=out_channels,
|
| 28 |
+
kernel_size=kernel_size,
|
| 29 |
+
stride=stride,
|
| 30 |
+
padding=padding,
|
| 31 |
+
dilation=dilation,
|
| 32 |
+
groups=groups,
|
| 33 |
+
bias=bias
|
| 34 |
+
)
|
| 35 |
+
)
|
| 36 |
+
if use_norm:
|
| 37 |
+
if norm_layer is None:
|
| 38 |
+
norm_layer = nn.BatchNorm2d(out_channels)
|
| 39 |
+
layers.append(norm_layer)
|
| 40 |
+
if use_act:
|
| 41 |
+
if act_layer is None:
|
| 42 |
+
act_layer = nn.ReLU(inplace=True)
|
| 43 |
+
layers.append(act_layer)
|
| 44 |
+
|
| 45 |
+
self.block = nn.Sequential(*layers)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
return self.block(x)
|
| 49 |
+
|
| 50 |
+
# PEG from https://arxiv.org/abs/2102.10882
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class PosCNN(nn.Module):
|
| 54 |
+
def __init__(self, in_chans, embed_dim=None, s=1):
|
| 55 |
+
super(PosCNN, self).__init__()
|
| 56 |
+
self.proj = nn.Sequential(
|
| 57 |
+
nn.Conv2d(in_chans, embed_dim, 3, s, 1,
|
| 58 |
+
bias=True, groups=embed_dim),
|
| 59 |
+
)
|
| 60 |
+
self.s = s
|
| 61 |
+
|
| 62 |
+
def forward(self, x, H, W):
|
| 63 |
+
B, N, C = x.shape
|
| 64 |
+
|
| 65 |
+
feat_token = x
|
| 66 |
+
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
|
| 67 |
+
if self.s == 1:
|
| 68 |
+
x = self.proj(cnn_feat) + cnn_feat
|
| 69 |
+
else:
|
| 70 |
+
x = self.proj(cnn_feat)
|
| 71 |
+
x = x.flatten(2).transpose(1, 2)
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
def no_weight_decay(self):
|
| 75 |
+
return ["proj.%d.weight" % i for i in range(4)]
|
model/resnet18.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
#
|
| 5 |
+
# For licensing see accompanying LICENSE file.
|
| 6 |
+
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
|
| 7 |
+
#
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
from typing import Dict, Optional, Sequence, Tuple, Union
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch import Tensor, nn
|
| 15 |
+
from torch.nn import functional as F
|
| 16 |
+
from .layer import ConvLayer2d, PosCNN
|
| 17 |
+
from timm.models.vision_transformer import Mlp, DropPath
|
| 18 |
+
|
| 19 |
+
from typing import Any
|
| 20 |
+
class BaseModule(nn.Module):
|
| 21 |
+
"""Base class for all modules"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, *args, **kwargs):
|
| 24 |
+
super(BaseModule, self).__init__()
|
| 25 |
+
|
| 26 |
+
def forward(self, x: Any, *args, **kwargs) -> Any:
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
def __repr__(self):
|
| 30 |
+
return "{}".format(self.__class__.__name__)
|
| 31 |
+
|
| 32 |
+
class Attention(nn.Module):
|
| 33 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 34 |
+
super().__init__()
|
| 35 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 36 |
+
self.num_heads = num_heads
|
| 37 |
+
head_dim = dim // num_heads
|
| 38 |
+
self.scale = head_dim ** -0.5
|
| 39 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 40 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 41 |
+
self.proj = nn.Linear(dim, dim)
|
| 42 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
B, N, C = x.shape
|
| 46 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 47 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 48 |
+
|
| 49 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 50 |
+
attn = attn.softmax(dim=-1)
|
| 51 |
+
attn = self.attn_drop(attn)
|
| 52 |
+
|
| 53 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 54 |
+
x = self.proj(x)
|
| 55 |
+
x = self.proj_drop(x)
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
class LayerScale(nn.Module):
|
| 59 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.inplace = inplace
|
| 62 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Block(nn.Module):
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
dim,
|
| 73 |
+
num_heads,
|
| 74 |
+
mlp_ratio=4.,
|
| 75 |
+
qkv_bias=False,
|
| 76 |
+
drop=0.0,
|
| 77 |
+
attn_drop=0.,
|
| 78 |
+
init_values=None,
|
| 79 |
+
drop_path=0.,
|
| 80 |
+
act_layer=nn.GELU,
|
| 81 |
+
norm_layer=nn.LayerNorm
|
| 82 |
+
):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.norm1 = norm_layer(dim, elementwise_affine=True)
|
| 85 |
+
|
| 86 |
+
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
| 87 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 88 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 89 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 90 |
+
|
| 91 |
+
self.norm2 = norm_layer(dim, elementwise_affine=True)
|
| 92 |
+
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
| 93 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 94 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 95 |
+
|
| 96 |
+
def forward(self, x):
|
| 97 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
| 98 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
| 99 |
+
return x
|
| 100 |
+
class MobileViTBlock(BaseModule):
|
| 101 |
+
"""
|
| 102 |
+
This class defines the `MobileViT block <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
opts: command line arguments
|
| 106 |
+
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
|
| 107 |
+
transformer_dim (int): Input dimension to the transformer unit
|
| 108 |
+
ffn_dim (int): Dimension of the FFN block
|
| 109 |
+
n_transformer_blocks (Optional[int]): Number of transformer blocks. Default: 2
|
| 110 |
+
head_dim (Optional[int]): Head dimension in the multi-head attention. Default: 32
|
| 111 |
+
attn_dropout (Optional[float]): Dropout in multi-head attention. Default: 0.0
|
| 112 |
+
dropout (Optional[float]): Dropout rate. Default: 0.0
|
| 113 |
+
ffn_dropout (Optional[float]): Dropout between FFN layers in transformer. Default: 0.0
|
| 114 |
+
patch_h (Optional[int]): Patch height for unfolding operation. Default: 8
|
| 115 |
+
patch_w (Optional[int]): Patch width for unfolding operation. Default: 8
|
| 116 |
+
transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm
|
| 117 |
+
conv_ksize (Optional[int]): Kernel size to learn local representations in MobileViT block. Default: 3
|
| 118 |
+
dilation (Optional[int]): Dilation rate in convolutions. Default: 1
|
| 119 |
+
no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
in_channels = 128,
|
| 125 |
+
transformer_dim = 128,
|
| 126 |
+
n_transformer_blocks = 2,
|
| 127 |
+
head_dim = 64,
|
| 128 |
+
attn_dropout = 0.0,
|
| 129 |
+
dropout = 0.0,
|
| 130 |
+
patch_h = 2,
|
| 131 |
+
patch_w = 2,
|
| 132 |
+
conv_ksize = 3,
|
| 133 |
+
dilation = 1,
|
| 134 |
+
no_fusion = True,
|
| 135 |
+
) -> None:
|
| 136 |
+
conv_3x3_in = ConvLayer2d(
|
| 137 |
+
in_channels=in_channels,
|
| 138 |
+
out_channels=in_channels,
|
| 139 |
+
kernel_size=conv_ksize,
|
| 140 |
+
stride=1,
|
| 141 |
+
use_norm=True,
|
| 142 |
+
use_act=True,
|
| 143 |
+
dilation=dilation,
|
| 144 |
+
padding = 1,
|
| 145 |
+
)
|
| 146 |
+
conv_1x1_in = ConvLayer2d(
|
| 147 |
+
in_channels=in_channels,
|
| 148 |
+
out_channels=transformer_dim,
|
| 149 |
+
kernel_size=1,
|
| 150 |
+
stride=1,
|
| 151 |
+
use_norm=False,
|
| 152 |
+
use_act=False,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
conv_1x1_out = ConvLayer2d(
|
| 156 |
+
in_channels=transformer_dim,
|
| 157 |
+
out_channels=in_channels,
|
| 158 |
+
kernel_size=1,
|
| 159 |
+
stride=1,
|
| 160 |
+
use_norm=True,
|
| 161 |
+
use_act=True,
|
| 162 |
+
)
|
| 163 |
+
conv_3x3_out = None
|
| 164 |
+
if not no_fusion:
|
| 165 |
+
conv_3x3_out = ConvLayer2d(
|
| 166 |
+
in_channels=2 * in_channels,
|
| 167 |
+
out_channels=in_channels,
|
| 168 |
+
kernel_size=conv_ksize,
|
| 169 |
+
stride=1,
|
| 170 |
+
padding = 1,
|
| 171 |
+
use_norm=True,
|
| 172 |
+
use_act=True,
|
| 173 |
+
)
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.local_rep = nn.Sequential()
|
| 176 |
+
self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
|
| 177 |
+
self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)
|
| 178 |
+
self.pos_pe = PosCNN(in_chans=transformer_dim, embed_dim=transformer_dim)
|
| 179 |
+
assert transformer_dim % head_dim == 0
|
| 180 |
+
num_heads = transformer_dim // head_dim
|
| 181 |
+
global_rep = [
|
| 182 |
+
Block(
|
| 183 |
+
dim=transformer_dim,
|
| 184 |
+
num_heads=num_heads,
|
| 185 |
+
mlp_ratio = 4.0,
|
| 186 |
+
qkv_bias = True,
|
| 187 |
+
attn_drop = attn_dropout,
|
| 188 |
+
drop=dropout,
|
| 189 |
+
norm_layer=nn.LayerNorm,
|
| 190 |
+
)
|
| 191 |
+
for _ in range(n_transformer_blocks)
|
| 192 |
+
]
|
| 193 |
+
global_rep.append(nn.LayerNorm(transformer_dim))
|
| 194 |
+
|
| 195 |
+
self.global_rep = nn.Sequential(*global_rep)
|
| 196 |
+
|
| 197 |
+
self.conv_proj = conv_1x1_out
|
| 198 |
+
|
| 199 |
+
self.fusion = conv_3x3_out
|
| 200 |
+
|
| 201 |
+
self.patch_h = patch_h
|
| 202 |
+
self.patch_w = patch_w
|
| 203 |
+
self.patch_area = self.patch_w * self.patch_h
|
| 204 |
+
|
| 205 |
+
self.cnn_in_dim = in_channels
|
| 206 |
+
self.cnn_out_dim = transformer_dim
|
| 207 |
+
self.n_heads = num_heads
|
| 208 |
+
self.dropout = dropout
|
| 209 |
+
self.attn_dropout = attn_dropout
|
| 210 |
+
self.dilation = dilation
|
| 211 |
+
self.n_blocks = n_transformer_blocks
|
| 212 |
+
self.conv_ksize = conv_ksize
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def unfolding(self, feature_map: Tensor) -> Tuple[Tensor, Dict]:
|
| 217 |
+
patch_w, patch_h = self.patch_w, self.patch_h
|
| 218 |
+
patch_area = int(patch_w * patch_h)
|
| 219 |
+
batch_size, in_channels, orig_h, orig_w = feature_map.shape
|
| 220 |
+
|
| 221 |
+
new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
|
| 222 |
+
new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)
|
| 223 |
+
|
| 224 |
+
interpolate = False
|
| 225 |
+
if new_w != orig_w or new_h != orig_h:
|
| 226 |
+
# Note: Padding can be done, but then it needs to be handled in attention function.
|
| 227 |
+
feature_map = F.interpolate(
|
| 228 |
+
feature_map, size=(new_h, new_w), mode="bilinear", align_corners=False
|
| 229 |
+
)
|
| 230 |
+
interpolate = True
|
| 231 |
+
|
| 232 |
+
# number of patches along width and height
|
| 233 |
+
num_patch_w = new_w // patch_w # n_w
|
| 234 |
+
num_patch_h = new_h // patch_h # n_h
|
| 235 |
+
num_patches = num_patch_h * num_patch_w # N
|
| 236 |
+
|
| 237 |
+
# [B, C, H, W] --> [B * C * n_h, p_h, n_w, p_w]
|
| 238 |
+
reshaped_fm = feature_map.reshape(
|
| 239 |
+
batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w
|
| 240 |
+
)
|
| 241 |
+
# [B * C * n_h, p_h, n_w, p_w] --> [B * C * n_h, n_w, p_h, p_w]
|
| 242 |
+
transposed_fm = reshaped_fm.transpose(1, 2)
|
| 243 |
+
# [B * C * n_h, n_w, p_h, p_w] --> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
|
| 244 |
+
reshaped_fm = transposed_fm.reshape(
|
| 245 |
+
batch_size, in_channels, num_patches, patch_area
|
| 246 |
+
)
|
| 247 |
+
# [B, C, N, P] --> [B, P, N, C]
|
| 248 |
+
transposed_fm = reshaped_fm.transpose(1, 3)
|
| 249 |
+
# [B, P, N, C] --> [BP, N, C]
|
| 250 |
+
patches = transposed_fm.reshape(batch_size * patch_area, num_patches, -1)
|
| 251 |
+
|
| 252 |
+
info_dict = {
|
| 253 |
+
"orig_size": (orig_h, orig_w),
|
| 254 |
+
"batch_size": batch_size,
|
| 255 |
+
"interpolate": interpolate,
|
| 256 |
+
"total_patches": num_patches,
|
| 257 |
+
"num_patches_w": num_patch_w,
|
| 258 |
+
"num_patches_h": num_patch_h,
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
return patches, info_dict
|
| 262 |
+
|
| 263 |
+
def folding(self, patches: Tensor, info_dict: Dict) -> Tensor:
|
| 264 |
+
n_dim = patches.dim()
|
| 265 |
+
assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(
|
| 266 |
+
patches.shape
|
| 267 |
+
)
|
| 268 |
+
# [BP, N, C] --> [B, P, N, C]
|
| 269 |
+
patches = patches.contiguous().view(
|
| 270 |
+
info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
batch_size, pixels, num_patches, channels = patches.size()
|
| 274 |
+
num_patch_h = info_dict["num_patches_h"]
|
| 275 |
+
num_patch_w = info_dict["num_patches_w"]
|
| 276 |
+
|
| 277 |
+
# [B, P, N, C] --> [B, C, N, P]
|
| 278 |
+
patches = patches.transpose(1, 3)
|
| 279 |
+
|
| 280 |
+
# [B, C, N, P] --> [B*C*n_h, n_w, p_h, p_w]
|
| 281 |
+
feature_map = patches.reshape(
|
| 282 |
+
batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w
|
| 283 |
+
)
|
| 284 |
+
# [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w]
|
| 285 |
+
feature_map = feature_map.transpose(1, 2)
|
| 286 |
+
# [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W]
|
| 287 |
+
feature_map = feature_map.reshape(
|
| 288 |
+
batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w
|
| 289 |
+
)
|
| 290 |
+
if info_dict["interpolate"]:
|
| 291 |
+
feature_map = F.interpolate(
|
| 292 |
+
feature_map,
|
| 293 |
+
size=info_dict["orig_size"],
|
| 294 |
+
mode="bilinear",
|
| 295 |
+
align_corners=False,
|
| 296 |
+
)
|
| 297 |
+
return feature_map
|
| 298 |
+
|
| 299 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 300 |
+
res = x
|
| 301 |
+
|
| 302 |
+
fm = self.local_rep(x)
|
| 303 |
+
|
| 304 |
+
# convert feature map to patches
|
| 305 |
+
patches, info_dict = self.unfolding(fm)
|
| 306 |
+
num_patch_h = info_dict["num_patches_h"]
|
| 307 |
+
num_patch_w = info_dict["num_patches_w"]
|
| 308 |
+
# learn global representations
|
| 309 |
+
|
| 310 |
+
for j, transformer_layer in enumerate(self.global_rep):
|
| 311 |
+
patches = transformer_layer(patches)
|
| 312 |
+
if j == 0:
|
| 313 |
+
patches = self.pos_pe(patches, num_patch_h, num_patch_w) # PEG here
|
| 314 |
+
# [B x Patch x Patches x C] --> [B x C x Patches x Patch]
|
| 315 |
+
fm = self.folding(patches=patches, info_dict=info_dict)
|
| 316 |
+
|
| 317 |
+
fm = self.conv_proj(fm)
|
| 318 |
+
|
| 319 |
+
if self.fusion is not None:
|
| 320 |
+
fm = self.fusion(torch.cat((res, fm), dim=1))
|
| 321 |
+
return fm
|
| 322 |
+
|
| 323 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 324 |
+
|
| 325 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
| 326 |
+
stride=stride, padding=1, bias=False)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class BasicBlock(nn.Module):
|
| 330 |
+
expansion = 1
|
| 331 |
+
|
| 332 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 333 |
+
super(BasicBlock, self).__init__()
|
| 334 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 335 |
+
self.bn1 = nn.BatchNorm2d(planes, eps=1e-05)
|
| 336 |
+
self.relu = nn.ReLU(inplace=True)
|
| 337 |
+
self.conv2 = conv3x3(planes, planes)
|
| 338 |
+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05)
|
| 339 |
+
self.downsample = downsample
|
| 340 |
+
self.stride = stride
|
| 341 |
+
|
| 342 |
+
def forward(self, x):
|
| 343 |
+
residual = x
|
| 344 |
+
|
| 345 |
+
out = self.conv1(x)
|
| 346 |
+
out = self.bn1(out)
|
| 347 |
+
out = self.relu(out)
|
| 348 |
+
|
| 349 |
+
out = self.conv2(out)
|
| 350 |
+
out = self.bn2(out)
|
| 351 |
+
|
| 352 |
+
if self.downsample is not None:
|
| 353 |
+
residual = self.downsample(x)
|
| 354 |
+
|
| 355 |
+
out += residual
|
| 356 |
+
out = self.relu(out)
|
| 357 |
+
|
| 358 |
+
return out
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class ResNet18(nn.Module):
|
| 362 |
+
|
| 363 |
+
def __init__(self, nb_feat=384):
|
| 364 |
+
|
| 365 |
+
self.inplanes = nb_feat // 4
|
| 366 |
+
super(ResNet18, self).__init__()
|
| 367 |
+
self.conv1 = nn.Conv2d(
|
| 368 |
+
1, nb_feat // 4, kernel_size=3, stride=(2, 1), padding=1, bias=False)
|
| 369 |
+
self.bn1 = nn.BatchNorm2d(nb_feat // 4, eps=1e-05)
|
| 370 |
+
self.relu = nn.ReLU(inplace=True)
|
| 371 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=(2, 1), padding=1)
|
| 372 |
+
self.layer1 = self._make_layer(
|
| 373 |
+
BasicBlock, nb_feat // 4, 2, stride=(2, 1))
|
| 374 |
+
self.mobilevit_block1 = MobileViTBlock(in_channels=nb_feat // 4, transformer_dim=nb_feat // 4, n_transformer_blocks=1, head_dim=64, attn_dropout=0.0, dropout=0.0, patch_h=2, patch_w=2, conv_ksize=3, dilation=1, no_fusion=True)
|
| 375 |
+
self.layer2 = self._make_layer(BasicBlock, nb_feat // 2, 2, stride=2)
|
| 376 |
+
self.mobilevit_block2 = MobileViTBlock(in_channels=nb_feat // 2, transformer_dim=nb_feat//2, n_transformer_blocks=1, head_dim=64, attn_dropout=0.0, dropout=0.0, patch_h=2, patch_w=2, conv_ksize=3, dilation=1, no_fusion=True)
|
| 377 |
+
self.layer3 = self._make_layer(BasicBlock, nb_feat, 2, stride=2)
|
| 378 |
+
self.mobilevit_block3 = MobileViTBlock(in_channels=nb_feat, transformer_dim=nb_feat, n_transformer_blocks=1, head_dim=64, attn_dropout=0.0, dropout=0.0, patch_h=2, patch_w=2, conv_ksize=3, dilation=1, no_fusion=True)
|
| 379 |
+
|
| 380 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 381 |
+
downsample = None
|
| 382 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 383 |
+
downsample = nn.Sequential(
|
| 384 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 385 |
+
kernel_size=1, stride=stride, bias=False),
|
| 386 |
+
nn.BatchNorm2d(planes * block.expansion, eps=1e-05),
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
layers = []
|
| 390 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 391 |
+
self.inplanes = planes * block.expansion
|
| 392 |
+
for i in range(1, blocks):
|
| 393 |
+
layers.append(block(self.inplanes, planes, 1, None))
|
| 394 |
+
|
| 395 |
+
return nn.Sequential(*layers)
|
| 396 |
+
|
| 397 |
+
def forward(self, x):
|
| 398 |
+
x = self.conv1(x)
|
| 399 |
+
x = self.bn1(x)
|
| 400 |
+
x = self.relu(x)
|
| 401 |
+
x = self.maxpool(x)
|
| 402 |
+
|
| 403 |
+
x = self.layer1(x)
|
| 404 |
+
x = self.mobilevit_block1(x)
|
| 405 |
+
x = self.layer2(x)
|
| 406 |
+
x = self.mobilevit_block2(x)
|
| 407 |
+
x = self.layer3(x)
|
| 408 |
+
x = self.mobilevit_block3(x)
|
| 409 |
+
x = self.maxpool(x)
|
| 410 |
+
|
| 411 |
+
return x
|
model/tcm_head.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def build_tcm_vocab(converter, add_tokens=("<pad>",)):
|
| 7 |
+
base = list(converter.character)
|
| 8 |
+
stoi = {ch: i for i, ch in enumerate(base)}
|
| 9 |
+
for t in add_tokens:
|
| 10 |
+
if t not in stoi:
|
| 11 |
+
stoi[t] = len(stoi)
|
| 12 |
+
itos = [''] * len(stoi)
|
| 13 |
+
for k, v in stoi.items():
|
| 14 |
+
itos[v] = k
|
| 15 |
+
pad_id = stoi["<pad>"]
|
| 16 |
+
return stoi, itos, pad_id
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def texts_to_ids(texts, stoi):
|
| 20 |
+
return [torch.tensor([stoi[ch] for ch in t], dtype=torch.long) for t in texts]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def make_context_batch(texts, stoi, sub_str_len=5, device='cuda'):
|
| 24 |
+
ids = [torch.tensor([stoi[ch] for ch in t], dtype=torch.long, device=device) for t in texts]
|
| 25 |
+
B = len(ids); Lmax = max(t.size(0) for t in ids); S = sub_str_len
|
| 26 |
+
PAD = stoi["<pad>"]
|
| 27 |
+
|
| 28 |
+
left = torch.full((B, Lmax, S), PAD, dtype=torch.long, device=device)
|
| 29 |
+
right = torch.full((B, Lmax, S), PAD, dtype=torch.long, device=device)
|
| 30 |
+
tgt = torch.full((B, Lmax), PAD, dtype=torch.long, device=device)
|
| 31 |
+
mask = torch.zeros((B, Lmax), dtype=torch.float32, device=device)
|
| 32 |
+
|
| 33 |
+
for b, seq in enumerate(ids):
|
| 34 |
+
L = seq.size(0)
|
| 35 |
+
tgt[b, :L] = seq
|
| 36 |
+
mask[b, :L] = 1.0
|
| 37 |
+
for i in range(L):
|
| 38 |
+
l_ctx = seq[max(0, i-S):i]
|
| 39 |
+
# left pad with PAD
|
| 40 |
+
if l_ctx.numel() < S:
|
| 41 |
+
l_ctx = torch.cat([torch.full((S - l_ctx.numel(),), PAD, device=device), l_ctx], dim=0)
|
| 42 |
+
left[b, i] = l_ctx[-S:]
|
| 43 |
+
|
| 44 |
+
r_ctx = seq[i+1:min(L, i+1+S)]
|
| 45 |
+
# right pad with PAD
|
| 46 |
+
if r_ctx.numel() < S:
|
| 47 |
+
r_ctx = torch.cat([r_ctx, torch.full((S - r_ctx.numel(),), PAD, device=device)], dim=0)
|
| 48 |
+
right[b, i] = r_ctx[:S]
|
| 49 |
+
|
| 50 |
+
return left, right, tgt, mask
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TCMHead(nn.Module):
|
| 54 |
+
def __init__(self, d_vis, vocab_size_tcm, pad_id, d_txt=256, sub_str_len=5, p_drop=0.1):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.vocab_size = vocab_size_tcm
|
| 57 |
+
self.sub_str_len = sub_str_len
|
| 58 |
+
|
| 59 |
+
# critical: padding_idx zeroes the PAD row and keeps it frozen
|
| 60 |
+
self.emb = nn.Embedding(vocab_size_tcm, d_txt, padding_idx=pad_id)
|
| 61 |
+
|
| 62 |
+
# keep direction as learned vectors (not tokens)
|
| 63 |
+
self.dir_left = nn.Parameter(torch.randn(1, 1, d_txt))
|
| 64 |
+
self.dir_right = nn.Parameter(torch.randn(1, 1, d_txt))
|
| 65 |
+
|
| 66 |
+
self.ctx_conv = nn.Conv1d(d_txt, d_txt, kernel_size=3, padding=1)
|
| 67 |
+
self.txt_proj = nn.Linear(d_txt, d_vis)
|
| 68 |
+
self.q_norm = nn.LayerNorm(d_vis)
|
| 69 |
+
self.kv_norm = nn.LayerNorm(d_vis)
|
| 70 |
+
self.dropout = nn.Dropout(p_drop)
|
| 71 |
+
self.classifier = nn.Linear(d_vis, vocab_size_tcm)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _context_to_query(self, ctx_ids, dir_token):
|
| 75 |
+
E = self.emb(ctx_ids)
|
| 76 |
+
B, L, S, D = E.shape
|
| 77 |
+
x = E.view(B*L, S, D).transpose(1, 2)
|
| 78 |
+
x = self.ctx_conv(x)
|
| 79 |
+
x = x.mean(dim=-1)
|
| 80 |
+
x = x.view(B, L, D)
|
| 81 |
+
|
| 82 |
+
x = x + dir_token
|
| 83 |
+
x = self.txt_proj(x)
|
| 84 |
+
return self.q_norm(x)
|
| 85 |
+
|
| 86 |
+
def _cross_attend(self, Q, F):
|
| 87 |
+
K = self.kv_norm(F)
|
| 88 |
+
V = K
|
| 89 |
+
attn = torch.einsum('bld,bnd->bln', Q, K) / \
|
| 90 |
+
(K.size(-1) ** 0.5)
|
| 91 |
+
A = attn.softmax(dim=-1)
|
| 92 |
+
out = torch.einsum('bln,bnd->bld', A, V)
|
| 93 |
+
return self.dropout(out)
|
| 94 |
+
|
| 95 |
+
def forward(self,
|
| 96 |
+
vis_tokens,
|
| 97 |
+
left_ctx_ids,
|
| 98 |
+
right_ctx_ids,
|
| 99 |
+
tgt_ids,
|
| 100 |
+
tgt_mask,
|
| 101 |
+
focus_mask=None):
|
| 102 |
+
Ql = self._context_to_query(left_ctx_ids, self.dir_left)
|
| 103 |
+
Qr = self._context_to_query(right_ctx_ids, self.dir_right)
|
| 104 |
+
|
| 105 |
+
Fl = self._cross_attend(Ql, vis_tokens)
|
| 106 |
+
Fr = self._cross_attend(Qr, vis_tokens)
|
| 107 |
+
|
| 108 |
+
logits_l = self.classifier(Fl)
|
| 109 |
+
logits_r = self.classifier(Fr)
|
| 110 |
+
|
| 111 |
+
loss_l = F.cross_entropy(
|
| 112 |
+
logits_l.view(-1, self.vocab_size),
|
| 113 |
+
tgt_ids.view(-1),
|
| 114 |
+
reduction='none'
|
| 115 |
+
).view_as(tgt_ids)
|
| 116 |
+
loss_r = F.cross_entropy(
|
| 117 |
+
logits_r.view(-1, self.vocab_size),
|
| 118 |
+
tgt_ids.view(-1),
|
| 119 |
+
reduction='none'
|
| 120 |
+
).view_as(tgt_ids)
|
| 121 |
+
|
| 122 |
+
if focus_mask is not None:
|
| 123 |
+
weights = tgt_mask * (1.0 + focus_mask)
|
| 124 |
+
else:
|
| 125 |
+
weights = tgt_mask
|
| 126 |
+
|
| 127 |
+
loss_masked = (loss_l + loss_r) * weights
|
| 128 |
+
denom = torch.clamp(weights.sum(), min=1.0)
|
| 129 |
+
loss_tcm = loss_masked.sum() / (2.0 * denom)
|
| 130 |
+
|
| 131 |
+
return {'loss_tcm': loss_tcm,
|
| 132 |
+
'logits_l': logits_l,
|
| 133 |
+
'logits_r': logits_r}
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy>=1.24
|
| 2 |
+
pillow>=9.0
|
| 3 |
+
opencv-python>=4.8
|
| 4 |
+
scikit-image>=0.21
|
| 5 |
+
tensorboard>=2.13
|
| 6 |
+
wandb>=0.16
|
| 7 |
+
editdistance>=0.6
|
| 8 |
+
timm>=0.9
|
run/iam.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python3 train.py --use-wandb --dataset iam --tcm-enable --exp-name "htr-convtext" --wandb-project iam --num-workers 4 --max-lr 1e-3 --warm-up-iter 1000 --weight-decay 0.05 --train-bs 32 --val-bs 8 --max-span-length 8 --mask-ratio 0.4 --attn-mask-ratio 0.1 --img-size 512 64 --proj 8 --dila-ero-max-kernel 2 --dila-ero-iter 1 --proba 0.5 --alpha 1 --total-iter 100001 --data-path /kaggle/input/iam-vt-lines/lines/ --train-data-list /kaggle/input/iam-vt-lines/train.ln --val-data-list /kaggle/input/iam-vt-lines/val.ln --test-data-list /kaggle/input/iam-vt-lines/test.ln --nb-cls 80
|
run/lam.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python3 train.py --use-wandb --dataset lam --tcm-enable --exp-name "htr-convtext" --wandb-project lam --num-workers 4 --max-lr 1e-3 --warm-up-iter 1000 --weight-decay 0.05 --train-bs 32 --val-bs 8 --max-span-length 8 --mask-ratio 0.4 --attn-mask-ratio 0.1 --img-size 512 64 --proj 8 --dila-ero-max-kernel 2 --dila-ero-iter 1 --proba 0.5 --alpha 1 --total-iter 100001 --data-path /kaggle/input/lam-vt-lines/lines/ --train-data-list /kaggle/input/lam-vt-lines/train.ln --val-data-list /kaggle/input/lam-vt-lines/val.ln --test-data-list /kaggle/input/lam-vt-lines/test.ln --nb-cls 91
|
run/read2016.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python3 train.py --use-wandb --dataset read2016 --tcm-enable --exp-name "htr-convtext" --wandb-project read2016 --num-workers 4 --max-lr 1e-3 --warm-up-iter 1000 --weight-decay 0.05 --train-bs 32 --val-bs 8 --max-span-length 8 --mask-ratio 0.4 --attn-mask-ratio 0.1 --img-size 512 64 --proj 8 --dila-ero-max-kernel 2 --dila-ero-iter 1 --proba 0.5 --alpha 1 --total-iter 100001 --data-path /kaggle/input/read2016-vt-lines/lines/ --train-data-list /kaggle/input/read2016-vt-lines/train.ln --val-data-list /kaggle/input/read2016-vt-lines/val.ln --test-data-list /kaggle/input/read2016-vt-lines/test.ln --nb-cls 90
|
run/vnondb.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python3 train.py --use-wandb --dataset vnondb --tcm-enable --exp-name "htr-convtext" --wandb-project vnondb --num-workers 4 --max-lr 1e-3 --warm-up-iter 1000 --weight-decay 0.05 --train-bs 32 --val-bs 8 --max-span-length 8 --mask-ratio 0.4 --attn-mask-ratio 0.1 --img-size 512 64 --proj 8 --dila-ero-max-kernel 2 --dila-ero-iter 1 --proba 0.5 --alpha 1 --total-iter 100001 --data-path /kaggle/input/vnondb/lines/ --train-data-list /kaggle/input/vnondb/train.ln --val-data-list /kaggle/input/vnondb/val.ln --test-data-list /kaggle/input/vnondb/test.ln --nb-cls 162
|
test.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import json
|
| 6 |
+
import valid
|
| 7 |
+
from utils import utils
|
| 8 |
+
from utils import option
|
| 9 |
+
from data import dataset
|
| 10 |
+
from model import htr_convtext
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main():
|
| 15 |
+
|
| 16 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 17 |
+
torch.manual_seed(args.seed)
|
| 18 |
+
|
| 19 |
+
args.save_dir = os.path.join(args.out_dir, args.exp_name)
|
| 20 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 21 |
+
logger = utils.get_logger(args.save_dir)
|
| 22 |
+
|
| 23 |
+
model = htr_convtext.create_model(
|
| 24 |
+
nb_cls=args.nb_cls, img_size=args.img_size[::-1])
|
| 25 |
+
|
| 26 |
+
pth_path = args.resume
|
| 27 |
+
logger.info('loading HWR checkpoint from {}'.format(pth_path))
|
| 28 |
+
|
| 29 |
+
ckpt = torch.load(pth_path, map_location='cpu', weights_only=False)
|
| 30 |
+
model_dict = OrderedDict()
|
| 31 |
+
pattern = re.compile('module.')
|
| 32 |
+
|
| 33 |
+
for k, v in ckpt['state_dict_ema'].items():
|
| 34 |
+
if re.search("module", k):
|
| 35 |
+
model_dict[re.sub(pattern, '', k)] = v
|
| 36 |
+
else:
|
| 37 |
+
model_dict[k] = v
|
| 38 |
+
|
| 39 |
+
model.load_state_dict(model_dict, strict=True)
|
| 40 |
+
model = model.cuda()
|
| 41 |
+
|
| 42 |
+
logger.info('Loading test loader...')
|
| 43 |
+
train_dataset = dataset.myLoadDS(
|
| 44 |
+
args.train_data_list, args.data_path, args.img_size, dataset=args.dataset)
|
| 45 |
+
|
| 46 |
+
test_dataset = dataset.myLoadDS(
|
| 47 |
+
args.test_data_list, args.data_path, args.img_size, ralph=train_dataset.ralph, dataset=args.dataset)
|
| 48 |
+
test_loader = torch.utils.data.DataLoader(test_dataset,
|
| 49 |
+
batch_size=args.val_bs,
|
| 50 |
+
shuffle=False,
|
| 51 |
+
pin_memory=True,
|
| 52 |
+
num_workers=args.num_workers)
|
| 53 |
+
|
| 54 |
+
converter = utils.CTCLabelConverter(train_dataset.ralph.values())
|
| 55 |
+
criterion = torch.nn.CTCLoss(
|
| 56 |
+
reduction='none', zero_infinity=True).to(device)
|
| 57 |
+
|
| 58 |
+
model.eval()
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
val_loss, val_cer, val_wer, preds, labels = valid.validation(
|
| 61 |
+
model,
|
| 62 |
+
criterion,
|
| 63 |
+
test_loader,
|
| 64 |
+
converter,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
logger.info(
|
| 68 |
+
f'Test. loss : {val_loss:0.3f} \t CER : {val_cer:0.4f} \t WER : {val_wer:0.4f} ')
|
| 69 |
+
|
| 70 |
+
# Save predictions as JSON
|
| 71 |
+
results = {
|
| 72 |
+
"test_metrics": {
|
| 73 |
+
"loss": float(val_loss),
|
| 74 |
+
"cer": float(val_cer),
|
| 75 |
+
"wer": float(val_wer)
|
| 76 |
+
},
|
| 77 |
+
"predictions": []
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def _levenshtein(pred_tokens, gt_tokens):
|
| 81 |
+
if pred_tokens == gt_tokens:
|
| 82 |
+
return 0
|
| 83 |
+
lp, lg = len(pred_tokens), len(gt_tokens)
|
| 84 |
+
if lp == 0:
|
| 85 |
+
return lg
|
| 86 |
+
if lg == 0:
|
| 87 |
+
return lp
|
| 88 |
+
prev = list(range(lg + 1))
|
| 89 |
+
for i in range(1, lp + 1):
|
| 90 |
+
cur = [i]
|
| 91 |
+
pi = pred_tokens[i - 1]
|
| 92 |
+
for j in range(1, lg + 1):
|
| 93 |
+
gj = gt_tokens[j - 1]
|
| 94 |
+
cost = 0 if pi == gj else 1
|
| 95 |
+
cur.append(
|
| 96 |
+
min(prev[j] + 1, cur[j - 1] + 1, prev[j - 1] + cost))
|
| 97 |
+
prev = cur
|
| 98 |
+
return prev[-1]
|
| 99 |
+
|
| 100 |
+
def _levenshtein_str(a: str, b: str):
|
| 101 |
+
return _levenshtein(list(a), list(b))
|
| 102 |
+
|
| 103 |
+
def _cer(pred: str, gt: str):
|
| 104 |
+
if len(gt) == 0:
|
| 105 |
+
return 0.0 if len(pred) == 0 else 1.0
|
| 106 |
+
return _levenshtein_str(pred, gt) / len(gt)
|
| 107 |
+
|
| 108 |
+
def _wer(pred: str, gt: str):
|
| 109 |
+
gt_words = gt.split()
|
| 110 |
+
pred_words = pred.split()
|
| 111 |
+
if len(gt_words) == 0:
|
| 112 |
+
return 0.0 if len(pred_words) == 0 else 1.0
|
| 113 |
+
return _levenshtein(pred_words, gt_words) / len(gt_words)
|
| 114 |
+
|
| 115 |
+
for i, (pred, label) in enumerate(zip(preds, labels)):
|
| 116 |
+
if i < len(test_dataset.fns):
|
| 117 |
+
img_path = test_dataset.fns[i]
|
| 118 |
+
img_name = os.path.basename(img_path)
|
| 119 |
+
else:
|
| 120 |
+
img_path = None
|
| 121 |
+
img_name = None
|
| 122 |
+
results["predictions"].append({
|
| 123 |
+
"sample_id": i + 1,
|
| 124 |
+
"image_filename": img_name,
|
| 125 |
+
"image_path": img_path,
|
| 126 |
+
"prediction": pred,
|
| 127 |
+
"ground_truth": label,
|
| 128 |
+
"match": pred == label,
|
| 129 |
+
"cer": round(float(_cer(pred, label)), 6),
|
| 130 |
+
"wer": round(float(_wer(pred, label)), 6)
|
| 131 |
+
})
|
| 132 |
+
|
| 133 |
+
pred_file = os.path.join(args.save_dir, 'predictions.json')
|
| 134 |
+
with open(pred_file, 'w', encoding='utf-8') as f:
|
| 135 |
+
json.dump(results, f, indent=2, ensure_ascii=False)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if __name__ == '__main__':
|
| 139 |
+
args = option.get_args_parser()
|
| 140 |
+
main()
|
train.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.utils.data
|
| 3 |
+
import torch.backends.cudnn as cudnn
|
| 4 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import valid
|
| 9 |
+
from utils import utils
|
| 10 |
+
from utils import sam
|
| 11 |
+
from utils import option
|
| 12 |
+
from data import dataset
|
| 13 |
+
from model import htr_convtext
|
| 14 |
+
from functools import partial
|
| 15 |
+
import random
|
| 16 |
+
import numpy as np
|
| 17 |
+
import re
|
| 18 |
+
import importlib
|
| 19 |
+
from model.tcm_head import TCMHead, build_tcm_vocab, make_context_batch
|
| 20 |
+
import wandb
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def compute_losses(
|
| 24 |
+
args,
|
| 25 |
+
model,
|
| 26 |
+
tcm_head,
|
| 27 |
+
image,
|
| 28 |
+
texts,
|
| 29 |
+
batch_size,
|
| 30 |
+
criterion_ctc,
|
| 31 |
+
converter,
|
| 32 |
+
nb_iter,
|
| 33 |
+
ctc_lambda,
|
| 34 |
+
tcm_lambda,
|
| 35 |
+
stoi,
|
| 36 |
+
mask_mode='span',
|
| 37 |
+
mask_ratio=0.30,
|
| 38 |
+
block_span=4,
|
| 39 |
+
max_span_length=8,
|
| 40 |
+
pre_tcm_ctx=None,
|
| 41 |
+
use_masking=True,
|
| 42 |
+
):
|
| 43 |
+
if tcm_head is None or nb_iter < args.tcm_warmup_iters:
|
| 44 |
+
preds = model(image, use_masking=use_masking, mask_mode=mask_mode,
|
| 45 |
+
mask_ratio=mask_ratio, max_span_length=max_span_length)
|
| 46 |
+
feats = None
|
| 47 |
+
else:
|
| 48 |
+
preds, feats, vis_mask = model(
|
| 49 |
+
image,
|
| 50 |
+
use_masking=use_masking,
|
| 51 |
+
return_features=True,
|
| 52 |
+
return_mask=True,
|
| 53 |
+
mask_mode=mask_mode,
|
| 54 |
+
mask_ratio=mask_ratio,
|
| 55 |
+
block_span=block_span,
|
| 56 |
+
max_span_length=max_span_length
|
| 57 |
+
)
|
| 58 |
+
text_ctc, length_ctc = converter.encode(texts)
|
| 59 |
+
text_ctc = text_ctc.to(preds.device)
|
| 60 |
+
length_ctc = length_ctc.to(preds.device)
|
| 61 |
+
preds_sz = torch.full((batch_size,), preds.size(
|
| 62 |
+
1), dtype=torch.int32, device=preds.device)
|
| 63 |
+
loss_ctc = criterion_ctc(preds.permute(1, 0, 2).log_softmax(2),
|
| 64 |
+
text_ctc, preds_sz, length_ctc).mean()
|
| 65 |
+
|
| 66 |
+
loss_tcm = torch.zeros((), device=preds.device)
|
| 67 |
+
if tcm_head is not None and feats is not None:
|
| 68 |
+
left_ctx, right_ctx, tgt_ids, tgt_mask = pre_tcm_ctx if pre_tcm_ctx is not None else make_context_batch(
|
| 69 |
+
texts, stoi, sub_str_len=args.tcm_sub_len, device=image.device)
|
| 70 |
+
if vis_mask is not None:
|
| 71 |
+
B_v, N_v = vis_mask.shape
|
| 72 |
+
B_t, L_t = tgt_mask.shape
|
| 73 |
+
if N_v != L_t:
|
| 74 |
+
idx = torch.linspace(0, N_v - 1, steps=L_t,
|
| 75 |
+
device=vis_mask.device).long()
|
| 76 |
+
focus_mask = vis_mask[:, idx]
|
| 77 |
+
else:
|
| 78 |
+
focus_mask = vis_mask
|
| 79 |
+
else:
|
| 80 |
+
focus_mask = None
|
| 81 |
+
|
| 82 |
+
out = tcm_head(
|
| 83 |
+
feats,
|
| 84 |
+
left_ctx, right_ctx,
|
| 85 |
+
tgt_ids, tgt_mask,
|
| 86 |
+
focus_mask=focus_mask
|
| 87 |
+
)
|
| 88 |
+
loss_tcm = out['loss_tcm']
|
| 89 |
+
|
| 90 |
+
total = ctc_lambda * loss_ctc + tcm_lambda * loss_tcm
|
| 91 |
+
return total, loss_ctc.detach(), loss_tcm.detach()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def tri_masked_loss(args, model, tcm_head, image, labels, batch_size,
|
| 95 |
+
criterion, converter, nb_iter, ctc_lambda, tcm_lambda, stoi,
|
| 96 |
+
r_rand=0.6, r_block=0.6, block_span=4, r_span=0.4, max_span=8):
|
| 97 |
+
total = 0.0
|
| 98 |
+
total_ctc = 0.0
|
| 99 |
+
total_tcm = 0.0
|
| 100 |
+
plans = [("random", r_rand), ("block", r_block), ("span", r_span)]
|
| 101 |
+
|
| 102 |
+
if tcm_head is not None and nb_iter >= args.tcm_warmup_iters:
|
| 103 |
+
pre_tcm_ctx = make_context_batch(
|
| 104 |
+
labels, stoi, sub_str_len=args.tcm_sub_len, device=image.device)
|
| 105 |
+
|
| 106 |
+
for mode, ratio in plans:
|
| 107 |
+
loss, loss_ctc, loss_tcm = compute_losses(
|
| 108 |
+
args, model, tcm_head, image, labels, batch_size, criterion, converter,
|
| 109 |
+
nb_iter, ctc_lambda, tcm_lambda, stoi,
|
| 110 |
+
mask_mode=mode, mask_ratio=ratio, block_span=block_span, max_span_length=max_span,
|
| 111 |
+
pre_tcm_ctx=pre_tcm_ctx
|
| 112 |
+
)
|
| 113 |
+
total += loss
|
| 114 |
+
total_ctc += loss_ctc
|
| 115 |
+
total_tcm += loss_tcm
|
| 116 |
+
|
| 117 |
+
denom = 3.0
|
| 118 |
+
return total/denom, total_ctc/denom, total_tcm/denom
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def main():
|
| 122 |
+
|
| 123 |
+
args = option.get_args_parser()
|
| 124 |
+
torch.manual_seed(args.seed)
|
| 125 |
+
|
| 126 |
+
args.save_dir = os.path.join(args.out_dir, args.exp_name)
|
| 127 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 128 |
+
|
| 129 |
+
logger = utils.get_logger(args.save_dir)
|
| 130 |
+
logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
|
| 131 |
+
writer = SummaryWriter(args.save_dir)
|
| 132 |
+
|
| 133 |
+
if getattr(args, 'use_wandb', False):
|
| 134 |
+
try:
|
| 135 |
+
wandb = importlib.import_module('wandb')
|
| 136 |
+
wandb.init(project=getattr(args, 'wandb_project', 'None'), name=args.exp_name,
|
| 137 |
+
config=vars(args), dir=args.save_dir)
|
| 138 |
+
logger.info("Weights & Biases logging enabled")
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.warning(
|
| 141 |
+
f"Failed to initialize wandb: {e}. Continuing without wandb.")
|
| 142 |
+
wandb = None
|
| 143 |
+
else:
|
| 144 |
+
wandb = None
|
| 145 |
+
|
| 146 |
+
torch.backends.cudnn.benchmark = True
|
| 147 |
+
|
| 148 |
+
model = htr_convtext.create_model(
|
| 149 |
+
nb_cls=args.nb_cls, img_size=args.img_size[::-1])
|
| 150 |
+
|
| 151 |
+
total_param = sum(p.numel() for p in model.parameters())
|
| 152 |
+
logger.info('total_param is {}'.format(total_param))
|
| 153 |
+
|
| 154 |
+
model.train()
|
| 155 |
+
model = model.cuda()
|
| 156 |
+
ema_decay = args.ema_decay
|
| 157 |
+
logger.info(f"Using EMA decay: {ema_decay}")
|
| 158 |
+
model_ema = utils.ModelEma(model, ema_decay)
|
| 159 |
+
model.zero_grad()
|
| 160 |
+
|
| 161 |
+
resume_path = args.resume
|
| 162 |
+
best_cer, best_wer, start_iter, optimizer_state, train_loss, train_loss_count = utils.load_checkpoint(
|
| 163 |
+
model, model_ema, None, resume_path, logger)
|
| 164 |
+
|
| 165 |
+
logger.info('Loading train loader...')
|
| 166 |
+
train_dataset = dataset.myLoadDS(
|
| 167 |
+
args.train_data_list, args.data_path, args.img_size, dataset=args.dataset)
|
| 168 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,
|
| 169 |
+
batch_size=args.train_bs,
|
| 170 |
+
shuffle=True,
|
| 171 |
+
pin_memory=True,
|
| 172 |
+
num_workers=args.num_workers,
|
| 173 |
+
collate_fn=partial(dataset.SameTrCollate, args=args))
|
| 174 |
+
train_iter = dataset.cycle_data(train_loader)
|
| 175 |
+
|
| 176 |
+
logger.info('Loading val loader...')
|
| 177 |
+
val_dataset = dataset.myLoadDS(
|
| 178 |
+
args.val_data_list, args.data_path, args.img_size, ralph=train_dataset.ralph, dataset=args.dataset)
|
| 179 |
+
val_loader = torch.utils.data.DataLoader(val_dataset,
|
| 180 |
+
batch_size=args.val_bs,
|
| 181 |
+
shuffle=False,
|
| 182 |
+
pin_memory=True,
|
| 183 |
+
num_workers=args.num_workers)
|
| 184 |
+
|
| 185 |
+
criterion = torch.nn.CTCLoss(reduction='none', zero_infinity=True)
|
| 186 |
+
converter = utils.CTCLabelConverter(train_dataset.ralph.values())
|
| 187 |
+
|
| 188 |
+
stoi, itos, pad_id = build_tcm_vocab(converter)
|
| 189 |
+
vocab_size_tcm = len(itos)
|
| 190 |
+
d_vis = model.embed_dim
|
| 191 |
+
|
| 192 |
+
if args.tcm_enable:
|
| 193 |
+
tcm_head = TCMHead(d_vis=d_vis, vocab_size_tcm=vocab_size_tcm, pad_id=pad_id,
|
| 194 |
+
sub_str_len=args.tcm_sub_len).cuda()
|
| 195 |
+
tcm_head.train()
|
| 196 |
+
else:
|
| 197 |
+
tcm_head = None
|
| 198 |
+
|
| 199 |
+
param_groups = list(model.parameters())
|
| 200 |
+
if args.tcm_enable and tcm_head is not None:
|
| 201 |
+
param_groups += list(tcm_head.parameters())
|
| 202 |
+
logger.info(
|
| 203 |
+
f"Optimizing {sum(p.numel() for p in tcm_head.parameters())} tcm params in addition to model params")
|
| 204 |
+
optimizer = sam.SAM(param_groups, torch.optim.AdamW,
|
| 205 |
+
lr=1e-7, betas=(0.9, 0.99), weight_decay=args.weight_decay)
|
| 206 |
+
|
| 207 |
+
if optimizer_state is not None:
|
| 208 |
+
try:
|
| 209 |
+
optimizer.load_state_dict(optimizer_state)
|
| 210 |
+
logger.info("Successfully loaded optimizer state")
|
| 211 |
+
except Exception as e:
|
| 212 |
+
logger.warning(f"Failed to load optimizer state: {e}")
|
| 213 |
+
logger.info(
|
| 214 |
+
"Continuing training without optimizer state (will restart from initial lr/momentum)")
|
| 215 |
+
elif resume_path and os.path.isfile(resume_path):
|
| 216 |
+
try:
|
| 217 |
+
ckpt = torch.load(resume_path, map_location='cpu',
|
| 218 |
+
weights_only=False)
|
| 219 |
+
if 'optimizer' in ckpt:
|
| 220 |
+
optimizer.load_state_dict(ckpt['optimizer'])
|
| 221 |
+
logger.info("Loaded optimizer state from checkpoint directly")
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.warning(
|
| 224 |
+
f"Could not load optimizer state from checkpoint: {e}")
|
| 225 |
+
|
| 226 |
+
if resume_path and os.path.isfile(resume_path) and tcm_head is not None:
|
| 227 |
+
try:
|
| 228 |
+
ckpt = torch.load(resume_path, map_location='cpu',
|
| 229 |
+
weights_only=False)
|
| 230 |
+
if 'tcm_head' in ckpt:
|
| 231 |
+
tcm_head.load_state_dict(ckpt['tcm_head'], strict=False)
|
| 232 |
+
logger.info("Restored tcm head state from checkpoint")
|
| 233 |
+
else:
|
| 234 |
+
logger.info(
|
| 235 |
+
"No tcm head state found in checkpoint; training tcm from scratch")
|
| 236 |
+
except Exception as e:
|
| 237 |
+
logger.warning(f"Failed to restore tcm head from checkpoint: {e}")
|
| 238 |
+
|
| 239 |
+
best_cer, best_wer = best_cer, best_wer
|
| 240 |
+
train_loss = train_loss
|
| 241 |
+
train_loss_count = train_loss_count
|
| 242 |
+
|
| 243 |
+
#### ---- train & eval ---- ####
|
| 244 |
+
logger.info('Start training...')
|
| 245 |
+
accum_steps = max(1, int(getattr(args, 'accum_steps', 1)))
|
| 246 |
+
micro_step = 0
|
| 247 |
+
avg_loss_ctc = 0.0
|
| 248 |
+
avg_loss_tcm = 0.0
|
| 249 |
+
|
| 250 |
+
for nb_iter in range(start_iter, args.total_iter):
|
| 251 |
+
optimizer, current_lr = utils.update_lr_cos(
|
| 252 |
+
nb_iter, args.warm_up_iter, args.total_iter, args.max_lr, optimizer)
|
| 253 |
+
optimizer.zero_grad()
|
| 254 |
+
total_loss_this_macro = 0.0
|
| 255 |
+
avg_loss_ctc = 0.0
|
| 256 |
+
avg_loss_tcm = 0.0
|
| 257 |
+
cached_batches = []
|
| 258 |
+
for micro_step in range(accum_steps):
|
| 259 |
+
batch = next(train_iter)
|
| 260 |
+
cached_batches.append(batch)
|
| 261 |
+
image = batch[0].cuda(non_blocking=True)
|
| 262 |
+
text, length = converter.encode(batch[1])
|
| 263 |
+
batch_size = image.size(0)
|
| 264 |
+
if args.use_masking:
|
| 265 |
+
# loss, loss_ctc, loss_tcm = tri_masked_loss(
|
| 266 |
+
# args, model, tcm_head, image, batch[1], batch_size, criterion, converter,
|
| 267 |
+
# nb_iter, args.ctc_lambda, args.tcm_lambda, stoi,
|
| 268 |
+
# r_rand=args.r_rand,
|
| 269 |
+
# r_block=args.r_block,
|
| 270 |
+
# block_span=args.block_span,
|
| 271 |
+
# r_span=args.r_span,
|
| 272 |
+
# max_span=args.max_span
|
| 273 |
+
# )
|
| 274 |
+
loss, loss_ctc, loss_tcm = compute_losses(
|
| 275 |
+
args, model, tcm_head, image, batch[1], batch_size, criterion, converter,
|
| 276 |
+
nb_iter, args.ctc_lambda, args.tcm_lambda, stoi,
|
| 277 |
+
mask_mode='span', mask_ratio=0.4, max_span_length=8, use_masking=True
|
| 278 |
+
)
|
| 279 |
+
else:
|
| 280 |
+
loss, loss_ctc, loss_tcm = compute_losses(
|
| 281 |
+
args, model, tcm_head, image, batch[1], batch_size, criterion, converter,
|
| 282 |
+
nb_iter, args.ctc_lambda, args.tcm_lambda, stoi, use_masking=False
|
| 283 |
+
)
|
| 284 |
+
(loss / accum_steps).backward()
|
| 285 |
+
total_loss_this_macro += loss.item()
|
| 286 |
+
avg_loss_ctc += loss_ctc.mean().item()
|
| 287 |
+
avg_loss_tcm += loss_tcm.mean().item()
|
| 288 |
+
|
| 289 |
+
optimizer.first_step(zero_grad=True)
|
| 290 |
+
|
| 291 |
+
# Recompute with perturbed weights and accumulate again for the second step
|
| 292 |
+
for micro_step in range(accum_steps):
|
| 293 |
+
batch = cached_batches[micro_step]
|
| 294 |
+
image = batch[0].cuda(non_blocking=True)
|
| 295 |
+
text, length = converter.encode(batch[1])
|
| 296 |
+
batch_size = image.size(0)
|
| 297 |
+
if args.use_masking:
|
| 298 |
+
# loss2, loss_ctc, loss_tcm = tri_masked_loss(
|
| 299 |
+
# args, model, tcm_head, image, batch[1], batch_size, criterion, converter,
|
| 300 |
+
# nb_iter, args.ctc_lambda, args.tcm_lambda, stoi,
|
| 301 |
+
# r_rand=args.r_rand,
|
| 302 |
+
# r_block=args.r_block,
|
| 303 |
+
# block_span=args.block_span,
|
| 304 |
+
# r_span=args.r_span,
|
| 305 |
+
# max_span=args.max_span
|
| 306 |
+
# )
|
| 307 |
+
loss2, loss_ctc, loss_tcm = compute_losses(
|
| 308 |
+
args, model, tcm_head, image, batch[1], batch_size, criterion, converter,
|
| 309 |
+
nb_iter, args.ctc_lambda, args.tcm_lambda, stoi,
|
| 310 |
+
mask_mode='span', mask_ratio=0.4, max_span_length=8, use_masking=True
|
| 311 |
+
)
|
| 312 |
+
else:
|
| 313 |
+
loss2, loss_ctc, loss_tcm = compute_losses(
|
| 314 |
+
args, model, tcm_head, image, batch[1], batch_size, criterion, converter,
|
| 315 |
+
nb_iter, args.ctc_lambda, args.tcm_lambda, stoi, use_masking=False
|
| 316 |
+
)
|
| 317 |
+
(loss2 / accum_steps).backward()
|
| 318 |
+
optimizer.second_step(zero_grad=True)
|
| 319 |
+
model.zero_grad()
|
| 320 |
+
model_ema.update(model, num_updates=nb_iter / 2)
|
| 321 |
+
|
| 322 |
+
train_loss += total_loss_this_macro / accum_steps
|
| 323 |
+
train_loss_count += 1
|
| 324 |
+
|
| 325 |
+
if nb_iter % args.print_iter == 0:
|
| 326 |
+
train_loss_avg = train_loss / train_loss_count if train_loss_count > 0 else 0.0
|
| 327 |
+
|
| 328 |
+
logger.info(
|
| 329 |
+
f'Iter : {nb_iter} \t LR : {current_lr:0.5f} \t total : {train_loss_avg:0.5f} \t CTC : {(avg_loss_ctc/accum_steps):0.5f} \t tcm : {(avg_loss_tcm/accum_steps):0.5f} \t ')
|
| 330 |
+
|
| 331 |
+
writer.add_scalar('./Train/lr', current_lr, nb_iter)
|
| 332 |
+
writer.add_scalar('./Train/train_loss', train_loss_avg, nb_iter)
|
| 333 |
+
if wandb is not None:
|
| 334 |
+
wandb.log({
|
| 335 |
+
'train/lr': current_lr,
|
| 336 |
+
'train/loss': train_loss_avg,
|
| 337 |
+
'train/CTC': (avg_loss_ctc/accum_steps),
|
| 338 |
+
'train/tcm': (avg_loss_tcm/accum_steps),
|
| 339 |
+
'iter': nb_iter,
|
| 340 |
+
}, step=nb_iter)
|
| 341 |
+
train_loss = 0.0
|
| 342 |
+
train_loss_count = 0
|
| 343 |
+
|
| 344 |
+
if nb_iter % args.eval_iter == 0:
|
| 345 |
+
model.eval()
|
| 346 |
+
with torch.no_grad():
|
| 347 |
+
val_loss, val_cer, val_wer, preds, labels = valid.validation(model_ema.ema,
|
| 348 |
+
criterion,
|
| 349 |
+
val_loader,
|
| 350 |
+
converter)
|
| 351 |
+
if nb_iter % args.eval_iter*5 == 0:
|
| 352 |
+
ckpt_name = f"checkpoint_{best_cer:.4f}_{best_wer:.4f}_{nb_iter}.pth"
|
| 353 |
+
checkpoint = {
|
| 354 |
+
'model': model.state_dict(),
|
| 355 |
+
'state_dict_ema': model_ema.ema.state_dict(),
|
| 356 |
+
'optimizer': optimizer.state_dict(),
|
| 357 |
+
'nb_iter': nb_iter,
|
| 358 |
+
'best_cer': best_cer,
|
| 359 |
+
'best_wer': best_wer,
|
| 360 |
+
'args': vars(args),
|
| 361 |
+
'random_state': random.getstate(),
|
| 362 |
+
'numpy_state': np.random.get_state(),
|
| 363 |
+
'torch_state': torch.get_rng_state(),
|
| 364 |
+
'torch_cuda_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
| 365 |
+
'train_loss': train_loss,
|
| 366 |
+
'train_loss_count': train_loss_count,
|
| 367 |
+
}
|
| 368 |
+
if tcm_head is not None:
|
| 369 |
+
checkpoint['tcm_head'] = tcm_head.state_dict()
|
| 370 |
+
torch.save(checkpoint, os.path.join(
|
| 371 |
+
args.save_dir, ckpt_name))
|
| 372 |
+
if val_cer < best_cer:
|
| 373 |
+
logger.info(
|
| 374 |
+
f'CER improved from {best_cer:.4f} to {val_cer:.4f}!!!')
|
| 375 |
+
best_cer = val_cer
|
| 376 |
+
checkpoint = {
|
| 377 |
+
'model': model.state_dict(),
|
| 378 |
+
'state_dict_ema': model_ema.ema.state_dict(),
|
| 379 |
+
'optimizer': optimizer.state_dict(),
|
| 380 |
+
'nb_iter': nb_iter,
|
| 381 |
+
'best_cer': best_cer,
|
| 382 |
+
'best_wer': best_wer,
|
| 383 |
+
'args': vars(args),
|
| 384 |
+
'random_state': random.getstate(),
|
| 385 |
+
'numpy_state': np.random.get_state(),
|
| 386 |
+
'torch_state': torch.get_rng_state(),
|
| 387 |
+
'torch_cuda_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
| 388 |
+
'train_loss': train_loss,
|
| 389 |
+
'train_loss_count': train_loss_count,
|
| 390 |
+
}
|
| 391 |
+
if tcm_head is not None:
|
| 392 |
+
checkpoint['tcm_head'] = tcm_head.state_dict()
|
| 393 |
+
torch.save(checkpoint, os.path.join(
|
| 394 |
+
args.save_dir, 'best_CER.pth'))
|
| 395 |
+
|
| 396 |
+
if val_wer < best_wer:
|
| 397 |
+
logger.info(
|
| 398 |
+
f'WER improved from {best_wer:.4f} to {val_wer:.4f}!!!')
|
| 399 |
+
best_wer = val_wer
|
| 400 |
+
checkpoint = {
|
| 401 |
+
'model': model.state_dict(),
|
| 402 |
+
'state_dict_ema': model_ema.ema.state_dict(),
|
| 403 |
+
'optimizer': optimizer.state_dict(),
|
| 404 |
+
'nb_iter': nb_iter,
|
| 405 |
+
'best_cer': best_cer,
|
| 406 |
+
'best_wer': best_wer,
|
| 407 |
+
'args': vars(args),
|
| 408 |
+
'random_state': random.getstate(),
|
| 409 |
+
'numpy_state': np.random.get_state(),
|
| 410 |
+
'torch_state': torch.get_rng_state(),
|
| 411 |
+
'torch_cuda_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
| 412 |
+
'train_loss': train_loss,
|
| 413 |
+
'train_loss_count': train_loss_count,
|
| 414 |
+
}
|
| 415 |
+
if tcm_head is not None:
|
| 416 |
+
checkpoint['tcm_head'] = tcm_head.state_dict()
|
| 417 |
+
torch.save(checkpoint, os.path.join(
|
| 418 |
+
args.save_dir, 'best_WER.pth'))
|
| 419 |
+
|
| 420 |
+
logger.info(
|
| 421 |
+
f'Val. loss : {val_loss:0.3f} \t CER : {val_cer:0.4f} \t WER : {val_wer:0.4f} \t ')
|
| 422 |
+
|
| 423 |
+
writer.add_scalar('./VAL/CER', val_cer, nb_iter)
|
| 424 |
+
writer.add_scalar('./VAL/WER', val_wer, nb_iter)
|
| 425 |
+
writer.add_scalar('./VAL/bestCER', best_cer, nb_iter)
|
| 426 |
+
writer.add_scalar('./VAL/bestWER', best_wer, nb_iter)
|
| 427 |
+
writer.add_scalar('./VAL/val_loss', val_loss, nb_iter)
|
| 428 |
+
if wandb is not None:
|
| 429 |
+
wandb.log({
|
| 430 |
+
'val/loss': val_loss,
|
| 431 |
+
'val/CER': val_cer,
|
| 432 |
+
'val/WER': val_wer,
|
| 433 |
+
'val/best_CER': best_cer,
|
| 434 |
+
'val/best_WER': best_wer,
|
| 435 |
+
'iter': nb_iter,
|
| 436 |
+
}, step=nb_iter)
|
| 437 |
+
model.train()
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
if __name__ == '__main__':
|
| 441 |
+
main()
|
utils/option.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_args_parser() -> argparse.Namespace:
|
| 5 |
+
"""Create and parse command-line options for HTR-ConvText.
|
| 6 |
+
|
| 7 |
+
This keeps all option names and defaults intact, but organizes them into
|
| 8 |
+
logical groups with clearer help messages.
|
| 9 |
+
"""
|
| 10 |
+
parser = argparse.ArgumentParser(
|
| 11 |
+
description='HTR-ConvText: Leveraging Convolution and Textual Context with Mixed Masking for Handwritten Text Recognition',
|
| 12 |
+
add_help=True,
|
| 13 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# ---------------------------------------------------------------------
|
| 17 |
+
# Experiment & Logging
|
| 18 |
+
# ---------------------------------------------------------------------
|
| 19 |
+
exp = parser.add_argument_group('Experiment & Logging')
|
| 20 |
+
exp.add_argument('--out-dir', type=str, default='./output',
|
| 21 |
+
help='Root directory to save logs, checkpoints, and outputs')
|
| 22 |
+
exp.add_argument('--exp-name', type=str, default='IAM_HTR_ORIGAMI_NET',
|
| 23 |
+
help='Experiment name; results go to <out-dir>/<exp-name>')
|
| 24 |
+
exp.add_argument('--seed', default=123, type=int,
|
| 25 |
+
help='Random seed for reproducibility')
|
| 26 |
+
exp.add_argument('--use-wandb', action='store_true', default=False,
|
| 27 |
+
help='Log to Weights & Biases; otherwise use TensorBoard')
|
| 28 |
+
exp.add_argument('--wandb-project', type=str, default='None',
|
| 29 |
+
help='W&B project name (used only if --use-wandb)')
|
| 30 |
+
exp.add_argument('--print-iter', default=100, type=int,
|
| 31 |
+
help='Iterations between training status prints')
|
| 32 |
+
exp.add_argument('--eval-iter', default=1000, type=int,
|
| 33 |
+
help='Iterations between validation runs')
|
| 34 |
+
|
| 35 |
+
# ---------------------------------------------------------------------
|
| 36 |
+
# Data & Dataloading
|
| 37 |
+
# ---------------------------------------------------------------------
|
| 38 |
+
data = parser.add_argument_group('Data & Dataloading')
|
| 39 |
+
data.add_argument('--dataset', type=str, choices=['iam', 'read2016', 'lam', 'vnondb'],
|
| 40 |
+
help='Dataset choice')
|
| 41 |
+
data.add_argument('--data-path', type=str, default='./data/iam/lines/',
|
| 42 |
+
help='Root directory containing image/line data')
|
| 43 |
+
data.add_argument('--train-data-list', type=str, default='./data/iam/train.ln',
|
| 44 |
+
help='Path to training list file (e.g., .ln)')
|
| 45 |
+
data.add_argument('--val-data-list', type=str, default='./data/iam/val.ln',
|
| 46 |
+
help='Path to validation list file (e.g., .ln)')
|
| 47 |
+
data.add_argument('--test-data-list', type=str, default='./data/iam/test.ln',
|
| 48 |
+
help='Path to test list file (e.g., .ln)')
|
| 49 |
+
data.add_argument('--nb-cls', default=80, type=int,
|
| 50 |
+
help='Number of classes. IAM=79+1, READ2016=89+1, LAM=90+1, VNOnDB=161+1')
|
| 51 |
+
data.add_argument('--num-workers', default=0, type=int,
|
| 52 |
+
help='Dataloader worker processes')
|
| 53 |
+
data.add_argument('--img-size', default=[512, 64], type=int, nargs='+',
|
| 54 |
+
help='Input image size [W, H]')
|
| 55 |
+
data.add_argument('--patch-size', default=[4, 32], type=int, nargs='+',
|
| 56 |
+
help='Patch size [W, H] for patch embedding')
|
| 57 |
+
|
| 58 |
+
# ---------------------------------------------------------------------
|
| 59 |
+
# Training Schedule & Optimization
|
| 60 |
+
# ---------------------------------------------------------------------
|
| 61 |
+
train = parser.add_argument_group('Training Schedule & Optimization')
|
| 62 |
+
train.add_argument('--train-bs', default=8, type=int,
|
| 63 |
+
help='Training batch size per iteration')
|
| 64 |
+
train.add_argument('--accum-steps', default=1, type=int,
|
| 65 |
+
help='Gradient accumulation steps; effective batch = train-bs * accum-steps')
|
| 66 |
+
train.add_argument('--val-bs', default=1, type=int,
|
| 67 |
+
help='Validation/test batch size')
|
| 68 |
+
train.add_argument('--total-iter', default=100000, type=int,
|
| 69 |
+
help='Total training iterations')
|
| 70 |
+
train.add_argument('--warm-up-iter', default=1000, type=int,
|
| 71 |
+
help='Warm-up iterations for the optimizer/scheduler')
|
| 72 |
+
train.add_argument('--max-lr', default=1e-3, type=float,
|
| 73 |
+
help='Peak learning rate')
|
| 74 |
+
train.add_argument('--weight-decay', default=5e-1, type=float,
|
| 75 |
+
help='Weight decay (L2) regularization')
|
| 76 |
+
train.add_argument('--ema-decay', default=0.9999, type=float,
|
| 77 |
+
help='Exponential Moving Average (EMA) decay factor for model weights')
|
| 78 |
+
train.add_argument('--alpha', default=0, type=float,
|
| 79 |
+
help='KL-divergence loss ratio (if applicable)')
|
| 80 |
+
|
| 81 |
+
# ---------------------------------------------------------------------
|
| 82 |
+
# Model & Encoder
|
| 83 |
+
# ---------------------------------------------------------------------
|
| 84 |
+
model = parser.add_argument_group('Model & Encoder')
|
| 85 |
+
model.add_argument('--model-type', default='ctc', type=str, choices=['ctc', 'encoder_decoder'],
|
| 86 |
+
help='Model family to train/use')
|
| 87 |
+
model.add_argument('--cos-temp', default=8, type=int,
|
| 88 |
+
help='Cosine-similarity classifier temperature')
|
| 89 |
+
model.add_argument('--proj', default=8, type=float,
|
| 90 |
+
help='Projection dimension or scaling for classifier head')
|
| 91 |
+
model.add_argument('--attn-mask-ratio', default=0., type=float,
|
| 92 |
+
help='Attention drop-key mask ratio')
|
| 93 |
+
|
| 94 |
+
# ---------------------------------------------------------------------
|
| 95 |
+
# Masking Strategy
|
| 96 |
+
# ---------------------------------------------------------------------
|
| 97 |
+
mask = parser.add_argument_group('Masking Strategy')
|
| 98 |
+
mask.add_argument('--use-masking', action='store_true', default=False,
|
| 99 |
+
help='Enable masking strategy during training')
|
| 100 |
+
mask.add_argument('--mask-ratio', default=0.3, type=float,
|
| 101 |
+
help='Overall proportion of tokens/patches to mask')
|
| 102 |
+
mask.add_argument('--max-span-length', default=4, type=int,
|
| 103 |
+
help='Max length for individual span masks')
|
| 104 |
+
mask.add_argument('--spacing', default=0, type=int,
|
| 105 |
+
help='Minimum spacing between two span masks')
|
| 106 |
+
# Tri-masking schedule ratios
|
| 107 |
+
mask.add_argument('--r-rand', dest='r_rand', default=0.6, type=float,
|
| 108 |
+
help='Ratio for random masking in tri-masking schedule')
|
| 109 |
+
mask.add_argument('--r-block', dest='r_block', default=0.6, type=float,
|
| 110 |
+
help='Ratio for block masking in tri-masking schedule')
|
| 111 |
+
mask.add_argument('--block-span', dest='block_span', default=4, type=int,
|
| 112 |
+
help='Block span length for block masking')
|
| 113 |
+
mask.add_argument('--r-span', dest='r_span', default=0.4, type=float,
|
| 114 |
+
help='Ratio for span masking in tri-masking schedule')
|
| 115 |
+
mask.add_argument('--max-span', dest='max_span', default=8, type=int,
|
| 116 |
+
help='Max span length for span masking')
|
| 117 |
+
|
| 118 |
+
# ---------------------------------------------------------------------
|
| 119 |
+
# Data Augmentations
|
| 120 |
+
# ---------------------------------------------------------------------
|
| 121 |
+
aug = parser.add_argument_group('Data Augmentations')
|
| 122 |
+
aug.add_argument('--dpi-min-factor', default=0.5, type=float,
|
| 123 |
+
help='Minimum scaling factor for DPI-based resize')
|
| 124 |
+
aug.add_argument('--dpi-max-factor', default=1.5, type=float,
|
| 125 |
+
help='Maximum scaling factor for DPI-based resize')
|
| 126 |
+
aug.add_argument('--perspective-low', default=0., type=float,
|
| 127 |
+
help='Lower bound for perspective transform magnitude')
|
| 128 |
+
aug.add_argument('--perspective-high', default=0.4, type=float,
|
| 129 |
+
help='Upper bound for perspective transform magnitude')
|
| 130 |
+
aug.add_argument('--elastic-distortion-min-kernel-size', default=3, type=int,
|
| 131 |
+
help='Minimum kernel size for elastic distortion grid')
|
| 132 |
+
aug.add_argument('--elastic-distortion-max-kernel-size', default=3, type=int,
|
| 133 |
+
help='Maximum kernel size for elastic distortion grid')
|
| 134 |
+
aug.add_argument('--elastic_distortion-max-magnitude', default=20, type=int,
|
| 135 |
+
help='Maximum distortion magnitude for elastic transforms')
|
| 136 |
+
aug.add_argument('--elastic-distortion-min-alpha', default=0.5, type=float,
|
| 137 |
+
help='Minimum alpha for elastic distortion')
|
| 138 |
+
aug.add_argument('--elastic-distortion-max-alpha', default=1, type=float,
|
| 139 |
+
help='Maximum alpha for elastic distortion')
|
| 140 |
+
aug.add_argument('--elastic-distortion-min-sigma', default=1, type=int,
|
| 141 |
+
help='Minimum sigma for Gaussian in elastic distortion')
|
| 142 |
+
aug.add_argument('--elastic-distortion-max-sigma', default=10, type=int,
|
| 143 |
+
help='Maximum sigma for Gaussian in elastic distortion')
|
| 144 |
+
aug.add_argument('--dila-ero-max-kernel', default=3, type=int,
|
| 145 |
+
help='Max kernel size for dilation/erosion ops')
|
| 146 |
+
aug.add_argument('--dila-ero-iter', default=1, type=int,
|
| 147 |
+
help='Iterations for dilation/erosion')
|
| 148 |
+
aug.add_argument('--jitter-contrast', default=0.4, type=float,
|
| 149 |
+
help='ColorJitter: contrast range')
|
| 150 |
+
aug.add_argument('--jitter-brightness', default=0.4, type=float,
|
| 151 |
+
help='ColorJitter: brightness range')
|
| 152 |
+
aug.add_argument('--jitter-saturation', default=0.4, type=float,
|
| 153 |
+
help='ColorJitter: saturation range')
|
| 154 |
+
aug.add_argument('--jitter-hue', default=0.2, type=float,
|
| 155 |
+
help='ColorJitter: hue range')
|
| 156 |
+
aug.add_argument('--blur-min-kernel', default=3, type=int,
|
| 157 |
+
help='Minimum kernel size for Gaussian blur')
|
| 158 |
+
aug.add_argument('--blur-max-kernel', default=5, type=int,
|
| 159 |
+
help='Maximum kernel size for Gaussian blur')
|
| 160 |
+
aug.add_argument('--blur-min-sigma', default=3, type=int,
|
| 161 |
+
help='Minimum sigma for Gaussian blur')
|
| 162 |
+
aug.add_argument('--blur-max-sigma', default=5, type=int,
|
| 163 |
+
help='Maximum sigma for Gaussian blur')
|
| 164 |
+
aug.add_argument('--sharpen-min-alpha', default=0, type=int,
|
| 165 |
+
help='Minimum alpha/mix for sharpening')
|
| 166 |
+
aug.add_argument('--sharpen-max-alpha', default=1, type=int,
|
| 167 |
+
help='Maximum alpha/mix for sharpening')
|
| 168 |
+
aug.add_argument('--sharpen-min-strength', default=0, type=int,
|
| 169 |
+
help='Minimum sharpening strength')
|
| 170 |
+
aug.add_argument('--sharpen-max-strength', default=1, type=int,
|
| 171 |
+
help='Maximum sharpening strength')
|
| 172 |
+
aug.add_argument('--zoom-min-h', default=0.8, type=float,
|
| 173 |
+
help='Minimum vertical zoom factor')
|
| 174 |
+
aug.add_argument('--zoom-max-h', default=1, type=float,
|
| 175 |
+
help='Maximum vertical zoom factor')
|
| 176 |
+
aug.add_argument('--zoom-min-w', default=0.99, type=float,
|
| 177 |
+
help='Minimum horizontal zoom factor')
|
| 178 |
+
aug.add_argument('--zoom-max-w', default=1, type=float,
|
| 179 |
+
help='Maximum horizontal zoom factor')
|
| 180 |
+
aug.add_argument('--proba', default=0.5, type=float,
|
| 181 |
+
help='Default probability for applying stochastic augmentations')
|
| 182 |
+
|
| 183 |
+
# ---------------------------------------------------------------------
|
| 184 |
+
# Decoder & Inference (for encoder-decoder mode)
|
| 185 |
+
# ---------------------------------------------------------------------
|
| 186 |
+
dec = parser.add_argument_group('Decoder & Inference')
|
| 187 |
+
dec.add_argument('--decoder-layers', default=6, type=int,
|
| 188 |
+
help='Number of Transformer decoder layers')
|
| 189 |
+
dec.add_argument('--decoder-heads', default=8, type=int,
|
| 190 |
+
help='Number of attention heads in decoder')
|
| 191 |
+
dec.add_argument('--max-seq-len', default=256, type=int,
|
| 192 |
+
help='Maximum output sequence length')
|
| 193 |
+
dec.add_argument('--label-smoothing', default=0.1, type=float,
|
| 194 |
+
help='Label-smoothing factor for cross-entropy loss')
|
| 195 |
+
dec.add_argument('--beam-size', default=5, type=int,
|
| 196 |
+
help='Beam size for beam-search decoding')
|
| 197 |
+
dec.add_argument('--generation-method', default='nucleus', type=str,
|
| 198 |
+
choices=['greedy', 'nucleus', 'beam_search'],
|
| 199 |
+
help='Token generation method for inference')
|
| 200 |
+
dec.add_argument('--generation-temperature', default=0.7, type=float,
|
| 201 |
+
help='Sampling temperature (used by nucleus/greedy sampling)')
|
| 202 |
+
dec.add_argument('--repetition-penalty', default=1.3, type=float,
|
| 203 |
+
help='Penalty to discourage token repetition during generation')
|
| 204 |
+
dec.add_argument('--top-p', default=0.9, type=float,
|
| 205 |
+
help='Top-p threshold for nucleus sampling')
|
| 206 |
+
|
| 207 |
+
# ---------------------------------------------------------------------
|
| 208 |
+
# TCM (Textual Context Module)
|
| 209 |
+
# ---------------------------------------------------------------------
|
| 210 |
+
tcm = parser.add_argument_group('TCM (Textual Context Module)')
|
| 211 |
+
tcm.add_argument('--tcm-enable', action='store_true', default=False,
|
| 212 |
+
help='Enable Textual Context Module (TCM)')
|
| 213 |
+
tcm.add_argument('--tcm-lambda', default=1.0, type=float,
|
| 214 |
+
help='TCM loss weight (λ2 in the paper)')
|
| 215 |
+
tcm.add_argument('--ctc-lambda', default=0.1, type=float,
|
| 216 |
+
help='CTC loss weight (λ1 in the paper)')
|
| 217 |
+
tcm.add_argument('--tcm-sub-len', default=5, type=int,
|
| 218 |
+
help='TCM context sub-string length')
|
| 219 |
+
tcm.add_argument('--tcm-warmup-iters', default=0, type=int,
|
| 220 |
+
help='Warm-up iterations before activating TCM (0 = start immediately)')
|
| 221 |
+
|
| 222 |
+
# ---------------------------------------------------------------------
|
| 223 |
+
# Checkpointing & Pretrained Weights
|
| 224 |
+
# ---------------------------------------------------------------------
|
| 225 |
+
ckpt = parser.add_argument_group('Checkpointing & Pretrained Weights')
|
| 226 |
+
ckpt.add_argument('--resume', type=str, default=None,
|
| 227 |
+
help='Resume training from a checkpoint (alias)')
|
| 228 |
+
ckpt.add_argument('--load-model', type=str, default=None,
|
| 229 |
+
help='Load a full pretrained model for fine-tuning')
|
| 230 |
+
ckpt.add_argument('--load-encoder-only', action='store_true', default=False,
|
| 231 |
+
help='Load only encoder weights (transfer learning)')
|
| 232 |
+
ckpt.add_argument('--strict-loading', action='store_true', default=True,
|
| 233 |
+
help='Use strict key matching when loading weights')
|
| 234 |
+
|
| 235 |
+
return parser.parse_args()
|
utils/sam.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class SAM(torch.optim.Optimizer):
|
| 5 |
+
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
|
| 6 |
+
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
|
| 7 |
+
|
| 8 |
+
defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
|
| 9 |
+
super(SAM, self).__init__(params, defaults)
|
| 10 |
+
|
| 11 |
+
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
|
| 12 |
+
self.param_groups = self.base_optimizer.param_groups
|
| 13 |
+
self.defaults.update(self.base_optimizer.defaults)
|
| 14 |
+
|
| 15 |
+
@torch.no_grad()
|
| 16 |
+
def first_step(self, zero_grad=False):
|
| 17 |
+
grad_norm = self._grad_norm()
|
| 18 |
+
for group in self.param_groups:
|
| 19 |
+
scale = group["rho"] / (grad_norm + 1e-12)
|
| 20 |
+
|
| 21 |
+
for p in group["params"]:
|
| 22 |
+
if p.grad is None: continue
|
| 23 |
+
self.state[p]["old_p"] = p.data.clone()
|
| 24 |
+
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
|
| 25 |
+
p.add_(e_w) # climb to the local maximum "w + e(w)"
|
| 26 |
+
|
| 27 |
+
if zero_grad: self.zero_grad()
|
| 28 |
+
|
| 29 |
+
@torch.no_grad()
|
| 30 |
+
def second_step(self, zero_grad=False):
|
| 31 |
+
for group in self.param_groups:
|
| 32 |
+
for p in group["params"]:
|
| 33 |
+
if p.grad is None: continue
|
| 34 |
+
p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
|
| 35 |
+
|
| 36 |
+
self.base_optimizer.step() # do the actual "sharpness-aware" update
|
| 37 |
+
|
| 38 |
+
if zero_grad: self.zero_grad()
|
| 39 |
+
|
| 40 |
+
@torch.no_grad()
|
| 41 |
+
def step(self, closure=None):
|
| 42 |
+
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
|
| 43 |
+
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
|
| 44 |
+
|
| 45 |
+
self.first_step(zero_grad=True)
|
| 46 |
+
closure()
|
| 47 |
+
self.second_step()
|
| 48 |
+
|
| 49 |
+
def _grad_norm(self):
|
| 50 |
+
shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
|
| 51 |
+
norm = torch.norm(
|
| 52 |
+
torch.stack([
|
| 53 |
+
((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
|
| 54 |
+
for group in self.param_groups for p in group["params"]
|
| 55 |
+
if p.grad is not None
|
| 56 |
+
]),
|
| 57 |
+
p=2
|
| 58 |
+
)
|
| 59 |
+
return norm
|
| 60 |
+
|
| 61 |
+
def load_state_dict(self, state_dict):
|
| 62 |
+
super().load_state_dict(state_dict)
|
| 63 |
+
self.base_optimizer.param_groups = self.param_groups
|
utils/utils.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
from torch.distributions.uniform import Uniform
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import sys
|
| 8 |
+
import math
|
| 9 |
+
import logging
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
import random
|
| 13 |
+
import numpy as np
|
| 14 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def randint(low, high):
|
| 18 |
+
return int(torch.randint(low, high, (1, )))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def rand_uniform(low, high):
|
| 22 |
+
return float(Uniform(low, high).sample())
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_logger(out_dir):
|
| 26 |
+
logger = logging.getLogger('Exp')
|
| 27 |
+
logger.setLevel(logging.INFO)
|
| 28 |
+
formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
|
| 29 |
+
|
| 30 |
+
file_path = os.path.join(out_dir, "run.log")
|
| 31 |
+
file_hdlr = logging.FileHandler(file_path)
|
| 32 |
+
file_hdlr.setFormatter(formatter)
|
| 33 |
+
|
| 34 |
+
strm_hdlr = logging.StreamHandler(sys.stdout)
|
| 35 |
+
strm_hdlr.setFormatter(formatter)
|
| 36 |
+
|
| 37 |
+
logger.addHandler(file_hdlr)
|
| 38 |
+
logger.addHandler(strm_hdlr)
|
| 39 |
+
return logger
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def update_lr_cos(nb_iter, warm_up_iter, total_iter, max_lr, optimizer, min_lr=1e-7):
|
| 43 |
+
|
| 44 |
+
if nb_iter < warm_up_iter:
|
| 45 |
+
current_lr = max_lr * (nb_iter + 1) / (warm_up_iter + 1)
|
| 46 |
+
else:
|
| 47 |
+
current_lr = min_lr + (max_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * nb_iter / (total_iter - warm_up_iter)))
|
| 48 |
+
|
| 49 |
+
for param_group in optimizer.param_groups:
|
| 50 |
+
param_group["lr"] = current_lr
|
| 51 |
+
|
| 52 |
+
return optimizer, current_lr
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class CTCLabelConverter(object):
|
| 56 |
+
def __init__(self, character):
|
| 57 |
+
dict_character = list(character)
|
| 58 |
+
self.dict = {}
|
| 59 |
+
for i, char in enumerate(dict_character):
|
| 60 |
+
self.dict[char] = i + 1
|
| 61 |
+
if len(self.dict) == 87: # '[' and ']' are not in the test set but in the training and validation sets.
|
| 62 |
+
self.dict['['], self.dict[']'] = 88, 89
|
| 63 |
+
self.character = ['[blank]'] + dict_character
|
| 64 |
+
|
| 65 |
+
def encode(self, text):
|
| 66 |
+
length = [len(s) for s in text]
|
| 67 |
+
text = ''.join(text)
|
| 68 |
+
text = [self.dict[char] for char in text]
|
| 69 |
+
|
| 70 |
+
return (torch.IntTensor(text).to(device), torch.IntTensor(length).to(device))
|
| 71 |
+
|
| 72 |
+
def decode(self, text_index, length):
|
| 73 |
+
texts = []
|
| 74 |
+
index = 0
|
| 75 |
+
|
| 76 |
+
for l in length:
|
| 77 |
+
t = text_index[index:index + l]
|
| 78 |
+
char_list = []
|
| 79 |
+
for i in range(l):
|
| 80 |
+
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])) and t[i]<len(self.character):
|
| 81 |
+
char_list.append(self.character[t[i]])
|
| 82 |
+
text = ''.join(char_list)
|
| 83 |
+
|
| 84 |
+
texts.append(text)
|
| 85 |
+
index += l
|
| 86 |
+
return texts
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class Averager(object):
|
| 90 |
+
def __init__(self):
|
| 91 |
+
self.reset()
|
| 92 |
+
|
| 93 |
+
def add(self, v):
|
| 94 |
+
count = v.data.numel()
|
| 95 |
+
v = v.data.sum()
|
| 96 |
+
self.n_count += count
|
| 97 |
+
self.sum += v
|
| 98 |
+
|
| 99 |
+
def reset(self):
|
| 100 |
+
self.n_count = 0
|
| 101 |
+
self.sum = 0
|
| 102 |
+
|
| 103 |
+
def val(self):
|
| 104 |
+
res = 0
|
| 105 |
+
if self.n_count != 0:
|
| 106 |
+
res = self.sum / float(self.n_count)
|
| 107 |
+
return res
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class Metric(object):
|
| 111 |
+
def __init__(self, name=''):
|
| 112 |
+
self.name = name
|
| 113 |
+
self.sum = torch.tensor(0.).double()
|
| 114 |
+
self.n = torch.tensor(0.)
|
| 115 |
+
|
| 116 |
+
def update(self, val):
|
| 117 |
+
rt = val.clone()
|
| 118 |
+
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
| 119 |
+
rt /= dist.get_world_size()
|
| 120 |
+
self.sum += rt.detach().cpu().double()
|
| 121 |
+
self.n += 1
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def avg(self):
|
| 125 |
+
return self.sum / self.n.double()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class ModelEma:
|
| 129 |
+
def __init__(self, model, decay=0.9999, device='', resume=''):
|
| 130 |
+
self.ema = deepcopy(model)
|
| 131 |
+
self.ema.eval()
|
| 132 |
+
self.decay = decay
|
| 133 |
+
self.device = device
|
| 134 |
+
if device:
|
| 135 |
+
self.ema.to(device=device)
|
| 136 |
+
self.ema_has_module = hasattr(self.ema, 'module')
|
| 137 |
+
if resume:
|
| 138 |
+
self._load_checkpoint(resume)
|
| 139 |
+
for p in self.ema.parameters():
|
| 140 |
+
p.requires_grad_(False)
|
| 141 |
+
|
| 142 |
+
def _load_checkpoint(self, checkpoint_path, mapl=None):
|
| 143 |
+
checkpoint = torch.load(checkpoint_path,map_location=mapl)
|
| 144 |
+
assert isinstance(checkpoint, dict)
|
| 145 |
+
if 'state_dict_ema' in checkpoint:
|
| 146 |
+
new_state_dict = OrderedDict()
|
| 147 |
+
for k, v in checkpoint['state_dict_ema'].items():
|
| 148 |
+
if self.ema_has_module:
|
| 149 |
+
name = 'module.' + k if not k.startswith('module') else k
|
| 150 |
+
else:
|
| 151 |
+
name = k
|
| 152 |
+
new_state_dict[name] = v
|
| 153 |
+
self.ema.load_state_dict(new_state_dict)
|
| 154 |
+
print("=> Loaded state_dict_ema")
|
| 155 |
+
else:
|
| 156 |
+
print("=> Failed to find state_dict_ema, starting from loaded model weights")
|
| 157 |
+
|
| 158 |
+
def update(self, model, num_updates=-1):
|
| 159 |
+
needs_module = hasattr(model, 'module') and not self.ema_has_module
|
| 160 |
+
if num_updates >= 0:
|
| 161 |
+
_cdecay = min(self.decay, (1 + num_updates) / (10 + num_updates))
|
| 162 |
+
else:
|
| 163 |
+
_cdecay = self.decay
|
| 164 |
+
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
msd = model.state_dict()
|
| 167 |
+
for k, ema_v in self.ema.state_dict().items():
|
| 168 |
+
if needs_module:
|
| 169 |
+
k = 'module.' + k
|
| 170 |
+
model_v = msd[k].detach()
|
| 171 |
+
if self.device:
|
| 172 |
+
model_v = model_v.to(device=self.device)
|
| 173 |
+
ema_v.copy_(ema_v * _cdecay + (1. - _cdecay) * model_v)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def format_string_for_wer(str):
|
| 177 |
+
str = re.sub('([\[\]{}/\\()\"\'&+*=<>?.;:,!\-—_€#%°])', r' \1 ', str)
|
| 178 |
+
str = re.sub('([ \n])+', " ", str).strip()
|
| 179 |
+
return str
|
| 180 |
+
|
| 181 |
+
def load_checkpoint(model, model_ema, optimizer, checkpoint_path, logger):
|
| 182 |
+
best_cer, best_wer, start_iter = 1e+6, 1e+6, 1
|
| 183 |
+
train_loss, train_loss_count = 0.0, 0
|
| 184 |
+
optimizer_state = None
|
| 185 |
+
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
|
| 186 |
+
logger.info(f"Resuming from checkpoint: {checkpoint_path}")
|
| 187 |
+
checkpoint = torch.load(
|
| 188 |
+
checkpoint_path, map_location='cpu', weights_only=False)
|
| 189 |
+
|
| 190 |
+
# Load model state dict (handle module prefix like in test.py)
|
| 191 |
+
model_dict = OrderedDict()
|
| 192 |
+
pattern = re.compile('module.')
|
| 193 |
+
|
| 194 |
+
# For main model, load from the 'model' state dict
|
| 195 |
+
# (the training checkpoint contains both 'model' and 'state_dict_ema')
|
| 196 |
+
if 'model' in checkpoint:
|
| 197 |
+
source_dict = checkpoint['model']
|
| 198 |
+
logger.info("Loading main model from 'model' state dict")
|
| 199 |
+
elif 'state_dict_ema' in checkpoint:
|
| 200 |
+
source_dict = checkpoint['state_dict_ema']
|
| 201 |
+
logger.info(
|
| 202 |
+
"Loading main model from 'state_dict_ema' (fallback)")
|
| 203 |
+
else:
|
| 204 |
+
raise KeyError(
|
| 205 |
+
"Neither 'model' nor 'state_dict_ema' found in checkpoint")
|
| 206 |
+
|
| 207 |
+
for k, v in source_dict.items():
|
| 208 |
+
if re.search("module", k):
|
| 209 |
+
model_dict[re.sub(pattern, '', k)] = v
|
| 210 |
+
else:
|
| 211 |
+
model_dict[k] = v
|
| 212 |
+
|
| 213 |
+
model.load_state_dict(model_dict, strict=True)
|
| 214 |
+
logger.info("Successfully loaded main model state dict")
|
| 215 |
+
|
| 216 |
+
# Load EMA state dict if available
|
| 217 |
+
if 'state_dict_ema' in checkpoint and model_ema is not None:
|
| 218 |
+
ema_dict = OrderedDict()
|
| 219 |
+
for k, v in checkpoint['state_dict_ema'].items():
|
| 220 |
+
if re.search("module", k):
|
| 221 |
+
ema_dict[re.sub(pattern, '', k)] = v
|
| 222 |
+
else:
|
| 223 |
+
ema_dict[k] = v
|
| 224 |
+
model_ema.ema.load_state_dict(ema_dict, strict=True)
|
| 225 |
+
logger.info("Successfully loaded EMA model state dict")
|
| 226 |
+
|
| 227 |
+
# Load optimizer state - handle SAM optimizer structure
|
| 228 |
+
if 'optimizer' in checkpoint and optimizer is not None:
|
| 229 |
+
try:
|
| 230 |
+
optimizer_state = checkpoint['optimizer']
|
| 231 |
+
logger.info(
|
| 232 |
+
"Optimizer state will be loaded after optimizer initialization")
|
| 233 |
+
except Exception as e:
|
| 234 |
+
logger.warning(f"Failed to prepare optimizer state: {e}")
|
| 235 |
+
optimizer_state = None
|
| 236 |
+
|
| 237 |
+
# Load metrics from checkpoint if available
|
| 238 |
+
if 'best_cer' in checkpoint:
|
| 239 |
+
best_cer = checkpoint['best_cer']
|
| 240 |
+
if 'best_wer' in checkpoint:
|
| 241 |
+
best_wer = checkpoint['best_wer']
|
| 242 |
+
if 'nb_iter' in checkpoint:
|
| 243 |
+
start_iter = checkpoint['nb_iter'] + 1
|
| 244 |
+
|
| 245 |
+
# Parse CER, WER, iter from filename as fallback
|
| 246 |
+
m = re.search(
|
| 247 |
+
r'checkpoint_(?P<cer>[\d\.]+)_(?P<wer>[\d\.]+)_(?P<iter>\d+)\.pth', checkpoint_path)
|
| 248 |
+
if m and 'best_cer' not in checkpoint:
|
| 249 |
+
best_cer = float(m.group('cer'))
|
| 250 |
+
best_wer = float(m.group('wer'))
|
| 251 |
+
start_iter = int(m.group('iter')) + 1
|
| 252 |
+
|
| 253 |
+
if 'train_loss' in checkpoint:
|
| 254 |
+
train_loss = checkpoint['train_loss']
|
| 255 |
+
if 'train_loss_count' in checkpoint:
|
| 256 |
+
train_loss_count = checkpoint['train_loss_count']
|
| 257 |
+
if 'random_state' in checkpoint:
|
| 258 |
+
random.setstate(checkpoint['random_state'])
|
| 259 |
+
logger.info("Restored random state")
|
| 260 |
+
if 'numpy_state' in checkpoint:
|
| 261 |
+
np.random.set_state(checkpoint['numpy_state'])
|
| 262 |
+
logger.info("Restored numpy random state")
|
| 263 |
+
if 'torch_state' in checkpoint:
|
| 264 |
+
torch.set_rng_state(checkpoint['torch_state'])
|
| 265 |
+
logger.info("Restored torch random state")
|
| 266 |
+
if 'torch_cuda_state' in checkpoint and torch.cuda.is_available():
|
| 267 |
+
torch.cuda.set_rng_state(checkpoint['torch_cuda_state'])
|
| 268 |
+
logger.info("Restored torch cuda random state")
|
| 269 |
+
|
| 270 |
+
# Validate that the model was loaded correctly by checking a few parameters
|
| 271 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 272 |
+
logger.info(f"Model loaded with {total_params} total parameters")
|
| 273 |
+
|
| 274 |
+
logger.info(
|
| 275 |
+
f"Resumed best_cer={best_cer}, best_wer={best_wer}, start_iter={start_iter}")
|
| 276 |
+
return best_cer, best_wer, start_iter, optimizer_state, train_loss, train_loss_count
|
valid.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.utils.data
|
| 3 |
+
import torch.backends.cudnn as cudnn
|
| 4 |
+
|
| 5 |
+
from utils import utils
|
| 6 |
+
import editdistance
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def validation(model, criterion, evaluation_loader, converter):
|
| 10 |
+
""" validation or evaluation """
|
| 11 |
+
|
| 12 |
+
norm_ED = 0
|
| 13 |
+
norm_ED_wer = 0
|
| 14 |
+
|
| 15 |
+
tot_ED = 0
|
| 16 |
+
tot_ED_wer = 0
|
| 17 |
+
|
| 18 |
+
valid_loss = 0.0
|
| 19 |
+
length_of_gt = 0
|
| 20 |
+
length_of_gt_wer = 0
|
| 21 |
+
count = 0
|
| 22 |
+
all_preds_str = []
|
| 23 |
+
all_labels = []
|
| 24 |
+
|
| 25 |
+
for i, (image_tensors, labels) in enumerate(evaluation_loader):
|
| 26 |
+
batch_size = image_tensors.size(0)
|
| 27 |
+
image = image_tensors.cuda()
|
| 28 |
+
|
| 29 |
+
text_for_loss, length_for_loss = converter.encode(labels)
|
| 30 |
+
|
| 31 |
+
preds = model(image)
|
| 32 |
+
preds = preds.float()
|
| 33 |
+
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
|
| 34 |
+
preds = preds.permute(1, 0, 2).log_softmax(2)
|
| 35 |
+
|
| 36 |
+
torch.backends.cudnn.enabled = False
|
| 37 |
+
cost = criterion(preds, text_for_loss, preds_size, length_for_loss).mean()
|
| 38 |
+
torch.backends.cudnn.enabled = True
|
| 39 |
+
|
| 40 |
+
_, preds_index = preds.max(2)
|
| 41 |
+
preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
|
| 42 |
+
preds_str = converter.decode(preds_index.data, preds_size.data)
|
| 43 |
+
valid_loss += cost.item()
|
| 44 |
+
count += 1
|
| 45 |
+
|
| 46 |
+
all_preds_str.extend(preds_str)
|
| 47 |
+
all_labels.extend(labels)
|
| 48 |
+
|
| 49 |
+
for pred_cer, gt_cer in zip(preds_str, labels):
|
| 50 |
+
tmp_ED = editdistance.eval(pred_cer, gt_cer)
|
| 51 |
+
if len(gt_cer) == 0:
|
| 52 |
+
norm_ED += 1
|
| 53 |
+
else:
|
| 54 |
+
norm_ED += tmp_ED / float(len(gt_cer))
|
| 55 |
+
tot_ED += tmp_ED
|
| 56 |
+
length_of_gt += len(gt_cer)
|
| 57 |
+
|
| 58 |
+
for pred_wer, gt_wer in zip(preds_str, labels):
|
| 59 |
+
pred_wer = utils.format_string_for_wer(pred_wer)
|
| 60 |
+
gt_wer = utils.format_string_for_wer(gt_wer)
|
| 61 |
+
pred_wer = pred_wer.split(" ")
|
| 62 |
+
gt_wer = gt_wer.split(" ")
|
| 63 |
+
tmp_ED_wer = editdistance.eval(pred_wer, gt_wer)
|
| 64 |
+
|
| 65 |
+
if len(gt_wer) == 0:
|
| 66 |
+
norm_ED_wer += 1
|
| 67 |
+
else:
|
| 68 |
+
norm_ED_wer += tmp_ED_wer / float(len(gt_wer))
|
| 69 |
+
|
| 70 |
+
tot_ED_wer += tmp_ED_wer
|
| 71 |
+
length_of_gt_wer += len(gt_wer)
|
| 72 |
+
|
| 73 |
+
val_loss = valid_loss / count
|
| 74 |
+
CER = tot_ED / float(length_of_gt)
|
| 75 |
+
WER = tot_ED_wer / float(length_of_gt_wer)
|
| 76 |
+
|
| 77 |
+
return val_loss, CER, WER, all_preds_str, all_labels
|