k0ry commited on
Commit
646f45c
·
verified ·
1 Parent(s): cc8a699

Upload 20 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,36 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ image/architecture.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,220 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ language:
4
+ - en
5
+ - vi
6
+ pipeline_tag: image-to-text
7
+ model-index:
8
+ - name: HTR-ConvText
9
+ results:
10
+ - task:
11
+ type: image-to-text
12
+ name: Handwritten Text Recognition
13
+ dataset:
14
+ name: IAM
15
+ type: iam
16
+ split: test
17
+ metrics:
18
+ - type: cer
19
+ value: 4.0
20
+ name: Test CER
21
+ - type: wer
22
+ value: 12.9
23
+ name: Test WER
24
+ - task:
25
+ type: image-to-text
26
+ name: Handwritten Text Recognition
27
+ dataset:
28
+ name: LAM
29
+ type: lam
30
+ split: test
31
+ metrics:
32
+ - type: cer
33
+ value: 2.7
34
+ name: Test CER
35
+ - type: wer
36
+ value: 7.0
37
+ name: Test WER
38
+ - task:
39
+ type: image-to-text
40
+ name: Handwritten Text Recognition
41
+ dataset:
42
+ name: READ2016
43
+ type: read2016
44
+ split: test
45
+ metrics:
46
+ - type: cer
47
+ value: 3.6
48
+ name: Test CER
49
+ - type: wer
50
+ value: 15.7
51
+ name: Test WER
52
+ - task:
53
+ type: image-to-text
54
+ name: Handwritten Text Recognition
55
+ dataset:
56
+ name: HANDS-VNOnDB
57
+ type: hands-vnondb
58
+ split: test
59
+ metrics:
60
+ - type: cer
61
+ value: 3.45
62
+ name: Test CER
63
+ - type: wer
64
+ value: 8.9
65
+ name: Test WER
66
  ---
67
+ ---
68
+ # HTR-ConvText: Leveraging Convolution and Textual Information for Handwritten Text Recognition
69
+
70
+ <div align="center"> <img src="image/architecture.png" alt="HTR-ConvText Architecture" width="800"/> </div>
71
+
72
+ <p align="center">
73
+ <a href="https://huggingface.co/DAIR-Group/HTR-ConvText">
74
+ <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue">
75
+ </a>
76
+ <a href="https://github.com/DAIR-Group/HTR-ConvText">
77
+ <img alt="GitHub" src="https://img.shields.io/badge/GitHub-Repo-181717.svg?logo=github&logoColor=white">
78
+ </a>
79
+ <a href="https://github.com/DAIR-Group/HTR-ConvText/blob/main/LICENSE">
80
+ <img alt="License" src="https://img.shields.io/badge/License-Apache%202.0-green">
81
+ </a>
82
+ <a href="https://arxiv.org/abs/2512.05021">
83
+ <img alt="arXiv" src="https://img.shields.io/badge/arXiv-2512.05021-b31b1b.svg">
84
+ </a>
85
+ </p>
86
+
87
+ ## Highlights
88
+
89
+ HTR-ConvText is a novel hybrid architecture for Handwritten Text Recognition (HTR) that effectively balances local feature extraction with global contextual modeling. Designed to overcome the limitations of standard CTC-based decoding and data-hungry Transformers, HTR-ConvText delivers state-of-the-art performance with the following key features:
90
+
91
+ - **Hybrid CNN-ViT Architecture**: Seamlessly integrates a ResNet backbone with MobileViT blocks (MVP) and Conditional Positional Encoding, enabling the model to capture fine-grained stroke details while maintaining global spatial awareness.
92
+ - **Hierarchical ConvText Encoder**: A U-Net-like encoder structure that interleaves Multi-Head Self-Attention with Depthwise Convolutions. This design efficiently models both long-range dependencies and local structural patterns.
93
+ - **Textual Context Module (TCM)**: An innovative training-only auxiliary module that injects bidirectional linguistic priors into the visual encoder. This mitigates the conditional independence weakness of CTC decoding without adding any latency during inference.
94
+ - **State-of-the-Art Performance**: Outperforms existing methods on major benchmarks including IAM (English), READ2016 (German), LAM (Italian), and HANDS-VNOnDB (Vietnamese), specifically excelling in low-resource scenarios and complex diacritics.
95
+
96
+ ## Model Overview
97
+
98
+ HTR-ConvText configurations and specifications:
99
+
100
+ | Feature | Specification |
101
+ | ------------------- | --------------------------------------------------- |
102
+ | Architecture Type | Hybrid CNN + Vision Transformer (Encoder-Only) |
103
+ | Parameters | ~65.9M |
104
+ | Backbone | ResNet-18 + MobileViT w/ Positional Encoding (MVP) |
105
+ | Encoder Layers | 8 ConvText Blocks (Hierarchical) |
106
+ | Attention Heads | 8 |
107
+ | Embedding Dimension | 512 |
108
+ | Image Input Size | 512×64 |
109
+ | Inference Strategy | Standard CTC Decoding (TCM is removed at inference) |
110
+
111
+ For more details, including ablation studies and theoretical proofs, please refer to our [Technical Report](https://arxiv.org/pdf/2512.05021).
112
+
113
+ ## Performance
114
+
115
+ We evaluated HTR-ConvText across four diverse datasets. The model achieves new SOTA results with the lowest Character Error Rate (CER) and Word Error Rate (WER) without requiring massive synthetic pre-training.
116
+
117
+ | Dataset | Language | Ours CER (%) | HTR-VT | OrigamiNet | TrOCR | CRNN |
118
+ |-----------|-------------|--------------|--------|------------|-------|-------|
119
+ | IAM | English | 4.0 | 4.7 | 4.8 | 7.3 | 7.8 |
120
+ | LAM | Italian | 2.7 | 2.8 | 3.0 | 3.6 | 3.8 |
121
+ | READ2016 | German | 3.6 | 3.9 | - | - | 4.7 |
122
+ | VNOnDB | Vietnamese | 3.45 | 4.26 | 7.6 | - | 10.53 |
123
+
124
+ ## Quickstart
125
+
126
+ ### Instalation
127
+
128
+ 1. **Clone the repository**
129
+ ```cmd
130
+ git clone https://github.com/0xk0ry/HTR-ConvText.git
131
+ cd HTR-ConvText
132
+ ```
133
+ 2. **Create and activate a Python 3.9+ Conda environment**
134
+ ```cmd
135
+ conda create -n htr-convtext python=3.9 -y
136
+ conda activate htr-convtext
137
+ ```
138
+ 3. **Install PyTorch** using the wheel that matches your CUDA driver (swap the index for CPU-only builds):
139
+ ```cmd
140
+ pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126
141
+ ```
142
+ 4. **Install the remaining project requirements** (everything except PyTorch, which you already picked in step 3).
143
+ ```cmd
144
+ pip install -r requirements.txt
145
+ ```
146
+
147
+ The code was tested on Python 3.9 and PyTorch 2.9.1.
148
+
149
+ ### Data Preparation
150
+
151
+ We provide split files (train.ln, val.ln, test.ln) for IAM, READ2016, LAM, and VNOnDB under data/. Organize your data as follows:
152
+
153
+ ```
154
+ ./data/iam/
155
+ ├── train.ln
156
+ ├── val.ln
157
+ ├── test.ln
158
+ └── lines
159
+ ├── a01-000u-00.png
160
+ ├── a01-000u-00.txt
161
+ └── ...
162
+ ```
163
+
164
+ ### Training
165
+
166
+ We provide comprehensive scripts in the ./run/ directory. To train on the IAM dataset with the Textual Context Module (TCM) enabled:
167
+
168
+ ```
169
+ # Using the provided script
170
+ bash run/iam.sh
171
+
172
+ # OR running directly via Python
173
+ python train.py \
174
+ --use-wandb \
175
+ --dataset iam \
176
+ --tcm-enable \
177
+ --exp-name "htr-convtext-iam" \
178
+ --img-size 512 64 \
179
+ --train-bs 32 \
180
+ --val-bs 8 \
181
+ --data-path /path/to/iam/lines/ \
182
+ --train-data-list data/iam/train.ln \
183
+ --val-data-list data/iam/val.ln \
184
+ --test-data-list data/iam/test.ln \
185
+ --nb-cls 80
186
+ ```
187
+
188
+ ### Inference / Evaluation
189
+
190
+ To evaluate a pre-trained checkpoint on the test set:
191
+
192
+ ```
193
+ python test.py \
194
+ --resume ./checkpoints/best_CER.pth \
195
+ --dataset iam \
196
+ --img-size 512 64 \
197
+ --data-path /path/to/iam/lines/ \
198
+ --test-data-list data/iam/test.ln \
199
+ --nb-cls 80
200
+ ```
201
+
202
+ ## Citation
203
+
204
+ If you find our work helpful, please cite our paper:
205
+
206
+ ```
207
+ @misc{truc2025htrconvtex,
208
+ title={HTR-ConvText: Leveraging Convolution and Textual Information for Handwritten Text Recognition},
209
+ author={Pham Thach Thanh Truc and Dang Hoai Nam and Huynh Tong Dang Khoa and Vo Nguyen Le Duy},
210
+ year={2025},
211
+ eprint={2512.05021},
212
+ archivePrefix={arXiv},
213
+ primaryClass={cs.CV},
214
+ url={https://arxiv.org/abs/2512.05021},
215
+ }
216
+ ```
217
+
218
+ ## Acknowledgement
219
+
220
+ This project is inspired by and adapted from [HTR-VT](https://github.com/Intellindust-AI-Lab/HTR-VT). We gratefully acknowledge the authors for their open-source contributions.
data/dataset.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import ColorJitter
2
+ from data import transform as transform
3
+ from utils import utils
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+ import itertools
7
+ import os
8
+ import skimage
9
+ import torch
10
+ import numpy as np
11
+
12
+
13
+ def SameTrCollate(batch, args):
14
+
15
+ images, labels = zip(*batch)
16
+ images = [Image.fromarray(np.uint8(images[i][0] * 255))
17
+ for i in range(len(images))]
18
+
19
+ # Apply data augmentations with 90% probability
20
+ if np.random.rand() < 0.5:
21
+ images = [transform.RandomTransform(
22
+ args.proj)(image) for image in images]
23
+
24
+ if np.random.rand() < 0.5:
25
+ kernel_h = utils.randint(1, args.dila_ero_max_kernel + 1)
26
+ kernel_w = utils.randint(1, args.dila_ero_max_kernel + 1)
27
+ if utils.randint(0, 2) == 0:
28
+ images = [transform.Erosion((kernel_w, kernel_h), args.dila_ero_iter)(
29
+ image) for image in images]
30
+ else:
31
+ images = [transform.Dilation((kernel_w, kernel_h), args.dila_ero_iter)(
32
+ image) for image in images]
33
+
34
+ if np.random.rand() < 0.5:
35
+ images = [ColorJitter(args.jitter_brightness, args.jitter_contrast, args.jitter_saturation,
36
+ args.jitter_hue)(image) for image in images]
37
+
38
+ # Convert images to tensors
39
+
40
+ image_tensors = [torch.from_numpy(
41
+ np.array(image, copy=True)) for image in images]
42
+ image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)
43
+ image_tensors = image_tensors.unsqueeze(1).float()
44
+ image_tensors = image_tensors / 255.
45
+ return image_tensors, labels
46
+
47
+
48
+ class myLoadDS(Dataset):
49
+ def __init__(self, flist, dpath, img_size=[512, 32], ralph=None, fmin=True, mln=None, dataset=None):
50
+ self.fns = get_files(flist, dpath)
51
+ self.tlbls = get_labels(self.fns)
52
+ self.img_size = img_size
53
+ if ralph is not None:
54
+ self.ralph = ralph
55
+ elif dataset is not None:
56
+ if dataset == 'iam':
57
+ self.ralph = {
58
+ idx: char for idx, char in enumerate(
59
+ ' !"#&\'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
60
+ )
61
+ }
62
+ elif dataset == 'lam':
63
+ self.ralph = {
64
+ idx: char for idx, char in enumerate(
65
+ ' !"#%&\'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXZabcdefghijlmnopqrstuvwxyz|°·ÈÉàèéìòù–'
66
+ )
67
+ }
68
+ elif dataset == 'read2016':
69
+ self.ralph = {
70
+ idx: char for idx, char in enumerate(
71
+ ' ()+,-./0123456789:<>ABCDEFGHIJKLMNOPQRSTUVWYZ[]abcdefghijklmnopqrstuvwxyz¾Ößäöüÿāēōūȳ̄̈—'
72
+ )
73
+ }
74
+ elif dataset == 'vnondb':
75
+ self.ralph = {
76
+ idx: char for idx, char in enumerate(
77
+ ' !"%&()*,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvxyzÀÁÂÔÚÝàáâãèéêìíòóôõùúýĂăĐđĩũƠơƯưạẢảẤấẦầẩẫậắằẳẵặẹẻẽếỀềỂểễỆệỉịọỏỐốỒồổỗộớờỞởỡợụỦủứừửữựỳỷỹ'
78
+ )
79
+ }
80
+ else:
81
+ alph = get_alphabet(self.tlbls)
82
+ self.ralph = dict(zip(alph.values(), alph.keys()))
83
+ self.alph = alph
84
+ else:
85
+ alph = get_alphabet(self.tlbls)
86
+ self.ralph = dict(zip(alph.values(), alph.keys()))
87
+ self.alph = alph
88
+ if mln != None:
89
+ filt = [len(x) <= mln if fmin else len(x)
90
+ >= mln for x in self.tlbls]
91
+ self.tlbls = np.asarray(self.tlbls)[filt].tolist()
92
+ self.fns = np.asarray(self.fns)[filt].tolist()
93
+
94
+ def __len__(self):
95
+ return len(self.fns)
96
+
97
+ def __getitem__(self, index):
98
+ timgs = get_images(self.fns[index], self.img_size[0], self.img_size[1])
99
+ timgs = timgs.transpose((2, 0, 1))
100
+
101
+ return (timgs, self.tlbls[index])
102
+
103
+
104
+ def _read_text(path):
105
+ """Read a text file with robust encoding handling.
106
+ Try UTF-8 first, then fall back to common Windows encodings.
107
+ """
108
+ encodings = ['utf-8', 'utf-8-sig', 'cp1258', 'cp1252', 'latin-1']
109
+ last_err = None
110
+ for enc in encodings:
111
+ try:
112
+ with open(path, 'r', encoding=enc) as f:
113
+ return f.read()
114
+ except UnicodeDecodeError as e:
115
+ last_err = e
116
+ continue
117
+ except FileNotFoundError:
118
+ raise
119
+ # As a last resort, ignore errors to avoid crashing the training loop
120
+ with open(path, 'r', encoding='utf-8', errors='ignore') as f:
121
+ return f.read()
122
+
123
+
124
+ def _read_lines(path):
125
+ txt = _read_text(path)
126
+ return txt.splitlines()
127
+
128
+
129
+ def get_files(nfile, dpath):
130
+ fnames = _read_lines(nfile)
131
+ fnames = [dpath + x.strip() for x in fnames]
132
+ return fnames
133
+
134
+
135
+ def npThum(img, max_w, max_h):
136
+ x, y = np.shape(img)[:2]
137
+
138
+ y = min(int(y * max_h / x), max_w)
139
+ x = max_h
140
+
141
+ img = np.array(Image.fromarray(img).resize((y, x)))
142
+ return img
143
+
144
+
145
+ def get_images(fname, max_w=500, max_h=500, nch=1): # args.max_w args.max_h args.nch
146
+
147
+ try:
148
+
149
+ image_data = np.array(Image.open(fname).convert('L'))
150
+ image_data = npThum(image_data, max_w, max_h)
151
+ image_data = skimage.img_as_float32(image_data)
152
+
153
+ h, w = np.shape(image_data)[:2]
154
+ if image_data.ndim < 3:
155
+ image_data = np.expand_dims(image_data, axis=-1)
156
+
157
+ if nch == 3 and image_data.shape[2] != 3:
158
+ image_data = np.tile(image_data, 3)
159
+
160
+ image_data = np.pad(image_data, ((0, 0), (0, max_w - np.shape(image_data)[1]), (0, 0)), mode='constant',
161
+ constant_values=(1.0))
162
+
163
+ except IOError as e:
164
+ print('Could not read:', fname, ':', e)
165
+
166
+ return image_data
167
+
168
+
169
+ def get_labels(fnames):
170
+ labels = []
171
+ for id, image_file in enumerate(fnames):
172
+ fn = os.path.splitext(image_file)[0] + '.txt'
173
+ lbl = _read_text(fn)
174
+ lbl = ' '.join(lbl.split()) # remove linebreaks if present
175
+ labels.append(lbl)
176
+
177
+ return labels
178
+
179
+
180
+ def get_alphabet(labels):
181
+ coll = ''.join(labels)
182
+ unq = sorted(list(set(coll)))
183
+ unq = [''.join(i) for i in itertools.product(unq, repeat=1)]
184
+ alph = dict(zip(unq, range(len(unq))))
185
+
186
+ return alph
187
+
188
+
189
+ def cycle_dpp(iterable):
190
+ epoch = 0
191
+ iterable.sampler.set_epoch(epoch)
192
+ while True:
193
+ for x in iterable:
194
+ yield x
195
+ epoch += 1
196
+ iterable.sampler.set_epoch(epoch)
197
+
198
+
199
+ def cycle_data(iterable):
200
+ while True:
201
+ for x in iterable:
202
+ yield x
data/transform.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import cv2
3
+ import numpy as np
4
+ from skimage import transform as stf
5
+ from numpy import random, floor
6
+ from PIL import Image, ImageOps
7
+ from cv2 import erode, dilate, normalize
8
+ from torchvision.transforms import RandomCrop
9
+ import math
10
+
11
+ class Dilation:
12
+ """
13
+ OCR: stroke width increasing
14
+ """
15
+ def __init__(self, kernel, iterations):
16
+ self.kernel = np.ones(kernel, np.uint8)
17
+ self.iterations = iterations
18
+
19
+ def __call__(self, x):
20
+ return Image.fromarray(dilate(np.array(x), self.kernel, iterations=self.iterations))
21
+
22
+
23
+ class Erosion:
24
+ """
25
+ OCR: stroke width decreasing
26
+ """
27
+
28
+ def __init__(self, kernel, iterations):
29
+ self.kernel = np.ones(kernel, np.uint8)
30
+ self.iterations = iterations
31
+
32
+ def __call__(self, x):
33
+ return Image.fromarray(erode(np.array(x), self.kernel, iterations=self.iterations))
34
+
35
+
36
+ class ElasticDistortion:
37
+ """
38
+ Elastic Distortion adapted from https://github.com/IntuitionMachines/OrigamiNet
39
+ Used in "OrigamiNet: Weakly-Supervised, Segmentation-Free, One-Step, Full Page TextRecognition by learning to unfold",
40
+ Yousef, Mohamed and Bishop, Tom E., The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2020
41
+ """
42
+
43
+ def __init__(self, grid, magnitude, min_sep):
44
+
45
+ self.grid_width, self.grid_height = grid
46
+ self.xmagnitude, self.ymagnitude = magnitude
47
+ self.min_h_sep, self.min_v_sep = min_sep
48
+
49
+ def __call__(self, x):
50
+ w, h = x.size
51
+
52
+ horizontal_tiles = self.grid_width
53
+ vertical_tiles = self.grid_height
54
+
55
+ width_of_square = int(floor(w / float(horizontal_tiles)))
56
+ height_of_square = int(floor(h / float(vertical_tiles)))
57
+
58
+ width_of_last_square = w - (width_of_square * (horizontal_tiles - 1))
59
+ height_of_last_square = h - (height_of_square * (vertical_tiles - 1))
60
+
61
+ dimensions = []
62
+ shift = [[(0, 0) for x in range(horizontal_tiles)] for y in range(vertical_tiles)]
63
+
64
+ for vertical_tile in range(vertical_tiles):
65
+ for horizontal_tile in range(horizontal_tiles):
66
+ if vertical_tile == (vertical_tiles - 1) and horizontal_tile == (horizontal_tiles - 1):
67
+ dimensions.append([horizontal_tile * width_of_square,
68
+ vertical_tile * height_of_square,
69
+ width_of_last_square + (horizontal_tile * width_of_square),
70
+ height_of_last_square + (height_of_square * vertical_tile)])
71
+ elif vertical_tile == (vertical_tiles - 1):
72
+ dimensions.append([horizontal_tile * width_of_square,
73
+ vertical_tile * height_of_square,
74
+ width_of_square + (horizontal_tile * width_of_square),
75
+ height_of_last_square + (height_of_square * vertical_tile)])
76
+ elif horizontal_tile == (horizontal_tiles - 1):
77
+ dimensions.append([horizontal_tile * width_of_square,
78
+ vertical_tile * height_of_square,
79
+ width_of_last_square + (horizontal_tile * width_of_square),
80
+ height_of_square + (height_of_square * vertical_tile)])
81
+ else:
82
+ dimensions.append([horizontal_tile * width_of_square,
83
+ vertical_tile * height_of_square,
84
+ width_of_square + (horizontal_tile * width_of_square),
85
+ height_of_square + (height_of_square * vertical_tile)])
86
+
87
+ sm_h = min(self.xmagnitude,
88
+ width_of_square - (self.min_h_sep + shift[vertical_tile][horizontal_tile - 1][
89
+ 0])) if horizontal_tile > 0 else self.xmagnitude
90
+ sm_v = min(self.ymagnitude,
91
+ height_of_square - (self.min_v_sep + shift[vertical_tile - 1][horizontal_tile][
92
+ 1])) if vertical_tile > 0 else self.ymagnitude
93
+
94
+ dx = random.randint(-sm_h, self.xmagnitude)
95
+ dy = random.randint(-sm_v, self.ymagnitude)
96
+ shift[vertical_tile][horizontal_tile] = (dx, dy)
97
+
98
+ shift = list(itertools.chain.from_iterable(shift))
99
+
100
+ last_column = []
101
+ for i in range(vertical_tiles):
102
+ last_column.append((horizontal_tiles - 1) + horizontal_tiles * i)
103
+
104
+ last_row = range((horizontal_tiles * vertical_tiles) - horizontal_tiles, horizontal_tiles * vertical_tiles)
105
+
106
+ polygons = []
107
+ for x1, y1, x2, y2 in dimensions:
108
+ polygons.append([x1, y1, x1, y2, x2, y2, x2, y1])
109
+
110
+ polygon_indices = []
111
+ for i in range((vertical_tiles * horizontal_tiles) - 1):
112
+ if i not in last_row and i not in last_column:
113
+ polygon_indices.append([i, i + 1, i + horizontal_tiles, i + 1 + horizontal_tiles])
114
+
115
+ for id, (a, b, c, d) in enumerate(polygon_indices):
116
+ dx = shift[id][0]
117
+ dy = shift[id][1]
118
+
119
+ x1, y1, x2, y2, x3, y3, x4, y4 = polygons[a]
120
+ polygons[a] = [x1, y1,
121
+ x2, y2,
122
+ x3 + dx, y3 + dy,
123
+ x4, y4]
124
+
125
+ x1, y1, x2, y2, x3, y3, x4, y4 = polygons[b]
126
+ polygons[b] = [x1, y1,
127
+ x2 + dx, y2 + dy,
128
+ x3, y3,
129
+ x4, y4]
130
+
131
+ x1, y1, x2, y2, x3, y3, x4, y4 = polygons[c]
132
+ polygons[c] = [x1, y1,
133
+ x2, y2,
134
+ x3, y3,
135
+ x4 + dx, y4 + dy]
136
+
137
+ x1, y1, x2, y2, x3, y3, x4, y4 = polygons[d]
138
+ polygons[d] = [x1 + dx, y1 + dy,
139
+ x2, y2,
140
+ x3, y3,
141
+ x4, y4]
142
+
143
+ generated_mesh = []
144
+ for i in range(len(dimensions)):
145
+ generated_mesh.append([dimensions[i], polygons[i]])
146
+
147
+ self.generated_mesh = generated_mesh
148
+
149
+ return x.transform(x.size, Image.MESH, self.generated_mesh, resample=Image.BICUBIC)
150
+
151
+ class RandomTransform:
152
+ """
153
+ Random Transform adapted from https://github.com/IntuitionMachines/OrigamiNet
154
+ Used in "OrigamiNet: Weakly-Supervised, Segmentation-Free, One-Step, Full Page TextRecognition by learning to unfold",
155
+ Yousef, Mohamed and Bishop, Tom E., The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2020
156
+ """
157
+ def __init__(self, val):
158
+
159
+ self.val = val
160
+
161
+ def __call__(self, x):
162
+ w, h = x.size
163
+
164
+ dw, dh = (self.val, 0) if random.randint(0, 2) == 0 else (0, self.val)
165
+
166
+ def rd(d):
167
+ return random.uniform(-d, d)
168
+
169
+ def fd(d):
170
+ return random.uniform(-dw, d)
171
+
172
+ # generate a random projective transform
173
+ # adapted from https://navoshta.com/traffic-signs-classification/
174
+ tl_top = rd(dh)
175
+ tl_left = fd(dw)
176
+ bl_bottom = rd(dh)
177
+ bl_left = fd(dw)
178
+ tr_top = rd(dh)
179
+ tr_right = fd(min(w * 3 / 4 - tl_left, dw))
180
+ br_bottom = rd(dh)
181
+ br_right = fd(min(w * 3 / 4 - bl_left, dw))
182
+
183
+ tform = stf.ProjectiveTransform()
184
+ tform.estimate(np.array(( #从对应点估计变换矩阵
185
+ (tl_left, tl_top),
186
+ (bl_left, h - bl_bottom),
187
+ (w - br_right, h - br_bottom),
188
+ (w - tr_right, tr_top)
189
+ )), np.array((
190
+ [0, 0],
191
+ [0, h - 1],
192
+ [w - 1, h - 1],
193
+ [w - 1, 0]
194
+ )))
195
+
196
+ # determine shape of output image, to preserve size
197
+ # trick take from the implementation of skimage.transform.rotate
198
+ corners = np.array([
199
+ [0, 0],
200
+ [0, h - 1],
201
+ [w - 1, h - 1],
202
+ [w - 1, 0]
203
+ ])
204
+
205
+ corners = tform.inverse(corners)
206
+ minc = corners[:, 0].min()
207
+ minr = corners[:, 1].min()
208
+ maxc = corners[:, 0].max()
209
+ maxr = corners[:, 1].max()
210
+ out_rows = maxr - minr + 1
211
+ out_cols = maxc - minc + 1
212
+ output_shape = np.around((out_rows, out_cols))
213
+
214
+ # fit output image in new shape
215
+ translation = (minc, minr)
216
+ tform4 = stf.SimilarityTransform(translation=translation)
217
+ tform = tform4 + tform
218
+ # normalize
219
+ tform.params /= tform.params[2, 2]
220
+
221
+ x = stf.warp(np.array(x), tform, output_shape=output_shape, cval=255, preserve_range=True)
222
+ x = stf.resize(x, (h, w), preserve_range=True).astype(np.uint8)
223
+
224
+ return Image.fromarray(x)
225
+
226
+
227
+ class SignFlipping:
228
+ """
229
+ Color inversion
230
+ """
231
+
232
+ def __init__(self):
233
+ pass
234
+
235
+ def __call__(self, x):
236
+ return ImageOps.invert(x)
237
+
238
+
239
+ class DPIAdjusting:
240
+ """
241
+ Resolution modification
242
+ """
243
+
244
+ def __init__(self, factor, preserve_ratio):
245
+ self.factor = factor
246
+
247
+ def __call__(self, x):
248
+ w, h = x.size
249
+ return x.resize((int(np.ceil(w * self.factor)), int(np.ceil(h * self.factor))), Image.BILINEAR)
250
+
251
+
252
+
253
+ class GaussianNoise:
254
+ """
255
+ Add Gaussian Noise
256
+ """
257
+
258
+ def __init__(self, std):
259
+ self.std = std
260
+
261
+ def __call__(self, x):
262
+ x_np = np.array(x)
263
+ mean, std = np.mean(x_np), np.std(x_np)
264
+ std = math.copysign(max(abs(std), 0.000001), std)
265
+ min_, max_ = np.min(x_np,), np.max(x_np)
266
+ normal_noise = np.random.randn(*x_np.shape)
267
+ if len(x_np.shape) == 3 and x_np.shape[2] == 3 and np.all(x_np[:, :, 0] == x_np[:, :, 1]) and np.all(x_np[:, :, 0] == x_np[:, :, 2]):
268
+ normal_noise[:, :, 1] = normal_noise[:, :, 2] = normal_noise[:, :, 0]
269
+ x_np = ((x_np-mean)/std + normal_noise*self.std) * std + mean
270
+ x_np = normalize(x_np, x_np, max_, min_, cv2.NORM_MINMAX)
271
+
272
+ return Image.fromarray(x_np.astype(np.uint8))
273
+
274
+
275
+ class Sharpen:
276
+ """
277
+ Add Gaussian Noise
278
+ """
279
+
280
+ def __init__(self, alpha, strength):
281
+ self.alpha = alpha
282
+ self.strength = strength
283
+
284
+ def __call__(self, x):
285
+ x_np = np.array(x)
286
+ id_matrix = np.array([[0, 0, 0],
287
+ [0, 1, 0],
288
+ [0, 0, 0]]
289
+ )
290
+ effect_matrix = np.array([[1, 1, 1],
291
+ [1, -(8+self.strength), 1],
292
+ [1, 1, 1]]
293
+ )
294
+ kernel = (1 - self.alpha) * id_matrix - self.alpha * effect_matrix
295
+ kernel = np.expand_dims(kernel, axis=2)
296
+ kernel = np.concatenate([kernel, kernel, kernel], axis=2)
297
+ sharpened = cv2.filter2D(x_np, -1, kernel=kernel[:, :, 0])
298
+ return Image.fromarray(sharpened.astype(np.uint8))
299
+
300
+
301
+ class ZoomRatio:
302
+ """
303
+ Crop by ratio
304
+ Preserve dimensions if keep_dim = True (= zoom)
305
+ """
306
+
307
+ def __init__(self, ratio_h, ratio_w, keep_dim=True):
308
+ self.ratio_w = ratio_w
309
+ self.ratio_h = ratio_h
310
+ self.keep_dim = keep_dim
311
+
312
+ def __call__(self, x):
313
+ w, h = x.size
314
+ x = RandomCrop((int(h * self.ratio_h), int(w * self.ratio_w)))(x)
315
+ if self.keep_dim:
316
+ x = x.resize((w, h), Image.BILINEAR)
317
+ return x
318
+
319
+
320
+ class Tightening:
321
+ """
322
+ Reduce interline spacing
323
+ """
324
+
325
+ def __init__(self, color=255, remove_proba=0.75):
326
+ self.color = color
327
+ self.remove_proba = remove_proba
328
+
329
+ def __call__(self, x):
330
+ x_np = np.array(x)
331
+ interline_indices = [np.all(line == 255) for line in x_np]
332
+ indices_to_removed = np.logical_and(np.random.choice([True, False], size=len(x_np), replace=True, p=[self.remove_proba, 1-self.remove_proba]), interline_indices)
333
+ new_x = x_np[np.logical_not(indices_to_removed)]
334
+ return Image.fromarray(new_x.astype(np.uint8))
image/architecture.png ADDED

Git LFS Details

  • SHA256: f4e7e266e92b47867035820e9aa2529470278d11d99838574c23e6d901b77bc2
  • Pointer size: 131 Bytes
  • Size of remote file: 797 kB
model/htr_convtext.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from timm.models.vision_transformer import Mlp, DropPath
5
+ from timm.layers import LayerScale
6
+ import numpy as np
7
+ from model import resnet18
8
+ from functools import partial
9
+ import random
10
+ import re
11
+ import warnings
12
+
13
+
14
+ class RelativePositionBias1D(nn.Module):
15
+ def __init__(self, num_heads: int, max_rel_positions: int = 1024):
16
+ super().__init__()
17
+ self.num_heads = num_heads
18
+ self.max_rel_positions = max(1, int(max_rel_positions))
19
+ self.bias = nn.Embedding(2 * self.max_rel_positions - 1, num_heads)
20
+ nn.init.zeros_(self.bias.weight)
21
+
22
+ def forward(self, N: int) -> torch.Tensor:
23
+ device = self.bias.weight.device
24
+ coords = torch.arange(N, device=device)
25
+ rel = coords[:, None] - coords[None, :]
26
+ rel = rel.clamp(-self.max_rel_positions + 1,
27
+ self.max_rel_positions - 1)
28
+ rel = rel + (self.max_rel_positions - 1)
29
+ bias = self.bias(rel)
30
+ return bias.permute(2, 0, 1).unsqueeze(0)
31
+
32
+
33
+ class Attention(nn.Module):
34
+ def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
35
+ super().__init__()
36
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
37
+ self.num_heads = num_heads
38
+ head_dim = dim // num_heads
39
+ self.scale = head_dim ** -0.5
40
+ max_rel_positions = max(
41
+ 1, int(num_patches)) if num_patches is not None else 1024
42
+ self.rel_pos_bias = RelativePositionBias1D(
43
+ num_heads=num_heads, max_rel_positions=max_rel_positions)
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x):
50
+ B, N, C = x.shape
51
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
52
+ self.num_heads).permute(2, 0, 3, 1, 4)
53
+ q, k, v = qkv.unbind(0)
54
+
55
+ attn = (q @ k.transpose(-2, -1)) * self.scale
56
+ attn = attn + self.rel_pos_bias(N)
57
+ attn = attn.softmax(dim=-1)
58
+ attn = self.attn_drop(attn)
59
+
60
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
61
+ x = self.proj(x)
62
+ x = self.proj_drop(x)
63
+ return x
64
+
65
+
66
+ class FeedForward(nn.Module):
67
+ def __init__(self, dim, hidden_dim, dropout=0.1, activation=nn.SiLU):
68
+ super().__init__()
69
+ self.lin1 = nn.Linear(dim, hidden_dim)
70
+ self.act = activation()
71
+ self.lin2 = nn.Linear(hidden_dim, dim)
72
+ self.dropout = nn.Dropout(dropout)
73
+
74
+ def forward(self, x):
75
+ return self.dropout(self.lin2(self.act(self.lin1(x))))
76
+
77
+
78
+ class ConvModule(nn.Module):
79
+ def __init__(self, dim, kernel_size=3, dropout=0.1, drop_path=0.0,
80
+ expansion=1.0, pre_norm=False, activation=nn.SiLU):
81
+ super().__init__()
82
+ self.pre_norm = nn.LayerNorm(dim) if pre_norm else None
83
+ hidden = int(round(dim * expansion))
84
+
85
+ self.pw1 = nn.Conv1d(dim, hidden, kernel_size=1, bias=True)
86
+ self.act1 = activation()
87
+
88
+ self.dw = nn.Conv1d(hidden, hidden, kernel_size=kernel_size,
89
+ padding=kernel_size // 2, groups=hidden, bias=True)
90
+ self.gn = nn.GroupNorm(1, hidden, eps=1e-5)
91
+ self.act2 = activation()
92
+
93
+ self.pw2 = nn.Conv1d(hidden, dim, kernel_size=1, bias=True)
94
+ self.dropout = nn.Dropout(dropout)
95
+ self.drop_path = DropPath(
96
+ drop_path) if drop_path > 0.0 else nn.Identity()
97
+
98
+ def forward(self, x):
99
+ if self.pre_norm is not None:
100
+ x = self.pre_norm(x)
101
+ z = x.transpose(1, 2)
102
+ z = self.pw1(z)
103
+ z = self.act1(z)
104
+ z = self.dw(z)
105
+ z = self.gn(z)
106
+ z = self.act2(z)
107
+ z = self.pw2(z)
108
+ z = self.dropout(z).transpose(1, 2)
109
+ return self.drop_path(z)
110
+
111
+
112
+ class Downsample1D(nn.Module):
113
+ def __init__(self, dim, kernel_size=3, stride=2, lowpass_init=True):
114
+ super().__init__()
115
+ self.dw = nn.Conv1d(dim, dim, kernel_size=kernel_size,
116
+ stride=stride, padding=kernel_size//2,
117
+ groups=dim, bias=False)
118
+ self.pw = nn.Conv1d(dim, dim, kernel_size=1, bias=True)
119
+ if lowpass_init:
120
+ with torch.no_grad():
121
+ w = torch.zeros_like(self.dw.weight)
122
+ w[:, 0, :] = 1.0 / kernel_size
123
+ self.dw.weight.copy_(w)
124
+
125
+ def forward(self, x):
126
+ x = x.transpose(1, 2)
127
+ x = self.pw(self.dw(x))
128
+ return x.transpose(1, 2)
129
+
130
+
131
+ class Upsample1D(nn.Module):
132
+ def __init__(self, dim, mode: str = 'nearest'):
133
+ super().__init__()
134
+ assert mode in (
135
+ 'nearest', 'linear'), "Upsample1D mode must be 'nearest' or 'linear'"
136
+ self.mode = mode
137
+ self.proj = nn.Conv1d(dim, dim, kernel_size=1, bias=True)
138
+
139
+ def forward(self, x, target_len: int):
140
+ x = x.transpose(1, 2)
141
+ if self.mode == 'nearest':
142
+ x = F.interpolate(x, size=target_len, mode='nearest')
143
+ else:
144
+ x = F.interpolate(x, size=target_len,
145
+ mode='linear', align_corners=False)
146
+ x = self.proj(x)
147
+ return x.transpose(1, 2)
148
+
149
+
150
+ class ConvTextBlock(nn.Module):
151
+ def __init__(self,
152
+ dim,
153
+ num_heads,
154
+ num_patches,
155
+ mlp_ratio=4.0,
156
+ ff_dropout=0.1,
157
+ attn_dropout=0.0,
158
+ conv_dropout=0.0,
159
+ conv_kernel_size=3,
160
+ conv_expansion=1.0,
161
+ norm_layer=nn.LayerNorm,
162
+ drop_path=0.0,
163
+ layerscale_init=1e-5):
164
+ super().__init__()
165
+
166
+ ff_hidden = int(dim * mlp_ratio)
167
+
168
+ self.attn = Attention(dim, num_patches, num_heads=num_heads,
169
+ qkv_bias=True, attn_drop=attn_dropout, proj_drop=ff_dropout)
170
+
171
+ self.ffn1 = FeedForward(
172
+ dim, ff_hidden, dropout=ff_dropout, activation=nn.SiLU)
173
+ self.conv = ConvModule(dim, kernel_size=conv_kernel_size,
174
+ dropout=conv_dropout, drop_path=0.0,
175
+ expansion=conv_expansion, pre_norm=False, activation=nn.SiLU)
176
+ self.ffn2 = FeedForward(
177
+ dim, ff_hidden, dropout=ff_dropout, activation=nn.SiLU)
178
+
179
+ self.postln_attn = norm_layer(dim, elementwise_affine=True)
180
+ self.postln_ffn1 = norm_layer(dim, elementwise_affine=True)
181
+ self.postln_conv = norm_layer(dim, elementwise_affine=True)
182
+ self.postln_ffn2 = norm_layer(dim, elementwise_affine=True)
183
+
184
+ self.dp_attn = DropPath(
185
+ drop_path) if drop_path > 0.0 else nn.Identity()
186
+ self.dp_ffn1 = DropPath(
187
+ drop_path) if drop_path > 0.0 else nn.Identity()
188
+ self.dp_conv = DropPath(
189
+ drop_path) if drop_path > 0.0 else nn.Identity()
190
+ self.dp_ffn2 = DropPath(
191
+ drop_path) if drop_path > 0.0 else nn.Identity()
192
+
193
+ self.ls_attn = LayerScale(dim, init_values=layerscale_init)
194
+ self.ls_ffn1 = LayerScale(dim, init_values=layerscale_init)
195
+ self.ls_conv = LayerScale(dim, init_values=layerscale_init)
196
+ self.ls_ffn2 = LayerScale(dim, init_values=layerscale_init)
197
+
198
+ def forward(self, x):
199
+ x = self.postln_attn(x + self.ls_attn(self.dp_attn(self.attn(x))))
200
+ x = self.postln_ffn1(
201
+ x + self.ls_ffn1(0.5 * self.dp_ffn1(self.ffn1(x))))
202
+ x = self.postln_conv(x + self.ls_conv(self.dp_conv(self.conv(x))))
203
+ x = self.postln_ffn2(
204
+ x + self.ls_ffn2(0.5 * self.dp_ffn2(self.ffn2(x))))
205
+ return x
206
+
207
+
208
+ def get_2d_sincos_pos_embed(embed_dim, grid_size):
209
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
210
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
211
+ grid = np.meshgrid(grid_w, grid_h)
212
+ grid = np.stack(grid, axis=0)
213
+
214
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
215
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
216
+ return pos_embed
217
+
218
+
219
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
220
+ assert embed_dim % 2 == 0
221
+
222
+ emb_h = get_1d_sincos_pos_embed_from_grid(
223
+ embed_dim // 2, grid[0])
224
+ emb_w = get_1d_sincos_pos_embed_from_grid(
225
+ embed_dim // 2, grid[1])
226
+
227
+ emb = np.concatenate([emb_h, emb_w], axis=1)
228
+ return emb
229
+
230
+
231
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
232
+ assert embed_dim % 2 == 0
233
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
234
+ omega /= embed_dim / 2.
235
+ omega = 1. / 10000 ** omega
236
+
237
+ pos = pos.reshape(-1)
238
+ out = np.einsum('m,d->md', pos, omega)
239
+
240
+ emb_sin = np.sin(out)
241
+ emb_cos = np.cos(out)
242
+
243
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
244
+ return emb
245
+
246
+
247
+ class HTR_ConvText(nn.Module):
248
+ def __init__(
249
+ self,
250
+ nb_cls=80,
251
+ img_size=[512, 64],
252
+ patch_size=[4, 32],
253
+ embed_dim=1024,
254
+ depth=24,
255
+ num_heads=16,
256
+ mlp_ratio=4.0,
257
+ norm_layer=nn.LayerNorm,
258
+ conv_kernel_size: int = 3,
259
+ dropout: float = 0.1,
260
+ drop_path: float = 0.1,
261
+ down_after: int = 2,
262
+ up_after: int = 4,
263
+ ds_kernel: int = 3,
264
+ max_seq_len: int = 1024,
265
+ upsample_mode: str = 'nearest',
266
+ ):
267
+ super().__init__()
268
+
269
+ self.patch_embed = resnet18.ResNet18(embed_dim)
270
+ self.embed_dim = embed_dim
271
+ self.max_rel_pos = int(max_seq_len)
272
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
273
+
274
+ dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
275
+ self.blocks = nn.ModuleList([
276
+ ConvTextBlock(embed_dim, num_heads, self.max_rel_pos,
277
+ mlp_ratio=mlp_ratio,
278
+ ff_dropout=dropout, attn_dropout=dropout,
279
+ conv_dropout=dropout, conv_kernel_size=conv_kernel_size,
280
+ conv_expansion=1.0,
281
+ norm_layer=norm_layer, drop_path=dpr[i],
282
+ layerscale_init=1e-5)
283
+ for i in range(depth)
284
+ ])
285
+
286
+ self.norm = norm_layer(embed_dim, elementwise_affine=True)
287
+ self.head = torch.nn.Linear(embed_dim, nb_cls)
288
+ self.down_after = down_after
289
+ self.up_after = up_after
290
+ self.down1 = Downsample1D(embed_dim, kernel_size=ds_kernel)
291
+ self.up1 = Upsample1D(embed_dim, mode=upsample_mode)
292
+ self.initialize_weights()
293
+
294
+ def initialize_weights(self):
295
+ torch.nn.init.normal_(self.mask_token, std=.02)
296
+ self.apply(self._init_weights)
297
+
298
+ def _init_weights(self, m):
299
+ if isinstance(m, nn.Linear):
300
+ torch.nn.init.xavier_uniform_(m.weight)
301
+ if m.bias is not None:
302
+ nn.init.constant_(m.bias, 0)
303
+ elif isinstance(m, nn.LayerNorm):
304
+ nn.init.constant_(m.bias, 0)
305
+ nn.init.constant_(m.weight, 1.0)
306
+
307
+ def mask_random_1d(self, x, ratio):
308
+ B, L, _ = x.shape
309
+ mask = torch.ones(B, L, dtype=torch.bool).to(x.device)
310
+ if ratio <= 0.0 or ratio > 1.0:
311
+ return mask
312
+ num = int(round(ratio * L))
313
+ if num <= 0:
314
+ return mask
315
+ noise = torch.rand(B, L).to(x.device)
316
+ idx = noise.argsort(dim=1)[:, :num]
317
+ mask.scatter_(1, idx, False)
318
+ return mask
319
+
320
+ def mask_block_1d(self, x, ratio: float, max_block_length: int):
321
+ B, L, _ = x.shape
322
+ device = x.device
323
+
324
+ if ratio <= 0.0:
325
+ return torch.ones(B, L, 1, dtype=torch.bool, device=device)
326
+ if ratio >= 1.0:
327
+ return torch.zeros(B, L, 1, dtype=torch.bool, device=device)
328
+
329
+ target_mask_tokens = int(round(ratio * L))
330
+ K = target_mask_tokens // max_block_length
331
+ K = max(K, 1)
332
+ starts = torch.randint(0, max(1, L - max_block_length + 1), (B, K), device=device)
333
+ lengths = torch.randint(1, max_block_length + 1, (B, K), device=device)
334
+ positions = torch.arange(L, device=device).view(1, 1, L)
335
+ starts_exp = starts.unsqueeze(-1)
336
+ ends_exp = (starts + lengths).unsqueeze(-1).clamp(max=L)
337
+ blocks_mask = (positions >= starts_exp) & (positions < ends_exp)
338
+ masked_any = blocks_mask.any(dim=1)
339
+ keep_mask = ~masked_any
340
+ return keep_mask.unsqueeze(-1)
341
+
342
+
343
+ def mask_span_1d(self, x, ratio: float, max_span_length: int):
344
+ B, L, _ = x.shape
345
+ device = x.device
346
+
347
+ if ratio <= 0.0:
348
+ return torch.ones(B, L, 1, dtype=torch.bool, device=device)
349
+ if ratio >= 1.0:
350
+ return torch.zeros(B, L, 1, dtype=torch.bool, device=device)
351
+
352
+ target_mask_tokens = int(round(ratio * L))
353
+ K = target_mask_tokens // max_span_length
354
+ K = max(K, 1)
355
+ starts = torch.randint(0, max(1, L - max_span_length + 1), (B, K), device=device)
356
+ lengths = torch.full((B, K), max_span_length, device=device)
357
+ positions = torch.arange(L, device=device).view(1, 1, L)
358
+ starts_exp = starts.unsqueeze(-1)
359
+ ends_exp = (starts + lengths).unsqueeze(-1).clamp(max=L)
360
+ spans_mask = (positions >= starts_exp) & (positions < ends_exp)
361
+ masked_any = spans_mask.any(dim=1)
362
+ keep_mask = ~masked_any
363
+ return keep_mask.unsqueeze(-1)
364
+
365
+
366
+ def forward_features(self, x, use_masking=False,
367
+ mask_mode="span",
368
+ mask_ratio=0.5, block_span=4, max_span_length=8):
369
+ x = self.patch_embed(x)
370
+ B, C, W, H = x.shape
371
+ assert C == self.embed_dim, f"Expected embed_dim {self.embed_dim}, got {C}"
372
+ x = x.view(B, C, -1).permute(0, 2, 1)
373
+
374
+ masked_positions_1d = None
375
+ if use_masking:
376
+ if mask_mode == "random":
377
+ keep_mask_1d = self.mask_random_1d(x, mask_ratio).float()
378
+ mask = keep_mask_1d.unsqueeze(-1)
379
+ elif mask_mode in ("block"):
380
+ keep_mask = self.mask_block_1d(x, mask_ratio, block_span).float()
381
+ keep_mask_1d = keep_mask.squeeze(-1)
382
+ mask = keep_mask
383
+ elif mask_mode in ("span"):
384
+ keep_mask = self.mask_span_1d(
385
+ x, mask_ratio, max_span_length).float()
386
+ keep_mask_1d = keep_mask.squeeze(-1)
387
+ mask = keep_mask
388
+ else:
389
+ warnings.warn(
390
+ f"Unknown mask_mode '{mask_mode}', defaulting to span.")
391
+ keep_mask = self.mask_span_1d(
392
+ x, mask_ratio, max_span_length).float()
393
+ keep_mask_1d = keep_mask.squeeze(-1)
394
+ mask = keep_mask
395
+ masked_positions_1d = (1.0 - keep_mask_1d).clamp(min=0.0, max=1.0)
396
+ x = mask * x + (1.0 - mask) * \
397
+ self.mask_token.expand(x.size(0), x.size(1), x.size(2))
398
+ skip_hi = None
399
+ for i, blk in enumerate(self.blocks, 1):
400
+ x = blk(x)
401
+ if i == self.down_after:
402
+ skip_hi = x
403
+ if (x.size(1) % 2) == 1:
404
+ x = torch.cat([x, x[:, -1:, :]], dim=1)
405
+ x = self.down1(x)
406
+ if i == self.up_after:
407
+ assert skip_hi is not None, "Upsample requires a stored skip."
408
+ x = self.up1(x, target_len=skip_hi.size(1))
409
+ x = x + skip_hi
410
+
411
+ x = self.norm(x)
412
+ return x, masked_positions_1d
413
+
414
+ def forward(self, x, use_masking=False, return_features=False, return_mask=False,
415
+ mask_mode="span", mask_ratio=None, block_span=None, max_span_length=None):
416
+ feats, masked_positions_1d = self.forward_features(
417
+ x, use_masking=use_masking, mask_mode=mask_mode, mask_ratio=mask_ratio, block_span=block_span, max_span_length=max_span_length)
418
+ logits = self.head(feats)
419
+ if return_features and return_mask:
420
+ return logits, feats, (masked_positions_1d if masked_positions_1d is not None else None)
421
+ if return_features:
422
+ return logits, feats
423
+ if return_mask:
424
+ return logits, (masked_positions_1d if masked_positions_1d is not None else None)
425
+ return logits
426
+
427
+
428
+ def create_model(nb_cls, img_size, mlp_ratio=4, **kwargs):
429
+ model = HTR_ConvText(
430
+ nb_cls,
431
+ img_size=img_size,
432
+ patch_size=(4, 64),
433
+ embed_dim=512,
434
+ depth=8,
435
+ num_heads=8,
436
+ mlp_ratio=mlp_ratio,
437
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
438
+ conv_kernel_size=7,
439
+ down_after=3,
440
+ up_after=7,
441
+ ds_kernel=3,
442
+ max_seq_len=128,
443
+ upsample_mode='nearest',
444
+ **kwargs,
445
+ )
446
+ return model
model/layer.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from typing import Optional, Union, Tuple
4
+
5
+
6
+ class ConvLayer2d(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_channels: int,
10
+ out_channels: int,
11
+ kernel_size: Union[int, Tuple[int, int]],
12
+ stride: int = 1,
13
+ padding: int = 0,
14
+ dilation: int = 1,
15
+ groups: int = 1,
16
+ bias: bool = False,
17
+ use_norm: bool = True,
18
+ use_act: bool = True,
19
+ norm_layer: Optional[nn.Module] = None,
20
+ act_layer: Optional[nn.Module] = None,
21
+ ):
22
+ super().__init__()
23
+ layers = []
24
+ layers.append(
25
+ nn.Conv2d(
26
+ in_channels=in_channels,
27
+ out_channels=out_channels,
28
+ kernel_size=kernel_size,
29
+ stride=stride,
30
+ padding=padding,
31
+ dilation=dilation,
32
+ groups=groups,
33
+ bias=bias
34
+ )
35
+ )
36
+ if use_norm:
37
+ if norm_layer is None:
38
+ norm_layer = nn.BatchNorm2d(out_channels)
39
+ layers.append(norm_layer)
40
+ if use_act:
41
+ if act_layer is None:
42
+ act_layer = nn.ReLU(inplace=True)
43
+ layers.append(act_layer)
44
+
45
+ self.block = nn.Sequential(*layers)
46
+
47
+ def forward(self, x):
48
+ return self.block(x)
49
+
50
+ # PEG from https://arxiv.org/abs/2102.10882
51
+
52
+
53
+ class PosCNN(nn.Module):
54
+ def __init__(self, in_chans, embed_dim=None, s=1):
55
+ super(PosCNN, self).__init__()
56
+ self.proj = nn.Sequential(
57
+ nn.Conv2d(in_chans, embed_dim, 3, s, 1,
58
+ bias=True, groups=embed_dim),
59
+ )
60
+ self.s = s
61
+
62
+ def forward(self, x, H, W):
63
+ B, N, C = x.shape
64
+
65
+ feat_token = x
66
+ cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
67
+ if self.s == 1:
68
+ x = self.proj(cnn_feat) + cnn_feat
69
+ else:
70
+ x = self.proj(cnn_feat)
71
+ x = x.flatten(2).transpose(1, 2)
72
+ return x
73
+
74
+ def no_weight_decay(self):
75
+ return ["proj.%d.weight" % i for i in range(4)]
model/resnet18.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ #
5
+ # For licensing see accompanying LICENSE file.
6
+ # Copyright (C) 2023 Apple Inc. All Rights Reserved.
7
+ #
8
+
9
+ import math
10
+ from typing import Dict, Optional, Sequence, Tuple, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import Tensor, nn
15
+ from torch.nn import functional as F
16
+ from .layer import ConvLayer2d, PosCNN
17
+ from timm.models.vision_transformer import Mlp, DropPath
18
+
19
+ from typing import Any
20
+ class BaseModule(nn.Module):
21
+ """Base class for all modules"""
22
+
23
+ def __init__(self, *args, **kwargs):
24
+ super(BaseModule, self).__init__()
25
+
26
+ def forward(self, x: Any, *args, **kwargs) -> Any:
27
+ raise NotImplementedError
28
+
29
+ def __repr__(self):
30
+ return "{}".format(self.__class__.__name__)
31
+
32
+ class Attention(nn.Module):
33
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
34
+ super().__init__()
35
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
36
+ self.num_heads = num_heads
37
+ head_dim = dim // num_heads
38
+ self.scale = head_dim ** -0.5
39
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
40
+ self.attn_drop = nn.Dropout(attn_drop)
41
+ self.proj = nn.Linear(dim, dim)
42
+ self.proj_drop = nn.Dropout(proj_drop)
43
+
44
+ def forward(self, x):
45
+ B, N, C = x.shape
46
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
47
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
48
+
49
+ attn = (q @ k.transpose(-2, -1)) * self.scale
50
+ attn = attn.softmax(dim=-1)
51
+ attn = self.attn_drop(attn)
52
+
53
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
54
+ x = self.proj(x)
55
+ x = self.proj_drop(x)
56
+ return x
57
+
58
+ class LayerScale(nn.Module):
59
+ def __init__(self, dim, init_values=1e-5, inplace=False):
60
+ super().__init__()
61
+ self.inplace = inplace
62
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
63
+
64
+ def forward(self, x):
65
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
66
+
67
+
68
+ class Block(nn.Module):
69
+
70
+ def __init__(
71
+ self,
72
+ dim,
73
+ num_heads,
74
+ mlp_ratio=4.,
75
+ qkv_bias=False,
76
+ drop=0.0,
77
+ attn_drop=0.,
78
+ init_values=None,
79
+ drop_path=0.,
80
+ act_layer=nn.GELU,
81
+ norm_layer=nn.LayerNorm
82
+ ):
83
+ super().__init__()
84
+ self.norm1 = norm_layer(dim, elementwise_affine=True)
85
+
86
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
87
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
88
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
89
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
90
+
91
+ self.norm2 = norm_layer(dim, elementwise_affine=True)
92
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
93
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
94
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
95
+
96
+ def forward(self, x):
97
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
98
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
99
+ return x
100
+ class MobileViTBlock(BaseModule):
101
+ """
102
+ This class defines the `MobileViT block <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
103
+
104
+ Args:
105
+ opts: command line arguments
106
+ in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
107
+ transformer_dim (int): Input dimension to the transformer unit
108
+ ffn_dim (int): Dimension of the FFN block
109
+ n_transformer_blocks (Optional[int]): Number of transformer blocks. Default: 2
110
+ head_dim (Optional[int]): Head dimension in the multi-head attention. Default: 32
111
+ attn_dropout (Optional[float]): Dropout in multi-head attention. Default: 0.0
112
+ dropout (Optional[float]): Dropout rate. Default: 0.0
113
+ ffn_dropout (Optional[float]): Dropout between FFN layers in transformer. Default: 0.0
114
+ patch_h (Optional[int]): Patch height for unfolding operation. Default: 8
115
+ patch_w (Optional[int]): Patch width for unfolding operation. Default: 8
116
+ transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm
117
+ conv_ksize (Optional[int]): Kernel size to learn local representations in MobileViT block. Default: 3
118
+ dilation (Optional[int]): Dilation rate in convolutions. Default: 1
119
+ no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ in_channels = 128,
125
+ transformer_dim = 128,
126
+ n_transformer_blocks = 2,
127
+ head_dim = 64,
128
+ attn_dropout = 0.0,
129
+ dropout = 0.0,
130
+ patch_h = 2,
131
+ patch_w = 2,
132
+ conv_ksize = 3,
133
+ dilation = 1,
134
+ no_fusion = True,
135
+ ) -> None:
136
+ conv_3x3_in = ConvLayer2d(
137
+ in_channels=in_channels,
138
+ out_channels=in_channels,
139
+ kernel_size=conv_ksize,
140
+ stride=1,
141
+ use_norm=True,
142
+ use_act=True,
143
+ dilation=dilation,
144
+ padding = 1,
145
+ )
146
+ conv_1x1_in = ConvLayer2d(
147
+ in_channels=in_channels,
148
+ out_channels=transformer_dim,
149
+ kernel_size=1,
150
+ stride=1,
151
+ use_norm=False,
152
+ use_act=False,
153
+ )
154
+
155
+ conv_1x1_out = ConvLayer2d(
156
+ in_channels=transformer_dim,
157
+ out_channels=in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ use_norm=True,
161
+ use_act=True,
162
+ )
163
+ conv_3x3_out = None
164
+ if not no_fusion:
165
+ conv_3x3_out = ConvLayer2d(
166
+ in_channels=2 * in_channels,
167
+ out_channels=in_channels,
168
+ kernel_size=conv_ksize,
169
+ stride=1,
170
+ padding = 1,
171
+ use_norm=True,
172
+ use_act=True,
173
+ )
174
+ super().__init__()
175
+ self.local_rep = nn.Sequential()
176
+ self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
177
+ self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)
178
+ self.pos_pe = PosCNN(in_chans=transformer_dim, embed_dim=transformer_dim)
179
+ assert transformer_dim % head_dim == 0
180
+ num_heads = transformer_dim // head_dim
181
+ global_rep = [
182
+ Block(
183
+ dim=transformer_dim,
184
+ num_heads=num_heads,
185
+ mlp_ratio = 4.0,
186
+ qkv_bias = True,
187
+ attn_drop = attn_dropout,
188
+ drop=dropout,
189
+ norm_layer=nn.LayerNorm,
190
+ )
191
+ for _ in range(n_transformer_blocks)
192
+ ]
193
+ global_rep.append(nn.LayerNorm(transformer_dim))
194
+
195
+ self.global_rep = nn.Sequential(*global_rep)
196
+
197
+ self.conv_proj = conv_1x1_out
198
+
199
+ self.fusion = conv_3x3_out
200
+
201
+ self.patch_h = patch_h
202
+ self.patch_w = patch_w
203
+ self.patch_area = self.patch_w * self.patch_h
204
+
205
+ self.cnn_in_dim = in_channels
206
+ self.cnn_out_dim = transformer_dim
207
+ self.n_heads = num_heads
208
+ self.dropout = dropout
209
+ self.attn_dropout = attn_dropout
210
+ self.dilation = dilation
211
+ self.n_blocks = n_transformer_blocks
212
+ self.conv_ksize = conv_ksize
213
+
214
+
215
+
216
+ def unfolding(self, feature_map: Tensor) -> Tuple[Tensor, Dict]:
217
+ patch_w, patch_h = self.patch_w, self.patch_h
218
+ patch_area = int(patch_w * patch_h)
219
+ batch_size, in_channels, orig_h, orig_w = feature_map.shape
220
+
221
+ new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
222
+ new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)
223
+
224
+ interpolate = False
225
+ if new_w != orig_w or new_h != orig_h:
226
+ # Note: Padding can be done, but then it needs to be handled in attention function.
227
+ feature_map = F.interpolate(
228
+ feature_map, size=(new_h, new_w), mode="bilinear", align_corners=False
229
+ )
230
+ interpolate = True
231
+
232
+ # number of patches along width and height
233
+ num_patch_w = new_w // patch_w # n_w
234
+ num_patch_h = new_h // patch_h # n_h
235
+ num_patches = num_patch_h * num_patch_w # N
236
+
237
+ # [B, C, H, W] --> [B * C * n_h, p_h, n_w, p_w]
238
+ reshaped_fm = feature_map.reshape(
239
+ batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w
240
+ )
241
+ # [B * C * n_h, p_h, n_w, p_w] --> [B * C * n_h, n_w, p_h, p_w]
242
+ transposed_fm = reshaped_fm.transpose(1, 2)
243
+ # [B * C * n_h, n_w, p_h, p_w] --> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
244
+ reshaped_fm = transposed_fm.reshape(
245
+ batch_size, in_channels, num_patches, patch_area
246
+ )
247
+ # [B, C, N, P] --> [B, P, N, C]
248
+ transposed_fm = reshaped_fm.transpose(1, 3)
249
+ # [B, P, N, C] --> [BP, N, C]
250
+ patches = transposed_fm.reshape(batch_size * patch_area, num_patches, -1)
251
+
252
+ info_dict = {
253
+ "orig_size": (orig_h, orig_w),
254
+ "batch_size": batch_size,
255
+ "interpolate": interpolate,
256
+ "total_patches": num_patches,
257
+ "num_patches_w": num_patch_w,
258
+ "num_patches_h": num_patch_h,
259
+ }
260
+
261
+ return patches, info_dict
262
+
263
+ def folding(self, patches: Tensor, info_dict: Dict) -> Tensor:
264
+ n_dim = patches.dim()
265
+ assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(
266
+ patches.shape
267
+ )
268
+ # [BP, N, C] --> [B, P, N, C]
269
+ patches = patches.contiguous().view(
270
+ info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
271
+ )
272
+
273
+ batch_size, pixels, num_patches, channels = patches.size()
274
+ num_patch_h = info_dict["num_patches_h"]
275
+ num_patch_w = info_dict["num_patches_w"]
276
+
277
+ # [B, P, N, C] --> [B, C, N, P]
278
+ patches = patches.transpose(1, 3)
279
+
280
+ # [B, C, N, P] --> [B*C*n_h, n_w, p_h, p_w]
281
+ feature_map = patches.reshape(
282
+ batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w
283
+ )
284
+ # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w]
285
+ feature_map = feature_map.transpose(1, 2)
286
+ # [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W]
287
+ feature_map = feature_map.reshape(
288
+ batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w
289
+ )
290
+ if info_dict["interpolate"]:
291
+ feature_map = F.interpolate(
292
+ feature_map,
293
+ size=info_dict["orig_size"],
294
+ mode="bilinear",
295
+ align_corners=False,
296
+ )
297
+ return feature_map
298
+
299
+ def forward(self, x: Tensor) -> Tensor:
300
+ res = x
301
+
302
+ fm = self.local_rep(x)
303
+
304
+ # convert feature map to patches
305
+ patches, info_dict = self.unfolding(fm)
306
+ num_patch_h = info_dict["num_patches_h"]
307
+ num_patch_w = info_dict["num_patches_w"]
308
+ # learn global representations
309
+
310
+ for j, transformer_layer in enumerate(self.global_rep):
311
+ patches = transformer_layer(patches)
312
+ if j == 0:
313
+ patches = self.pos_pe(patches, num_patch_h, num_patch_w) # PEG here
314
+ # [B x Patch x Patches x C] --> [B x C x Patches x Patch]
315
+ fm = self.folding(patches=patches, info_dict=info_dict)
316
+
317
+ fm = self.conv_proj(fm)
318
+
319
+ if self.fusion is not None:
320
+ fm = self.fusion(torch.cat((res, fm), dim=1))
321
+ return fm
322
+
323
+ def conv3x3(in_planes, out_planes, stride=1):
324
+
325
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
326
+ stride=stride, padding=1, bias=False)
327
+
328
+
329
+ class BasicBlock(nn.Module):
330
+ expansion = 1
331
+
332
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
333
+ super(BasicBlock, self).__init__()
334
+ self.conv1 = conv3x3(inplanes, planes, stride)
335
+ self.bn1 = nn.BatchNorm2d(planes, eps=1e-05)
336
+ self.relu = nn.ReLU(inplace=True)
337
+ self.conv2 = conv3x3(planes, planes)
338
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05)
339
+ self.downsample = downsample
340
+ self.stride = stride
341
+
342
+ def forward(self, x):
343
+ residual = x
344
+
345
+ out = self.conv1(x)
346
+ out = self.bn1(out)
347
+ out = self.relu(out)
348
+
349
+ out = self.conv2(out)
350
+ out = self.bn2(out)
351
+
352
+ if self.downsample is not None:
353
+ residual = self.downsample(x)
354
+
355
+ out += residual
356
+ out = self.relu(out)
357
+
358
+ return out
359
+
360
+
361
+ class ResNet18(nn.Module):
362
+
363
+ def __init__(self, nb_feat=384):
364
+
365
+ self.inplanes = nb_feat // 4
366
+ super(ResNet18, self).__init__()
367
+ self.conv1 = nn.Conv2d(
368
+ 1, nb_feat // 4, kernel_size=3, stride=(2, 1), padding=1, bias=False)
369
+ self.bn1 = nn.BatchNorm2d(nb_feat // 4, eps=1e-05)
370
+ self.relu = nn.ReLU(inplace=True)
371
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=(2, 1), padding=1)
372
+ self.layer1 = self._make_layer(
373
+ BasicBlock, nb_feat // 4, 2, stride=(2, 1))
374
+ self.mobilevit_block1 = MobileViTBlock(in_channels=nb_feat // 4, transformer_dim=nb_feat // 4, n_transformer_blocks=1, head_dim=64, attn_dropout=0.0, dropout=0.0, patch_h=2, patch_w=2, conv_ksize=3, dilation=1, no_fusion=True)
375
+ self.layer2 = self._make_layer(BasicBlock, nb_feat // 2, 2, stride=2)
376
+ self.mobilevit_block2 = MobileViTBlock(in_channels=nb_feat // 2, transformer_dim=nb_feat//2, n_transformer_blocks=1, head_dim=64, attn_dropout=0.0, dropout=0.0, patch_h=2, patch_w=2, conv_ksize=3, dilation=1, no_fusion=True)
377
+ self.layer3 = self._make_layer(BasicBlock, nb_feat, 2, stride=2)
378
+ self.mobilevit_block3 = MobileViTBlock(in_channels=nb_feat, transformer_dim=nb_feat, n_transformer_blocks=1, head_dim=64, attn_dropout=0.0, dropout=0.0, patch_h=2, patch_w=2, conv_ksize=3, dilation=1, no_fusion=True)
379
+
380
+ def _make_layer(self, block, planes, blocks, stride=1):
381
+ downsample = None
382
+ if stride != 1 or self.inplanes != planes * block.expansion:
383
+ downsample = nn.Sequential(
384
+ nn.Conv2d(self.inplanes, planes * block.expansion,
385
+ kernel_size=1, stride=stride, bias=False),
386
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05),
387
+ )
388
+
389
+ layers = []
390
+ layers.append(block(self.inplanes, planes, stride, downsample))
391
+ self.inplanes = planes * block.expansion
392
+ for i in range(1, blocks):
393
+ layers.append(block(self.inplanes, planes, 1, None))
394
+
395
+ return nn.Sequential(*layers)
396
+
397
+ def forward(self, x):
398
+ x = self.conv1(x)
399
+ x = self.bn1(x)
400
+ x = self.relu(x)
401
+ x = self.maxpool(x)
402
+
403
+ x = self.layer1(x)
404
+ x = self.mobilevit_block1(x)
405
+ x = self.layer2(x)
406
+ x = self.mobilevit_block2(x)
407
+ x = self.layer3(x)
408
+ x = self.mobilevit_block3(x)
409
+ x = self.maxpool(x)
410
+
411
+ return x
model/tcm_head.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def build_tcm_vocab(converter, add_tokens=("<pad>",)):
7
+ base = list(converter.character)
8
+ stoi = {ch: i for i, ch in enumerate(base)}
9
+ for t in add_tokens:
10
+ if t not in stoi:
11
+ stoi[t] = len(stoi)
12
+ itos = [''] * len(stoi)
13
+ for k, v in stoi.items():
14
+ itos[v] = k
15
+ pad_id = stoi["<pad>"]
16
+ return stoi, itos, pad_id
17
+
18
+
19
+ def texts_to_ids(texts, stoi):
20
+ return [torch.tensor([stoi[ch] for ch in t], dtype=torch.long) for t in texts]
21
+
22
+
23
+ def make_context_batch(texts, stoi, sub_str_len=5, device='cuda'):
24
+ ids = [torch.tensor([stoi[ch] for ch in t], dtype=torch.long, device=device) for t in texts]
25
+ B = len(ids); Lmax = max(t.size(0) for t in ids); S = sub_str_len
26
+ PAD = stoi["<pad>"]
27
+
28
+ left = torch.full((B, Lmax, S), PAD, dtype=torch.long, device=device)
29
+ right = torch.full((B, Lmax, S), PAD, dtype=torch.long, device=device)
30
+ tgt = torch.full((B, Lmax), PAD, dtype=torch.long, device=device)
31
+ mask = torch.zeros((B, Lmax), dtype=torch.float32, device=device)
32
+
33
+ for b, seq in enumerate(ids):
34
+ L = seq.size(0)
35
+ tgt[b, :L] = seq
36
+ mask[b, :L] = 1.0
37
+ for i in range(L):
38
+ l_ctx = seq[max(0, i-S):i]
39
+ # left pad with PAD
40
+ if l_ctx.numel() < S:
41
+ l_ctx = torch.cat([torch.full((S - l_ctx.numel(),), PAD, device=device), l_ctx], dim=0)
42
+ left[b, i] = l_ctx[-S:]
43
+
44
+ r_ctx = seq[i+1:min(L, i+1+S)]
45
+ # right pad with PAD
46
+ if r_ctx.numel() < S:
47
+ r_ctx = torch.cat([r_ctx, torch.full((S - r_ctx.numel(),), PAD, device=device)], dim=0)
48
+ right[b, i] = r_ctx[:S]
49
+
50
+ return left, right, tgt, mask
51
+
52
+
53
+ class TCMHead(nn.Module):
54
+ def __init__(self, d_vis, vocab_size_tcm, pad_id, d_txt=256, sub_str_len=5, p_drop=0.1):
55
+ super().__init__()
56
+ self.vocab_size = vocab_size_tcm
57
+ self.sub_str_len = sub_str_len
58
+
59
+ # critical: padding_idx zeroes the PAD row and keeps it frozen
60
+ self.emb = nn.Embedding(vocab_size_tcm, d_txt, padding_idx=pad_id)
61
+
62
+ # keep direction as learned vectors (not tokens)
63
+ self.dir_left = nn.Parameter(torch.randn(1, 1, d_txt))
64
+ self.dir_right = nn.Parameter(torch.randn(1, 1, d_txt))
65
+
66
+ self.ctx_conv = nn.Conv1d(d_txt, d_txt, kernel_size=3, padding=1)
67
+ self.txt_proj = nn.Linear(d_txt, d_vis)
68
+ self.q_norm = nn.LayerNorm(d_vis)
69
+ self.kv_norm = nn.LayerNorm(d_vis)
70
+ self.dropout = nn.Dropout(p_drop)
71
+ self.classifier = nn.Linear(d_vis, vocab_size_tcm)
72
+
73
+
74
+ def _context_to_query(self, ctx_ids, dir_token):
75
+ E = self.emb(ctx_ids)
76
+ B, L, S, D = E.shape
77
+ x = E.view(B*L, S, D).transpose(1, 2)
78
+ x = self.ctx_conv(x)
79
+ x = x.mean(dim=-1)
80
+ x = x.view(B, L, D)
81
+
82
+ x = x + dir_token
83
+ x = self.txt_proj(x)
84
+ return self.q_norm(x)
85
+
86
+ def _cross_attend(self, Q, F):
87
+ K = self.kv_norm(F)
88
+ V = K
89
+ attn = torch.einsum('bld,bnd->bln', Q, K) / \
90
+ (K.size(-1) ** 0.5)
91
+ A = attn.softmax(dim=-1)
92
+ out = torch.einsum('bln,bnd->bld', A, V)
93
+ return self.dropout(out)
94
+
95
+ def forward(self,
96
+ vis_tokens,
97
+ left_ctx_ids,
98
+ right_ctx_ids,
99
+ tgt_ids,
100
+ tgt_mask,
101
+ focus_mask=None):
102
+ Ql = self._context_to_query(left_ctx_ids, self.dir_left)
103
+ Qr = self._context_to_query(right_ctx_ids, self.dir_right)
104
+
105
+ Fl = self._cross_attend(Ql, vis_tokens)
106
+ Fr = self._cross_attend(Qr, vis_tokens)
107
+
108
+ logits_l = self.classifier(Fl)
109
+ logits_r = self.classifier(Fr)
110
+
111
+ loss_l = F.cross_entropy(
112
+ logits_l.view(-1, self.vocab_size),
113
+ tgt_ids.view(-1),
114
+ reduction='none'
115
+ ).view_as(tgt_ids)
116
+ loss_r = F.cross_entropy(
117
+ logits_r.view(-1, self.vocab_size),
118
+ tgt_ids.view(-1),
119
+ reduction='none'
120
+ ).view_as(tgt_ids)
121
+
122
+ if focus_mask is not None:
123
+ weights = tgt_mask * (1.0 + focus_mask)
124
+ else:
125
+ weights = tgt_mask
126
+
127
+ loss_masked = (loss_l + loss_r) * weights
128
+ denom = torch.clamp(weights.sum(), min=1.0)
129
+ loss_tcm = loss_masked.sum() / (2.0 * denom)
130
+
131
+ return {'loss_tcm': loss_tcm,
132
+ 'logits_l': logits_l,
133
+ 'logits_r': logits_r}
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.24
2
+ pillow>=9.0
3
+ opencv-python>=4.8
4
+ scikit-image>=0.21
5
+ tensorboard>=2.13
6
+ wandb>=0.16
7
+ editdistance>=0.6
8
+ timm>=0.9
run/iam.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python3 train.py --use-wandb --dataset iam --tcm-enable --exp-name "htr-convtext" --wandb-project iam --num-workers 4 --max-lr 1e-3 --warm-up-iter 1000 --weight-decay 0.05 --train-bs 32 --val-bs 8 --max-span-length 8 --mask-ratio 0.4 --attn-mask-ratio 0.1 --img-size 512 64 --proj 8 --dila-ero-max-kernel 2 --dila-ero-iter 1 --proba 0.5 --alpha 1 --total-iter 100001 --data-path /kaggle/input/iam-vt-lines/lines/ --train-data-list /kaggle/input/iam-vt-lines/train.ln --val-data-list /kaggle/input/iam-vt-lines/val.ln --test-data-list /kaggle/input/iam-vt-lines/test.ln --nb-cls 80
run/lam.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python3 train.py --use-wandb --dataset lam --tcm-enable --exp-name "htr-convtext" --wandb-project lam --num-workers 4 --max-lr 1e-3 --warm-up-iter 1000 --weight-decay 0.05 --train-bs 32 --val-bs 8 --max-span-length 8 --mask-ratio 0.4 --attn-mask-ratio 0.1 --img-size 512 64 --proj 8 --dila-ero-max-kernel 2 --dila-ero-iter 1 --proba 0.5 --alpha 1 --total-iter 100001 --data-path /kaggle/input/lam-vt-lines/lines/ --train-data-list /kaggle/input/lam-vt-lines/train.ln --val-data-list /kaggle/input/lam-vt-lines/val.ln --test-data-list /kaggle/input/lam-vt-lines/test.ln --nb-cls 91
run/read2016.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python3 train.py --use-wandb --dataset read2016 --tcm-enable --exp-name "htr-convtext" --wandb-project read2016 --num-workers 4 --max-lr 1e-3 --warm-up-iter 1000 --weight-decay 0.05 --train-bs 32 --val-bs 8 --max-span-length 8 --mask-ratio 0.4 --attn-mask-ratio 0.1 --img-size 512 64 --proj 8 --dila-ero-max-kernel 2 --dila-ero-iter 1 --proba 0.5 --alpha 1 --total-iter 100001 --data-path /kaggle/input/read2016-vt-lines/lines/ --train-data-list /kaggle/input/read2016-vt-lines/train.ln --val-data-list /kaggle/input/read2016-vt-lines/val.ln --test-data-list /kaggle/input/read2016-vt-lines/test.ln --nb-cls 90
run/vnondb.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python3 train.py --use-wandb --dataset vnondb --tcm-enable --exp-name "htr-convtext" --wandb-project vnondb --num-workers 4 --max-lr 1e-3 --warm-up-iter 1000 --weight-decay 0.05 --train-bs 32 --val-bs 8 --max-span-length 8 --mask-ratio 0.4 --attn-mask-ratio 0.1 --img-size 512 64 --proj 8 --dila-ero-max-kernel 2 --dila-ero-iter 1 --proba 0.5 --alpha 1 --total-iter 100001 --data-path /kaggle/input/vnondb/lines/ --train-data-list /kaggle/input/vnondb/train.ln --val-data-list /kaggle/input/vnondb/val.ln --test-data-list /kaggle/input/vnondb/test.ln --nb-cls 162
test.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import os
4
+ import re
5
+ import json
6
+ import valid
7
+ from utils import utils
8
+ from utils import option
9
+ from data import dataset
10
+ from model import htr_convtext
11
+ from collections import OrderedDict
12
+
13
+
14
+ def main():
15
+
16
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
+ torch.manual_seed(args.seed)
18
+
19
+ args.save_dir = os.path.join(args.out_dir, args.exp_name)
20
+ os.makedirs(args.save_dir, exist_ok=True)
21
+ logger = utils.get_logger(args.save_dir)
22
+
23
+ model = htr_convtext.create_model(
24
+ nb_cls=args.nb_cls, img_size=args.img_size[::-1])
25
+
26
+ pth_path = args.resume
27
+ logger.info('loading HWR checkpoint from {}'.format(pth_path))
28
+
29
+ ckpt = torch.load(pth_path, map_location='cpu', weights_only=False)
30
+ model_dict = OrderedDict()
31
+ pattern = re.compile('module.')
32
+
33
+ for k, v in ckpt['state_dict_ema'].items():
34
+ if re.search("module", k):
35
+ model_dict[re.sub(pattern, '', k)] = v
36
+ else:
37
+ model_dict[k] = v
38
+
39
+ model.load_state_dict(model_dict, strict=True)
40
+ model = model.cuda()
41
+
42
+ logger.info('Loading test loader...')
43
+ train_dataset = dataset.myLoadDS(
44
+ args.train_data_list, args.data_path, args.img_size, dataset=args.dataset)
45
+
46
+ test_dataset = dataset.myLoadDS(
47
+ args.test_data_list, args.data_path, args.img_size, ralph=train_dataset.ralph, dataset=args.dataset)
48
+ test_loader = torch.utils.data.DataLoader(test_dataset,
49
+ batch_size=args.val_bs,
50
+ shuffle=False,
51
+ pin_memory=True,
52
+ num_workers=args.num_workers)
53
+
54
+ converter = utils.CTCLabelConverter(train_dataset.ralph.values())
55
+ criterion = torch.nn.CTCLoss(
56
+ reduction='none', zero_infinity=True).to(device)
57
+
58
+ model.eval()
59
+ with torch.no_grad():
60
+ val_loss, val_cer, val_wer, preds, labels = valid.validation(
61
+ model,
62
+ criterion,
63
+ test_loader,
64
+ converter,
65
+ )
66
+
67
+ logger.info(
68
+ f'Test. loss : {val_loss:0.3f} \t CER : {val_cer:0.4f} \t WER : {val_wer:0.4f} ')
69
+
70
+ # Save predictions as JSON
71
+ results = {
72
+ "test_metrics": {
73
+ "loss": float(val_loss),
74
+ "cer": float(val_cer),
75
+ "wer": float(val_wer)
76
+ },
77
+ "predictions": []
78
+ }
79
+
80
+ def _levenshtein(pred_tokens, gt_tokens):
81
+ if pred_tokens == gt_tokens:
82
+ return 0
83
+ lp, lg = len(pred_tokens), len(gt_tokens)
84
+ if lp == 0:
85
+ return lg
86
+ if lg == 0:
87
+ return lp
88
+ prev = list(range(lg + 1))
89
+ for i in range(1, lp + 1):
90
+ cur = [i]
91
+ pi = pred_tokens[i - 1]
92
+ for j in range(1, lg + 1):
93
+ gj = gt_tokens[j - 1]
94
+ cost = 0 if pi == gj else 1
95
+ cur.append(
96
+ min(prev[j] + 1, cur[j - 1] + 1, prev[j - 1] + cost))
97
+ prev = cur
98
+ return prev[-1]
99
+
100
+ def _levenshtein_str(a: str, b: str):
101
+ return _levenshtein(list(a), list(b))
102
+
103
+ def _cer(pred: str, gt: str):
104
+ if len(gt) == 0:
105
+ return 0.0 if len(pred) == 0 else 1.0
106
+ return _levenshtein_str(pred, gt) / len(gt)
107
+
108
+ def _wer(pred: str, gt: str):
109
+ gt_words = gt.split()
110
+ pred_words = pred.split()
111
+ if len(gt_words) == 0:
112
+ return 0.0 if len(pred_words) == 0 else 1.0
113
+ return _levenshtein(pred_words, gt_words) / len(gt_words)
114
+
115
+ for i, (pred, label) in enumerate(zip(preds, labels)):
116
+ if i < len(test_dataset.fns):
117
+ img_path = test_dataset.fns[i]
118
+ img_name = os.path.basename(img_path)
119
+ else:
120
+ img_path = None
121
+ img_name = None
122
+ results["predictions"].append({
123
+ "sample_id": i + 1,
124
+ "image_filename": img_name,
125
+ "image_path": img_path,
126
+ "prediction": pred,
127
+ "ground_truth": label,
128
+ "match": pred == label,
129
+ "cer": round(float(_cer(pred, label)), 6),
130
+ "wer": round(float(_wer(pred, label)), 6)
131
+ })
132
+
133
+ pred_file = os.path.join(args.save_dir, 'predictions.json')
134
+ with open(pred_file, 'w', encoding='utf-8') as f:
135
+ json.dump(results, f, indent=2, ensure_ascii=False)
136
+
137
+
138
+ if __name__ == '__main__':
139
+ args = option.get_args_parser()
140
+ main()
train.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ import torch.backends.cudnn as cudnn
4
+ from torch.utils.tensorboard import SummaryWriter
5
+
6
+ import os
7
+ import json
8
+ import valid
9
+ from utils import utils
10
+ from utils import sam
11
+ from utils import option
12
+ from data import dataset
13
+ from model import htr_convtext
14
+ from functools import partial
15
+ import random
16
+ import numpy as np
17
+ import re
18
+ import importlib
19
+ from model.tcm_head import TCMHead, build_tcm_vocab, make_context_batch
20
+ import wandb
21
+
22
+
23
+ def compute_losses(
24
+ args,
25
+ model,
26
+ tcm_head,
27
+ image,
28
+ texts,
29
+ batch_size,
30
+ criterion_ctc,
31
+ converter,
32
+ nb_iter,
33
+ ctc_lambda,
34
+ tcm_lambda,
35
+ stoi,
36
+ mask_mode='span',
37
+ mask_ratio=0.30,
38
+ block_span=4,
39
+ max_span_length=8,
40
+ pre_tcm_ctx=None,
41
+ use_masking=True,
42
+ ):
43
+ if tcm_head is None or nb_iter < args.tcm_warmup_iters:
44
+ preds = model(image, use_masking=use_masking, mask_mode=mask_mode,
45
+ mask_ratio=mask_ratio, max_span_length=max_span_length)
46
+ feats = None
47
+ else:
48
+ preds, feats, vis_mask = model(
49
+ image,
50
+ use_masking=use_masking,
51
+ return_features=True,
52
+ return_mask=True,
53
+ mask_mode=mask_mode,
54
+ mask_ratio=mask_ratio,
55
+ block_span=block_span,
56
+ max_span_length=max_span_length
57
+ )
58
+ text_ctc, length_ctc = converter.encode(texts)
59
+ text_ctc = text_ctc.to(preds.device)
60
+ length_ctc = length_ctc.to(preds.device)
61
+ preds_sz = torch.full((batch_size,), preds.size(
62
+ 1), dtype=torch.int32, device=preds.device)
63
+ loss_ctc = criterion_ctc(preds.permute(1, 0, 2).log_softmax(2),
64
+ text_ctc, preds_sz, length_ctc).mean()
65
+
66
+ loss_tcm = torch.zeros((), device=preds.device)
67
+ if tcm_head is not None and feats is not None:
68
+ left_ctx, right_ctx, tgt_ids, tgt_mask = pre_tcm_ctx if pre_tcm_ctx is not None else make_context_batch(
69
+ texts, stoi, sub_str_len=args.tcm_sub_len, device=image.device)
70
+ if vis_mask is not None:
71
+ B_v, N_v = vis_mask.shape
72
+ B_t, L_t = tgt_mask.shape
73
+ if N_v != L_t:
74
+ idx = torch.linspace(0, N_v - 1, steps=L_t,
75
+ device=vis_mask.device).long()
76
+ focus_mask = vis_mask[:, idx]
77
+ else:
78
+ focus_mask = vis_mask
79
+ else:
80
+ focus_mask = None
81
+
82
+ out = tcm_head(
83
+ feats,
84
+ left_ctx, right_ctx,
85
+ tgt_ids, tgt_mask,
86
+ focus_mask=focus_mask
87
+ )
88
+ loss_tcm = out['loss_tcm']
89
+
90
+ total = ctc_lambda * loss_ctc + tcm_lambda * loss_tcm
91
+ return total, loss_ctc.detach(), loss_tcm.detach()
92
+
93
+
94
+ def tri_masked_loss(args, model, tcm_head, image, labels, batch_size,
95
+ criterion, converter, nb_iter, ctc_lambda, tcm_lambda, stoi,
96
+ r_rand=0.6, r_block=0.6, block_span=4, r_span=0.4, max_span=8):
97
+ total = 0.0
98
+ total_ctc = 0.0
99
+ total_tcm = 0.0
100
+ plans = [("random", r_rand), ("block", r_block), ("span", r_span)]
101
+
102
+ if tcm_head is not None and nb_iter >= args.tcm_warmup_iters:
103
+ pre_tcm_ctx = make_context_batch(
104
+ labels, stoi, sub_str_len=args.tcm_sub_len, device=image.device)
105
+
106
+ for mode, ratio in plans:
107
+ loss, loss_ctc, loss_tcm = compute_losses(
108
+ args, model, tcm_head, image, labels, batch_size, criterion, converter,
109
+ nb_iter, ctc_lambda, tcm_lambda, stoi,
110
+ mask_mode=mode, mask_ratio=ratio, block_span=block_span, max_span_length=max_span,
111
+ pre_tcm_ctx=pre_tcm_ctx
112
+ )
113
+ total += loss
114
+ total_ctc += loss_ctc
115
+ total_tcm += loss_tcm
116
+
117
+ denom = 3.0
118
+ return total/denom, total_ctc/denom, total_tcm/denom
119
+
120
+
121
+ def main():
122
+
123
+ args = option.get_args_parser()
124
+ torch.manual_seed(args.seed)
125
+
126
+ args.save_dir = os.path.join(args.out_dir, args.exp_name)
127
+ os.makedirs(args.save_dir, exist_ok=True)
128
+
129
+ logger = utils.get_logger(args.save_dir)
130
+ logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
131
+ writer = SummaryWriter(args.save_dir)
132
+
133
+ if getattr(args, 'use_wandb', False):
134
+ try:
135
+ wandb = importlib.import_module('wandb')
136
+ wandb.init(project=getattr(args, 'wandb_project', 'None'), name=args.exp_name,
137
+ config=vars(args), dir=args.save_dir)
138
+ logger.info("Weights & Biases logging enabled")
139
+ except Exception as e:
140
+ logger.warning(
141
+ f"Failed to initialize wandb: {e}. Continuing without wandb.")
142
+ wandb = None
143
+ else:
144
+ wandb = None
145
+
146
+ torch.backends.cudnn.benchmark = True
147
+
148
+ model = htr_convtext.create_model(
149
+ nb_cls=args.nb_cls, img_size=args.img_size[::-1])
150
+
151
+ total_param = sum(p.numel() for p in model.parameters())
152
+ logger.info('total_param is {}'.format(total_param))
153
+
154
+ model.train()
155
+ model = model.cuda()
156
+ ema_decay = args.ema_decay
157
+ logger.info(f"Using EMA decay: {ema_decay}")
158
+ model_ema = utils.ModelEma(model, ema_decay)
159
+ model.zero_grad()
160
+
161
+ resume_path = args.resume
162
+ best_cer, best_wer, start_iter, optimizer_state, train_loss, train_loss_count = utils.load_checkpoint(
163
+ model, model_ema, None, resume_path, logger)
164
+
165
+ logger.info('Loading train loader...')
166
+ train_dataset = dataset.myLoadDS(
167
+ args.train_data_list, args.data_path, args.img_size, dataset=args.dataset)
168
+ train_loader = torch.utils.data.DataLoader(train_dataset,
169
+ batch_size=args.train_bs,
170
+ shuffle=True,
171
+ pin_memory=True,
172
+ num_workers=args.num_workers,
173
+ collate_fn=partial(dataset.SameTrCollate, args=args))
174
+ train_iter = dataset.cycle_data(train_loader)
175
+
176
+ logger.info('Loading val loader...')
177
+ val_dataset = dataset.myLoadDS(
178
+ args.val_data_list, args.data_path, args.img_size, ralph=train_dataset.ralph, dataset=args.dataset)
179
+ val_loader = torch.utils.data.DataLoader(val_dataset,
180
+ batch_size=args.val_bs,
181
+ shuffle=False,
182
+ pin_memory=True,
183
+ num_workers=args.num_workers)
184
+
185
+ criterion = torch.nn.CTCLoss(reduction='none', zero_infinity=True)
186
+ converter = utils.CTCLabelConverter(train_dataset.ralph.values())
187
+
188
+ stoi, itos, pad_id = build_tcm_vocab(converter)
189
+ vocab_size_tcm = len(itos)
190
+ d_vis = model.embed_dim
191
+
192
+ if args.tcm_enable:
193
+ tcm_head = TCMHead(d_vis=d_vis, vocab_size_tcm=vocab_size_tcm, pad_id=pad_id,
194
+ sub_str_len=args.tcm_sub_len).cuda()
195
+ tcm_head.train()
196
+ else:
197
+ tcm_head = None
198
+
199
+ param_groups = list(model.parameters())
200
+ if args.tcm_enable and tcm_head is not None:
201
+ param_groups += list(tcm_head.parameters())
202
+ logger.info(
203
+ f"Optimizing {sum(p.numel() for p in tcm_head.parameters())} tcm params in addition to model params")
204
+ optimizer = sam.SAM(param_groups, torch.optim.AdamW,
205
+ lr=1e-7, betas=(0.9, 0.99), weight_decay=args.weight_decay)
206
+
207
+ if optimizer_state is not None:
208
+ try:
209
+ optimizer.load_state_dict(optimizer_state)
210
+ logger.info("Successfully loaded optimizer state")
211
+ except Exception as e:
212
+ logger.warning(f"Failed to load optimizer state: {e}")
213
+ logger.info(
214
+ "Continuing training without optimizer state (will restart from initial lr/momentum)")
215
+ elif resume_path and os.path.isfile(resume_path):
216
+ try:
217
+ ckpt = torch.load(resume_path, map_location='cpu',
218
+ weights_only=False)
219
+ if 'optimizer' in ckpt:
220
+ optimizer.load_state_dict(ckpt['optimizer'])
221
+ logger.info("Loaded optimizer state from checkpoint directly")
222
+ except Exception as e:
223
+ logger.warning(
224
+ f"Could not load optimizer state from checkpoint: {e}")
225
+
226
+ if resume_path and os.path.isfile(resume_path) and tcm_head is not None:
227
+ try:
228
+ ckpt = torch.load(resume_path, map_location='cpu',
229
+ weights_only=False)
230
+ if 'tcm_head' in ckpt:
231
+ tcm_head.load_state_dict(ckpt['tcm_head'], strict=False)
232
+ logger.info("Restored tcm head state from checkpoint")
233
+ else:
234
+ logger.info(
235
+ "No tcm head state found in checkpoint; training tcm from scratch")
236
+ except Exception as e:
237
+ logger.warning(f"Failed to restore tcm head from checkpoint: {e}")
238
+
239
+ best_cer, best_wer = best_cer, best_wer
240
+ train_loss = train_loss
241
+ train_loss_count = train_loss_count
242
+
243
+ #### ---- train & eval ---- ####
244
+ logger.info('Start training...')
245
+ accum_steps = max(1, int(getattr(args, 'accum_steps', 1)))
246
+ micro_step = 0
247
+ avg_loss_ctc = 0.0
248
+ avg_loss_tcm = 0.0
249
+
250
+ for nb_iter in range(start_iter, args.total_iter):
251
+ optimizer, current_lr = utils.update_lr_cos(
252
+ nb_iter, args.warm_up_iter, args.total_iter, args.max_lr, optimizer)
253
+ optimizer.zero_grad()
254
+ total_loss_this_macro = 0.0
255
+ avg_loss_ctc = 0.0
256
+ avg_loss_tcm = 0.0
257
+ cached_batches = []
258
+ for micro_step in range(accum_steps):
259
+ batch = next(train_iter)
260
+ cached_batches.append(batch)
261
+ image = batch[0].cuda(non_blocking=True)
262
+ text, length = converter.encode(batch[1])
263
+ batch_size = image.size(0)
264
+ if args.use_masking:
265
+ # loss, loss_ctc, loss_tcm = tri_masked_loss(
266
+ # args, model, tcm_head, image, batch[1], batch_size, criterion, converter,
267
+ # nb_iter, args.ctc_lambda, args.tcm_lambda, stoi,
268
+ # r_rand=args.r_rand,
269
+ # r_block=args.r_block,
270
+ # block_span=args.block_span,
271
+ # r_span=args.r_span,
272
+ # max_span=args.max_span
273
+ # )
274
+ loss, loss_ctc, loss_tcm = compute_losses(
275
+ args, model, tcm_head, image, batch[1], batch_size, criterion, converter,
276
+ nb_iter, args.ctc_lambda, args.tcm_lambda, stoi,
277
+ mask_mode='span', mask_ratio=0.4, max_span_length=8, use_masking=True
278
+ )
279
+ else:
280
+ loss, loss_ctc, loss_tcm = compute_losses(
281
+ args, model, tcm_head, image, batch[1], batch_size, criterion, converter,
282
+ nb_iter, args.ctc_lambda, args.tcm_lambda, stoi, use_masking=False
283
+ )
284
+ (loss / accum_steps).backward()
285
+ total_loss_this_macro += loss.item()
286
+ avg_loss_ctc += loss_ctc.mean().item()
287
+ avg_loss_tcm += loss_tcm.mean().item()
288
+
289
+ optimizer.first_step(zero_grad=True)
290
+
291
+ # Recompute with perturbed weights and accumulate again for the second step
292
+ for micro_step in range(accum_steps):
293
+ batch = cached_batches[micro_step]
294
+ image = batch[0].cuda(non_blocking=True)
295
+ text, length = converter.encode(batch[1])
296
+ batch_size = image.size(0)
297
+ if args.use_masking:
298
+ # loss2, loss_ctc, loss_tcm = tri_masked_loss(
299
+ # args, model, tcm_head, image, batch[1], batch_size, criterion, converter,
300
+ # nb_iter, args.ctc_lambda, args.tcm_lambda, stoi,
301
+ # r_rand=args.r_rand,
302
+ # r_block=args.r_block,
303
+ # block_span=args.block_span,
304
+ # r_span=args.r_span,
305
+ # max_span=args.max_span
306
+ # )
307
+ loss2, loss_ctc, loss_tcm = compute_losses(
308
+ args, model, tcm_head, image, batch[1], batch_size, criterion, converter,
309
+ nb_iter, args.ctc_lambda, args.tcm_lambda, stoi,
310
+ mask_mode='span', mask_ratio=0.4, max_span_length=8, use_masking=True
311
+ )
312
+ else:
313
+ loss2, loss_ctc, loss_tcm = compute_losses(
314
+ args, model, tcm_head, image, batch[1], batch_size, criterion, converter,
315
+ nb_iter, args.ctc_lambda, args.tcm_lambda, stoi, use_masking=False
316
+ )
317
+ (loss2 / accum_steps).backward()
318
+ optimizer.second_step(zero_grad=True)
319
+ model.zero_grad()
320
+ model_ema.update(model, num_updates=nb_iter / 2)
321
+
322
+ train_loss += total_loss_this_macro / accum_steps
323
+ train_loss_count += 1
324
+
325
+ if nb_iter % args.print_iter == 0:
326
+ train_loss_avg = train_loss / train_loss_count if train_loss_count > 0 else 0.0
327
+
328
+ logger.info(
329
+ f'Iter : {nb_iter} \t LR : {current_lr:0.5f} \t total : {train_loss_avg:0.5f} \t CTC : {(avg_loss_ctc/accum_steps):0.5f} \t tcm : {(avg_loss_tcm/accum_steps):0.5f} \t ')
330
+
331
+ writer.add_scalar('./Train/lr', current_lr, nb_iter)
332
+ writer.add_scalar('./Train/train_loss', train_loss_avg, nb_iter)
333
+ if wandb is not None:
334
+ wandb.log({
335
+ 'train/lr': current_lr,
336
+ 'train/loss': train_loss_avg,
337
+ 'train/CTC': (avg_loss_ctc/accum_steps),
338
+ 'train/tcm': (avg_loss_tcm/accum_steps),
339
+ 'iter': nb_iter,
340
+ }, step=nb_iter)
341
+ train_loss = 0.0
342
+ train_loss_count = 0
343
+
344
+ if nb_iter % args.eval_iter == 0:
345
+ model.eval()
346
+ with torch.no_grad():
347
+ val_loss, val_cer, val_wer, preds, labels = valid.validation(model_ema.ema,
348
+ criterion,
349
+ val_loader,
350
+ converter)
351
+ if nb_iter % args.eval_iter*5 == 0:
352
+ ckpt_name = f"checkpoint_{best_cer:.4f}_{best_wer:.4f}_{nb_iter}.pth"
353
+ checkpoint = {
354
+ 'model': model.state_dict(),
355
+ 'state_dict_ema': model_ema.ema.state_dict(),
356
+ 'optimizer': optimizer.state_dict(),
357
+ 'nb_iter': nb_iter,
358
+ 'best_cer': best_cer,
359
+ 'best_wer': best_wer,
360
+ 'args': vars(args),
361
+ 'random_state': random.getstate(),
362
+ 'numpy_state': np.random.get_state(),
363
+ 'torch_state': torch.get_rng_state(),
364
+ 'torch_cuda_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
365
+ 'train_loss': train_loss,
366
+ 'train_loss_count': train_loss_count,
367
+ }
368
+ if tcm_head is not None:
369
+ checkpoint['tcm_head'] = tcm_head.state_dict()
370
+ torch.save(checkpoint, os.path.join(
371
+ args.save_dir, ckpt_name))
372
+ if val_cer < best_cer:
373
+ logger.info(
374
+ f'CER improved from {best_cer:.4f} to {val_cer:.4f}!!!')
375
+ best_cer = val_cer
376
+ checkpoint = {
377
+ 'model': model.state_dict(),
378
+ 'state_dict_ema': model_ema.ema.state_dict(),
379
+ 'optimizer': optimizer.state_dict(),
380
+ 'nb_iter': nb_iter,
381
+ 'best_cer': best_cer,
382
+ 'best_wer': best_wer,
383
+ 'args': vars(args),
384
+ 'random_state': random.getstate(),
385
+ 'numpy_state': np.random.get_state(),
386
+ 'torch_state': torch.get_rng_state(),
387
+ 'torch_cuda_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
388
+ 'train_loss': train_loss,
389
+ 'train_loss_count': train_loss_count,
390
+ }
391
+ if tcm_head is not None:
392
+ checkpoint['tcm_head'] = tcm_head.state_dict()
393
+ torch.save(checkpoint, os.path.join(
394
+ args.save_dir, 'best_CER.pth'))
395
+
396
+ if val_wer < best_wer:
397
+ logger.info(
398
+ f'WER improved from {best_wer:.4f} to {val_wer:.4f}!!!')
399
+ best_wer = val_wer
400
+ checkpoint = {
401
+ 'model': model.state_dict(),
402
+ 'state_dict_ema': model_ema.ema.state_dict(),
403
+ 'optimizer': optimizer.state_dict(),
404
+ 'nb_iter': nb_iter,
405
+ 'best_cer': best_cer,
406
+ 'best_wer': best_wer,
407
+ 'args': vars(args),
408
+ 'random_state': random.getstate(),
409
+ 'numpy_state': np.random.get_state(),
410
+ 'torch_state': torch.get_rng_state(),
411
+ 'torch_cuda_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
412
+ 'train_loss': train_loss,
413
+ 'train_loss_count': train_loss_count,
414
+ }
415
+ if tcm_head is not None:
416
+ checkpoint['tcm_head'] = tcm_head.state_dict()
417
+ torch.save(checkpoint, os.path.join(
418
+ args.save_dir, 'best_WER.pth'))
419
+
420
+ logger.info(
421
+ f'Val. loss : {val_loss:0.3f} \t CER : {val_cer:0.4f} \t WER : {val_wer:0.4f} \t ')
422
+
423
+ writer.add_scalar('./VAL/CER', val_cer, nb_iter)
424
+ writer.add_scalar('./VAL/WER', val_wer, nb_iter)
425
+ writer.add_scalar('./VAL/bestCER', best_cer, nb_iter)
426
+ writer.add_scalar('./VAL/bestWER', best_wer, nb_iter)
427
+ writer.add_scalar('./VAL/val_loss', val_loss, nb_iter)
428
+ if wandb is not None:
429
+ wandb.log({
430
+ 'val/loss': val_loss,
431
+ 'val/CER': val_cer,
432
+ 'val/WER': val_wer,
433
+ 'val/best_CER': best_cer,
434
+ 'val/best_WER': best_wer,
435
+ 'iter': nb_iter,
436
+ }, step=nb_iter)
437
+ model.train()
438
+
439
+
440
+ if __name__ == '__main__':
441
+ main()
utils/option.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def get_args_parser() -> argparse.Namespace:
5
+ """Create and parse command-line options for HTR-ConvText.
6
+
7
+ This keeps all option names and defaults intact, but organizes them into
8
+ logical groups with clearer help messages.
9
+ """
10
+ parser = argparse.ArgumentParser(
11
+ description='HTR-ConvText: Leveraging Convolution and Textual Context with Mixed Masking for Handwritten Text Recognition',
12
+ add_help=True,
13
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
14
+ )
15
+
16
+ # ---------------------------------------------------------------------
17
+ # Experiment & Logging
18
+ # ---------------------------------------------------------------------
19
+ exp = parser.add_argument_group('Experiment & Logging')
20
+ exp.add_argument('--out-dir', type=str, default='./output',
21
+ help='Root directory to save logs, checkpoints, and outputs')
22
+ exp.add_argument('--exp-name', type=str, default='IAM_HTR_ORIGAMI_NET',
23
+ help='Experiment name; results go to <out-dir>/<exp-name>')
24
+ exp.add_argument('--seed', default=123, type=int,
25
+ help='Random seed for reproducibility')
26
+ exp.add_argument('--use-wandb', action='store_true', default=False,
27
+ help='Log to Weights & Biases; otherwise use TensorBoard')
28
+ exp.add_argument('--wandb-project', type=str, default='None',
29
+ help='W&B project name (used only if --use-wandb)')
30
+ exp.add_argument('--print-iter', default=100, type=int,
31
+ help='Iterations between training status prints')
32
+ exp.add_argument('--eval-iter', default=1000, type=int,
33
+ help='Iterations between validation runs')
34
+
35
+ # ---------------------------------------------------------------------
36
+ # Data & Dataloading
37
+ # ---------------------------------------------------------------------
38
+ data = parser.add_argument_group('Data & Dataloading')
39
+ data.add_argument('--dataset', type=str, choices=['iam', 'read2016', 'lam', 'vnondb'],
40
+ help='Dataset choice')
41
+ data.add_argument('--data-path', type=str, default='./data/iam/lines/',
42
+ help='Root directory containing image/line data')
43
+ data.add_argument('--train-data-list', type=str, default='./data/iam/train.ln',
44
+ help='Path to training list file (e.g., .ln)')
45
+ data.add_argument('--val-data-list', type=str, default='./data/iam/val.ln',
46
+ help='Path to validation list file (e.g., .ln)')
47
+ data.add_argument('--test-data-list', type=str, default='./data/iam/test.ln',
48
+ help='Path to test list file (e.g., .ln)')
49
+ data.add_argument('--nb-cls', default=80, type=int,
50
+ help='Number of classes. IAM=79+1, READ2016=89+1, LAM=90+1, VNOnDB=161+1')
51
+ data.add_argument('--num-workers', default=0, type=int,
52
+ help='Dataloader worker processes')
53
+ data.add_argument('--img-size', default=[512, 64], type=int, nargs='+',
54
+ help='Input image size [W, H]')
55
+ data.add_argument('--patch-size', default=[4, 32], type=int, nargs='+',
56
+ help='Patch size [W, H] for patch embedding')
57
+
58
+ # ---------------------------------------------------------------------
59
+ # Training Schedule & Optimization
60
+ # ---------------------------------------------------------------------
61
+ train = parser.add_argument_group('Training Schedule & Optimization')
62
+ train.add_argument('--train-bs', default=8, type=int,
63
+ help='Training batch size per iteration')
64
+ train.add_argument('--accum-steps', default=1, type=int,
65
+ help='Gradient accumulation steps; effective batch = train-bs * accum-steps')
66
+ train.add_argument('--val-bs', default=1, type=int,
67
+ help='Validation/test batch size')
68
+ train.add_argument('--total-iter', default=100000, type=int,
69
+ help='Total training iterations')
70
+ train.add_argument('--warm-up-iter', default=1000, type=int,
71
+ help='Warm-up iterations for the optimizer/scheduler')
72
+ train.add_argument('--max-lr', default=1e-3, type=float,
73
+ help='Peak learning rate')
74
+ train.add_argument('--weight-decay', default=5e-1, type=float,
75
+ help='Weight decay (L2) regularization')
76
+ train.add_argument('--ema-decay', default=0.9999, type=float,
77
+ help='Exponential Moving Average (EMA) decay factor for model weights')
78
+ train.add_argument('--alpha', default=0, type=float,
79
+ help='KL-divergence loss ratio (if applicable)')
80
+
81
+ # ---------------------------------------------------------------------
82
+ # Model & Encoder
83
+ # ---------------------------------------------------------------------
84
+ model = parser.add_argument_group('Model & Encoder')
85
+ model.add_argument('--model-type', default='ctc', type=str, choices=['ctc', 'encoder_decoder'],
86
+ help='Model family to train/use')
87
+ model.add_argument('--cos-temp', default=8, type=int,
88
+ help='Cosine-similarity classifier temperature')
89
+ model.add_argument('--proj', default=8, type=float,
90
+ help='Projection dimension or scaling for classifier head')
91
+ model.add_argument('--attn-mask-ratio', default=0., type=float,
92
+ help='Attention drop-key mask ratio')
93
+
94
+ # ---------------------------------------------------------------------
95
+ # Masking Strategy
96
+ # ---------------------------------------------------------------------
97
+ mask = parser.add_argument_group('Masking Strategy')
98
+ mask.add_argument('--use-masking', action='store_true', default=False,
99
+ help='Enable masking strategy during training')
100
+ mask.add_argument('--mask-ratio', default=0.3, type=float,
101
+ help='Overall proportion of tokens/patches to mask')
102
+ mask.add_argument('--max-span-length', default=4, type=int,
103
+ help='Max length for individual span masks')
104
+ mask.add_argument('--spacing', default=0, type=int,
105
+ help='Minimum spacing between two span masks')
106
+ # Tri-masking schedule ratios
107
+ mask.add_argument('--r-rand', dest='r_rand', default=0.6, type=float,
108
+ help='Ratio for random masking in tri-masking schedule')
109
+ mask.add_argument('--r-block', dest='r_block', default=0.6, type=float,
110
+ help='Ratio for block masking in tri-masking schedule')
111
+ mask.add_argument('--block-span', dest='block_span', default=4, type=int,
112
+ help='Block span length for block masking')
113
+ mask.add_argument('--r-span', dest='r_span', default=0.4, type=float,
114
+ help='Ratio for span masking in tri-masking schedule')
115
+ mask.add_argument('--max-span', dest='max_span', default=8, type=int,
116
+ help='Max span length for span masking')
117
+
118
+ # ---------------------------------------------------------------------
119
+ # Data Augmentations
120
+ # ---------------------------------------------------------------------
121
+ aug = parser.add_argument_group('Data Augmentations')
122
+ aug.add_argument('--dpi-min-factor', default=0.5, type=float,
123
+ help='Minimum scaling factor for DPI-based resize')
124
+ aug.add_argument('--dpi-max-factor', default=1.5, type=float,
125
+ help='Maximum scaling factor for DPI-based resize')
126
+ aug.add_argument('--perspective-low', default=0., type=float,
127
+ help='Lower bound for perspective transform magnitude')
128
+ aug.add_argument('--perspective-high', default=0.4, type=float,
129
+ help='Upper bound for perspective transform magnitude')
130
+ aug.add_argument('--elastic-distortion-min-kernel-size', default=3, type=int,
131
+ help='Minimum kernel size for elastic distortion grid')
132
+ aug.add_argument('--elastic-distortion-max-kernel-size', default=3, type=int,
133
+ help='Maximum kernel size for elastic distortion grid')
134
+ aug.add_argument('--elastic_distortion-max-magnitude', default=20, type=int,
135
+ help='Maximum distortion magnitude for elastic transforms')
136
+ aug.add_argument('--elastic-distortion-min-alpha', default=0.5, type=float,
137
+ help='Minimum alpha for elastic distortion')
138
+ aug.add_argument('--elastic-distortion-max-alpha', default=1, type=float,
139
+ help='Maximum alpha for elastic distortion')
140
+ aug.add_argument('--elastic-distortion-min-sigma', default=1, type=int,
141
+ help='Minimum sigma for Gaussian in elastic distortion')
142
+ aug.add_argument('--elastic-distortion-max-sigma', default=10, type=int,
143
+ help='Maximum sigma for Gaussian in elastic distortion')
144
+ aug.add_argument('--dila-ero-max-kernel', default=3, type=int,
145
+ help='Max kernel size for dilation/erosion ops')
146
+ aug.add_argument('--dila-ero-iter', default=1, type=int,
147
+ help='Iterations for dilation/erosion')
148
+ aug.add_argument('--jitter-contrast', default=0.4, type=float,
149
+ help='ColorJitter: contrast range')
150
+ aug.add_argument('--jitter-brightness', default=0.4, type=float,
151
+ help='ColorJitter: brightness range')
152
+ aug.add_argument('--jitter-saturation', default=0.4, type=float,
153
+ help='ColorJitter: saturation range')
154
+ aug.add_argument('--jitter-hue', default=0.2, type=float,
155
+ help='ColorJitter: hue range')
156
+ aug.add_argument('--blur-min-kernel', default=3, type=int,
157
+ help='Minimum kernel size for Gaussian blur')
158
+ aug.add_argument('--blur-max-kernel', default=5, type=int,
159
+ help='Maximum kernel size for Gaussian blur')
160
+ aug.add_argument('--blur-min-sigma', default=3, type=int,
161
+ help='Minimum sigma for Gaussian blur')
162
+ aug.add_argument('--blur-max-sigma', default=5, type=int,
163
+ help='Maximum sigma for Gaussian blur')
164
+ aug.add_argument('--sharpen-min-alpha', default=0, type=int,
165
+ help='Minimum alpha/mix for sharpening')
166
+ aug.add_argument('--sharpen-max-alpha', default=1, type=int,
167
+ help='Maximum alpha/mix for sharpening')
168
+ aug.add_argument('--sharpen-min-strength', default=0, type=int,
169
+ help='Minimum sharpening strength')
170
+ aug.add_argument('--sharpen-max-strength', default=1, type=int,
171
+ help='Maximum sharpening strength')
172
+ aug.add_argument('--zoom-min-h', default=0.8, type=float,
173
+ help='Minimum vertical zoom factor')
174
+ aug.add_argument('--zoom-max-h', default=1, type=float,
175
+ help='Maximum vertical zoom factor')
176
+ aug.add_argument('--zoom-min-w', default=0.99, type=float,
177
+ help='Minimum horizontal zoom factor')
178
+ aug.add_argument('--zoom-max-w', default=1, type=float,
179
+ help='Maximum horizontal zoom factor')
180
+ aug.add_argument('--proba', default=0.5, type=float,
181
+ help='Default probability for applying stochastic augmentations')
182
+
183
+ # ---------------------------------------------------------------------
184
+ # Decoder & Inference (for encoder-decoder mode)
185
+ # ---------------------------------------------------------------------
186
+ dec = parser.add_argument_group('Decoder & Inference')
187
+ dec.add_argument('--decoder-layers', default=6, type=int,
188
+ help='Number of Transformer decoder layers')
189
+ dec.add_argument('--decoder-heads', default=8, type=int,
190
+ help='Number of attention heads in decoder')
191
+ dec.add_argument('--max-seq-len', default=256, type=int,
192
+ help='Maximum output sequence length')
193
+ dec.add_argument('--label-smoothing', default=0.1, type=float,
194
+ help='Label-smoothing factor for cross-entropy loss')
195
+ dec.add_argument('--beam-size', default=5, type=int,
196
+ help='Beam size for beam-search decoding')
197
+ dec.add_argument('--generation-method', default='nucleus', type=str,
198
+ choices=['greedy', 'nucleus', 'beam_search'],
199
+ help='Token generation method for inference')
200
+ dec.add_argument('--generation-temperature', default=0.7, type=float,
201
+ help='Sampling temperature (used by nucleus/greedy sampling)')
202
+ dec.add_argument('--repetition-penalty', default=1.3, type=float,
203
+ help='Penalty to discourage token repetition during generation')
204
+ dec.add_argument('--top-p', default=0.9, type=float,
205
+ help='Top-p threshold for nucleus sampling')
206
+
207
+ # ---------------------------------------------------------------------
208
+ # TCM (Textual Context Module)
209
+ # ---------------------------------------------------------------------
210
+ tcm = parser.add_argument_group('TCM (Textual Context Module)')
211
+ tcm.add_argument('--tcm-enable', action='store_true', default=False,
212
+ help='Enable Textual Context Module (TCM)')
213
+ tcm.add_argument('--tcm-lambda', default=1.0, type=float,
214
+ help='TCM loss weight (λ2 in the paper)')
215
+ tcm.add_argument('--ctc-lambda', default=0.1, type=float,
216
+ help='CTC loss weight (λ1 in the paper)')
217
+ tcm.add_argument('--tcm-sub-len', default=5, type=int,
218
+ help='TCM context sub-string length')
219
+ tcm.add_argument('--tcm-warmup-iters', default=0, type=int,
220
+ help='Warm-up iterations before activating TCM (0 = start immediately)')
221
+
222
+ # ---------------------------------------------------------------------
223
+ # Checkpointing & Pretrained Weights
224
+ # ---------------------------------------------------------------------
225
+ ckpt = parser.add_argument_group('Checkpointing & Pretrained Weights')
226
+ ckpt.add_argument('--resume', type=str, default=None,
227
+ help='Resume training from a checkpoint (alias)')
228
+ ckpt.add_argument('--load-model', type=str, default=None,
229
+ help='Load a full pretrained model for fine-tuning')
230
+ ckpt.add_argument('--load-encoder-only', action='store_true', default=False,
231
+ help='Load only encoder weights (transfer learning)')
232
+ ckpt.add_argument('--strict-loading', action='store_true', default=True,
233
+ help='Use strict key matching when loading weights')
234
+
235
+ return parser.parse_args()
utils/sam.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class SAM(torch.optim.Optimizer):
5
+ def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
6
+ assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
7
+
8
+ defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
9
+ super(SAM, self).__init__(params, defaults)
10
+
11
+ self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
12
+ self.param_groups = self.base_optimizer.param_groups
13
+ self.defaults.update(self.base_optimizer.defaults)
14
+
15
+ @torch.no_grad()
16
+ def first_step(self, zero_grad=False):
17
+ grad_norm = self._grad_norm()
18
+ for group in self.param_groups:
19
+ scale = group["rho"] / (grad_norm + 1e-12)
20
+
21
+ for p in group["params"]:
22
+ if p.grad is None: continue
23
+ self.state[p]["old_p"] = p.data.clone()
24
+ e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
25
+ p.add_(e_w) # climb to the local maximum "w + e(w)"
26
+
27
+ if zero_grad: self.zero_grad()
28
+
29
+ @torch.no_grad()
30
+ def second_step(self, zero_grad=False):
31
+ for group in self.param_groups:
32
+ for p in group["params"]:
33
+ if p.grad is None: continue
34
+ p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
35
+
36
+ self.base_optimizer.step() # do the actual "sharpness-aware" update
37
+
38
+ if zero_grad: self.zero_grad()
39
+
40
+ @torch.no_grad()
41
+ def step(self, closure=None):
42
+ assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
43
+ closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
44
+
45
+ self.first_step(zero_grad=True)
46
+ closure()
47
+ self.second_step()
48
+
49
+ def _grad_norm(self):
50
+ shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
51
+ norm = torch.norm(
52
+ torch.stack([
53
+ ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
54
+ for group in self.param_groups for p in group["params"]
55
+ if p.grad is not None
56
+ ]),
57
+ p=2
58
+ )
59
+ return norm
60
+
61
+ def load_state_dict(self, state_dict):
62
+ super().load_state_dict(state_dict)
63
+ self.base_optimizer.param_groups = self.param_groups
utils/utils.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributions.uniform import Uniform
4
+
5
+ import os
6
+ import re
7
+ import sys
8
+ import math
9
+ import logging
10
+ from copy import deepcopy
11
+ from collections import OrderedDict
12
+ import random
13
+ import numpy as np
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+
17
+ def randint(low, high):
18
+ return int(torch.randint(low, high, (1, )))
19
+
20
+
21
+ def rand_uniform(low, high):
22
+ return float(Uniform(low, high).sample())
23
+
24
+
25
+ def get_logger(out_dir):
26
+ logger = logging.getLogger('Exp')
27
+ logger.setLevel(logging.INFO)
28
+ formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
29
+
30
+ file_path = os.path.join(out_dir, "run.log")
31
+ file_hdlr = logging.FileHandler(file_path)
32
+ file_hdlr.setFormatter(formatter)
33
+
34
+ strm_hdlr = logging.StreamHandler(sys.stdout)
35
+ strm_hdlr.setFormatter(formatter)
36
+
37
+ logger.addHandler(file_hdlr)
38
+ logger.addHandler(strm_hdlr)
39
+ return logger
40
+
41
+
42
+ def update_lr_cos(nb_iter, warm_up_iter, total_iter, max_lr, optimizer, min_lr=1e-7):
43
+
44
+ if nb_iter < warm_up_iter:
45
+ current_lr = max_lr * (nb_iter + 1) / (warm_up_iter + 1)
46
+ else:
47
+ current_lr = min_lr + (max_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * nb_iter / (total_iter - warm_up_iter)))
48
+
49
+ for param_group in optimizer.param_groups:
50
+ param_group["lr"] = current_lr
51
+
52
+ return optimizer, current_lr
53
+
54
+
55
+ class CTCLabelConverter(object):
56
+ def __init__(self, character):
57
+ dict_character = list(character)
58
+ self.dict = {}
59
+ for i, char in enumerate(dict_character):
60
+ self.dict[char] = i + 1
61
+ if len(self.dict) == 87: # '[' and ']' are not in the test set but in the training and validation sets.
62
+ self.dict['['], self.dict[']'] = 88, 89
63
+ self.character = ['[blank]'] + dict_character
64
+
65
+ def encode(self, text):
66
+ length = [len(s) for s in text]
67
+ text = ''.join(text)
68
+ text = [self.dict[char] for char in text]
69
+
70
+ return (torch.IntTensor(text).to(device), torch.IntTensor(length).to(device))
71
+
72
+ def decode(self, text_index, length):
73
+ texts = []
74
+ index = 0
75
+
76
+ for l in length:
77
+ t = text_index[index:index + l]
78
+ char_list = []
79
+ for i in range(l):
80
+ if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])) and t[i]<len(self.character):
81
+ char_list.append(self.character[t[i]])
82
+ text = ''.join(char_list)
83
+
84
+ texts.append(text)
85
+ index += l
86
+ return texts
87
+
88
+
89
+ class Averager(object):
90
+ def __init__(self):
91
+ self.reset()
92
+
93
+ def add(self, v):
94
+ count = v.data.numel()
95
+ v = v.data.sum()
96
+ self.n_count += count
97
+ self.sum += v
98
+
99
+ def reset(self):
100
+ self.n_count = 0
101
+ self.sum = 0
102
+
103
+ def val(self):
104
+ res = 0
105
+ if self.n_count != 0:
106
+ res = self.sum / float(self.n_count)
107
+ return res
108
+
109
+
110
+ class Metric(object):
111
+ def __init__(self, name=''):
112
+ self.name = name
113
+ self.sum = torch.tensor(0.).double()
114
+ self.n = torch.tensor(0.)
115
+
116
+ def update(self, val):
117
+ rt = val.clone()
118
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
119
+ rt /= dist.get_world_size()
120
+ self.sum += rt.detach().cpu().double()
121
+ self.n += 1
122
+
123
+ @property
124
+ def avg(self):
125
+ return self.sum / self.n.double()
126
+
127
+
128
+ class ModelEma:
129
+ def __init__(self, model, decay=0.9999, device='', resume=''):
130
+ self.ema = deepcopy(model)
131
+ self.ema.eval()
132
+ self.decay = decay
133
+ self.device = device
134
+ if device:
135
+ self.ema.to(device=device)
136
+ self.ema_has_module = hasattr(self.ema, 'module')
137
+ if resume:
138
+ self._load_checkpoint(resume)
139
+ for p in self.ema.parameters():
140
+ p.requires_grad_(False)
141
+
142
+ def _load_checkpoint(self, checkpoint_path, mapl=None):
143
+ checkpoint = torch.load(checkpoint_path,map_location=mapl)
144
+ assert isinstance(checkpoint, dict)
145
+ if 'state_dict_ema' in checkpoint:
146
+ new_state_dict = OrderedDict()
147
+ for k, v in checkpoint['state_dict_ema'].items():
148
+ if self.ema_has_module:
149
+ name = 'module.' + k if not k.startswith('module') else k
150
+ else:
151
+ name = k
152
+ new_state_dict[name] = v
153
+ self.ema.load_state_dict(new_state_dict)
154
+ print("=> Loaded state_dict_ema")
155
+ else:
156
+ print("=> Failed to find state_dict_ema, starting from loaded model weights")
157
+
158
+ def update(self, model, num_updates=-1):
159
+ needs_module = hasattr(model, 'module') and not self.ema_has_module
160
+ if num_updates >= 0:
161
+ _cdecay = min(self.decay, (1 + num_updates) / (10 + num_updates))
162
+ else:
163
+ _cdecay = self.decay
164
+
165
+ with torch.no_grad():
166
+ msd = model.state_dict()
167
+ for k, ema_v in self.ema.state_dict().items():
168
+ if needs_module:
169
+ k = 'module.' + k
170
+ model_v = msd[k].detach()
171
+ if self.device:
172
+ model_v = model_v.to(device=self.device)
173
+ ema_v.copy_(ema_v * _cdecay + (1. - _cdecay) * model_v)
174
+
175
+
176
+ def format_string_for_wer(str):
177
+ str = re.sub('([\[\]{}/\\()\"\'&+*=<>?.;:,!\-—_€#%°])', r' \1 ', str)
178
+ str = re.sub('([ \n])+', " ", str).strip()
179
+ return str
180
+
181
+ def load_checkpoint(model, model_ema, optimizer, checkpoint_path, logger):
182
+ best_cer, best_wer, start_iter = 1e+6, 1e+6, 1
183
+ train_loss, train_loss_count = 0.0, 0
184
+ optimizer_state = None
185
+ if checkpoint_path is not None and os.path.isfile(checkpoint_path):
186
+ logger.info(f"Resuming from checkpoint: {checkpoint_path}")
187
+ checkpoint = torch.load(
188
+ checkpoint_path, map_location='cpu', weights_only=False)
189
+
190
+ # Load model state dict (handle module prefix like in test.py)
191
+ model_dict = OrderedDict()
192
+ pattern = re.compile('module.')
193
+
194
+ # For main model, load from the 'model' state dict
195
+ # (the training checkpoint contains both 'model' and 'state_dict_ema')
196
+ if 'model' in checkpoint:
197
+ source_dict = checkpoint['model']
198
+ logger.info("Loading main model from 'model' state dict")
199
+ elif 'state_dict_ema' in checkpoint:
200
+ source_dict = checkpoint['state_dict_ema']
201
+ logger.info(
202
+ "Loading main model from 'state_dict_ema' (fallback)")
203
+ else:
204
+ raise KeyError(
205
+ "Neither 'model' nor 'state_dict_ema' found in checkpoint")
206
+
207
+ for k, v in source_dict.items():
208
+ if re.search("module", k):
209
+ model_dict[re.sub(pattern, '', k)] = v
210
+ else:
211
+ model_dict[k] = v
212
+
213
+ model.load_state_dict(model_dict, strict=True)
214
+ logger.info("Successfully loaded main model state dict")
215
+
216
+ # Load EMA state dict if available
217
+ if 'state_dict_ema' in checkpoint and model_ema is not None:
218
+ ema_dict = OrderedDict()
219
+ for k, v in checkpoint['state_dict_ema'].items():
220
+ if re.search("module", k):
221
+ ema_dict[re.sub(pattern, '', k)] = v
222
+ else:
223
+ ema_dict[k] = v
224
+ model_ema.ema.load_state_dict(ema_dict, strict=True)
225
+ logger.info("Successfully loaded EMA model state dict")
226
+
227
+ # Load optimizer state - handle SAM optimizer structure
228
+ if 'optimizer' in checkpoint and optimizer is not None:
229
+ try:
230
+ optimizer_state = checkpoint['optimizer']
231
+ logger.info(
232
+ "Optimizer state will be loaded after optimizer initialization")
233
+ except Exception as e:
234
+ logger.warning(f"Failed to prepare optimizer state: {e}")
235
+ optimizer_state = None
236
+
237
+ # Load metrics from checkpoint if available
238
+ if 'best_cer' in checkpoint:
239
+ best_cer = checkpoint['best_cer']
240
+ if 'best_wer' in checkpoint:
241
+ best_wer = checkpoint['best_wer']
242
+ if 'nb_iter' in checkpoint:
243
+ start_iter = checkpoint['nb_iter'] + 1
244
+
245
+ # Parse CER, WER, iter from filename as fallback
246
+ m = re.search(
247
+ r'checkpoint_(?P<cer>[\d\.]+)_(?P<wer>[\d\.]+)_(?P<iter>\d+)\.pth', checkpoint_path)
248
+ if m and 'best_cer' not in checkpoint:
249
+ best_cer = float(m.group('cer'))
250
+ best_wer = float(m.group('wer'))
251
+ start_iter = int(m.group('iter')) + 1
252
+
253
+ if 'train_loss' in checkpoint:
254
+ train_loss = checkpoint['train_loss']
255
+ if 'train_loss_count' in checkpoint:
256
+ train_loss_count = checkpoint['train_loss_count']
257
+ if 'random_state' in checkpoint:
258
+ random.setstate(checkpoint['random_state'])
259
+ logger.info("Restored random state")
260
+ if 'numpy_state' in checkpoint:
261
+ np.random.set_state(checkpoint['numpy_state'])
262
+ logger.info("Restored numpy random state")
263
+ if 'torch_state' in checkpoint:
264
+ torch.set_rng_state(checkpoint['torch_state'])
265
+ logger.info("Restored torch random state")
266
+ if 'torch_cuda_state' in checkpoint and torch.cuda.is_available():
267
+ torch.cuda.set_rng_state(checkpoint['torch_cuda_state'])
268
+ logger.info("Restored torch cuda random state")
269
+
270
+ # Validate that the model was loaded correctly by checking a few parameters
271
+ total_params = sum(p.numel() for p in model.parameters())
272
+ logger.info(f"Model loaded with {total_params} total parameters")
273
+
274
+ logger.info(
275
+ f"Resumed best_cer={best_cer}, best_wer={best_wer}, start_iter={start_iter}")
276
+ return best_cer, best_wer, start_iter, optimizer_state, train_loss, train_loss_count
valid.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ import torch.backends.cudnn as cudnn
4
+
5
+ from utils import utils
6
+ import editdistance
7
+
8
+
9
+ def validation(model, criterion, evaluation_loader, converter):
10
+ """ validation or evaluation """
11
+
12
+ norm_ED = 0
13
+ norm_ED_wer = 0
14
+
15
+ tot_ED = 0
16
+ tot_ED_wer = 0
17
+
18
+ valid_loss = 0.0
19
+ length_of_gt = 0
20
+ length_of_gt_wer = 0
21
+ count = 0
22
+ all_preds_str = []
23
+ all_labels = []
24
+
25
+ for i, (image_tensors, labels) in enumerate(evaluation_loader):
26
+ batch_size = image_tensors.size(0)
27
+ image = image_tensors.cuda()
28
+
29
+ text_for_loss, length_for_loss = converter.encode(labels)
30
+
31
+ preds = model(image)
32
+ preds = preds.float()
33
+ preds_size = torch.IntTensor([preds.size(1)] * batch_size)
34
+ preds = preds.permute(1, 0, 2).log_softmax(2)
35
+
36
+ torch.backends.cudnn.enabled = False
37
+ cost = criterion(preds, text_for_loss, preds_size, length_for_loss).mean()
38
+ torch.backends.cudnn.enabled = True
39
+
40
+ _, preds_index = preds.max(2)
41
+ preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
42
+ preds_str = converter.decode(preds_index.data, preds_size.data)
43
+ valid_loss += cost.item()
44
+ count += 1
45
+
46
+ all_preds_str.extend(preds_str)
47
+ all_labels.extend(labels)
48
+
49
+ for pred_cer, gt_cer in zip(preds_str, labels):
50
+ tmp_ED = editdistance.eval(pred_cer, gt_cer)
51
+ if len(gt_cer) == 0:
52
+ norm_ED += 1
53
+ else:
54
+ norm_ED += tmp_ED / float(len(gt_cer))
55
+ tot_ED += tmp_ED
56
+ length_of_gt += len(gt_cer)
57
+
58
+ for pred_wer, gt_wer in zip(preds_str, labels):
59
+ pred_wer = utils.format_string_for_wer(pred_wer)
60
+ gt_wer = utils.format_string_for_wer(gt_wer)
61
+ pred_wer = pred_wer.split(" ")
62
+ gt_wer = gt_wer.split(" ")
63
+ tmp_ED_wer = editdistance.eval(pred_wer, gt_wer)
64
+
65
+ if len(gt_wer) == 0:
66
+ norm_ED_wer += 1
67
+ else:
68
+ norm_ED_wer += tmp_ED_wer / float(len(gt_wer))
69
+
70
+ tot_ED_wer += tmp_ED_wer
71
+ length_of_gt_wer += len(gt_wer)
72
+
73
+ val_loss = valid_loss / count
74
+ CER = tot_ED / float(length_of_gt)
75
+ WER = tot_ED_wer / float(length_of_gt_wer)
76
+
77
+ return val_loss, CER, WER, all_preds_str, all_labels