Ehtesham123 commited on
Commit
a045aa1
·
verified ·
1 Parent(s): 2663dbe

Upload 54 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .gradio/flagged/dataset1.csv +2 -0
  3. __pycache__/parseq_recognize.cpython-311.pyc +0 -0
  4. __pycache__/yolo_detect.cpython-311.pyc +0 -0
  5. app.py +34 -0
  6. parseq_recognize.py +18 -0
  7. pretrained_model/parseq.ckpt +3 -0
  8. pretrained_model/yolo_obb.pt +3 -0
  9. requirements.txt +8 -0
  10. samples/T1.png +3 -0
  11. samples/T2.png +3 -0
  12. samples/image_0004.png +3 -0
  13. samples/image_0082.png +3 -0
  14. strhub/__init__.py +0 -0
  15. strhub/__pycache__/__init__.cpython-311.pyc +0 -0
  16. strhub/__pycache__/__init__.cpython-312.pyc +0 -0
  17. strhub/data/__init__.py +0 -0
  18. strhub/data/__pycache__/__init__.cpython-311.pyc +0 -0
  19. strhub/data/__pycache__/__init__.cpython-312.pyc +0 -0
  20. strhub/data/__pycache__/aa_overrides.cpython-312.pyc +0 -0
  21. strhub/data/__pycache__/augment.cpython-312.pyc +0 -0
  22. strhub/data/__pycache__/dataset.cpython-311.pyc +0 -0
  23. strhub/data/__pycache__/dataset.cpython-312.pyc +0 -0
  24. strhub/data/__pycache__/module.cpython-311.pyc +0 -0
  25. strhub/data/__pycache__/module.cpython-312.pyc +0 -0
  26. strhub/data/__pycache__/utils.cpython-311.pyc +0 -0
  27. strhub/data/__pycache__/utils.cpython-312.pyc +0 -0
  28. strhub/data/aa_overrides.py +46 -0
  29. strhub/data/augment.py +112 -0
  30. strhub/data/dataset.py +148 -0
  31. strhub/data/module.py +158 -0
  32. strhub/data/utils.py +150 -0
  33. strhub/models/__init__.py +0 -0
  34. strhub/models/__pycache__/__init__.cpython-311.pyc +0 -0
  35. strhub/models/__pycache__/__init__.cpython-312.pyc +0 -0
  36. strhub/models/__pycache__/base.cpython-311.pyc +0 -0
  37. strhub/models/__pycache__/base.cpython-312.pyc +0 -0
  38. strhub/models/__pycache__/utils.cpython-311.pyc +0 -0
  39. strhub/models/__pycache__/utils.cpython-312.pyc +0 -0
  40. strhub/models/base.py +221 -0
  41. strhub/models/modules.py +20 -0
  42. strhub/models/parseq/__init__.py +0 -0
  43. strhub/models/parseq/__pycache__/__init__.cpython-311.pyc +0 -0
  44. strhub/models/parseq/__pycache__/__init__.cpython-312.pyc +0 -0
  45. strhub/models/parseq/__pycache__/model.cpython-311.pyc +0 -0
  46. strhub/models/parseq/__pycache__/model.cpython-312.pyc +0 -0
  47. strhub/models/parseq/__pycache__/modules.cpython-311.pyc +0 -0
  48. strhub/models/parseq/__pycache__/modules.cpython-312.pyc +0 -0
  49. strhub/models/parseq/__pycache__/system.cpython-311.pyc +0 -0
  50. strhub/models/parseq/__pycache__/system.cpython-312.pyc +0 -0
.gitattributes CHANGED
@@ -37,3 +37,7 @@ ocr_demo/samples/image_0004.png filter=lfs diff=lfs merge=lfs -text
37
  ocr_demo/samples/image_0082.png filter=lfs diff=lfs merge=lfs -text
38
  ocr_demo/samples/T1.png filter=lfs diff=lfs merge=lfs -text
39
  ocr_demo/samples/T2.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
37
  ocr_demo/samples/image_0082.png filter=lfs diff=lfs merge=lfs -text
38
  ocr_demo/samples/T1.png filter=lfs diff=lfs merge=lfs -text
39
  ocr_demo/samples/T2.png filter=lfs diff=lfs merge=lfs -text
40
+ samples/image_0004.png filter=lfs diff=lfs merge=lfs -text
41
+ samples/image_0082.png filter=lfs diff=lfs merge=lfs -text
42
+ samples/T1.png filter=lfs diff=lfs merge=lfs -text
43
+ samples/T2.png filter=lfs diff=lfs merge=lfs -text
.gradio/flagged/dataset1.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Choose a sample image,Recognized Text,timestamp
2
+ ,,2025-08-07 15:20:31.167102
__pycache__/parseq_recognize.cpython-311.pyc ADDED
Binary file (2.07 kB). View file
 
__pycache__/yolo_detect.cpython-311.pyc ADDED
Binary file (4 kB). View file
 
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from yolo_detect import OBBPredictor
3
+ from parseq_recognize import TextRecognizer
4
+ import os
5
+
6
+ # Initialize models
7
+ yolo_model_path = "pretrained_model\\yolo_obb.pt"
8
+ parseq_ckpt_path = "pretrained_model\\parseq.ckpt"
9
+
10
+ detector = OBBPredictor(yolo_model_path)
11
+ recognizer = TextRecognizer(parseq_ckpt_path, device='cpu') # or 'cuda' if on GPU
12
+
13
+ # ==== OCR pipeline function ====
14
+ def run_pipeline(image):
15
+ crops = detector.predict(image)
16
+ recognized_texts = [recognizer.recognize(crop) for crop in crops]
17
+ final_output = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(recognized_texts)])
18
+ return final_output if recognized_texts else "No text detected."
19
+
20
+ # ==== Get sample image paths ====
21
+ example_images = [f"samples/{f}" for f in os.listdir("samples") if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
22
+
23
+ # ==== Gradio app with ONLY sample images ====
24
+ demo = gr.Interface(
25
+ fn=run_pipeline,
26
+ inputs=gr.Image(type="pil", label="Choose a sample image"),
27
+ outputs=gr.Textbox(label="Recognized Text"),
28
+ examples=[[img] for img in example_images], # list of lists required
29
+ title="Two-Stage OCR Network for Aero Engine Blades Serial Number",
30
+ description="Choose only one of the predefined image. The model will detect text regions and recognize their contents."
31
+ )
32
+
33
+ if __name__ == "__main__":
34
+ demo.launch()
parseq_recognize.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from strhub.data.module import SceneTextDataModule
4
+ from strhub.models.utils import load_from_checkpoint
5
+
6
+ class TextRecognizer:
7
+ def __init__(self, ckpt_path, device='cpu'):
8
+ self.device = device
9
+ self.parseq = load_from_checkpoint(ckpt_path).eval().to(device)
10
+ self.img_transform = SceneTextDataModule.get_transform(self.parseq.hparams.img_size)
11
+
12
+ def recognize(self, image_pil):
13
+ image_tensor = self.img_transform(image_pil).unsqueeze(0).to(self.device)
14
+ with torch.no_grad():
15
+ logits = self.parseq(image_tensor)
16
+ pred = logits.softmax(-1)
17
+ label, _ = self.parseq.tokenizer.decode(pred)
18
+ return label[0]
pretrained_model/parseq.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c95fbe3efe9c59f71e7f75761b7b70b5ed5097e7f502cf138d6eded042f7c073
3
+ size 96584214
pretrained_model/yolo_obb.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:739cd8cd3f49a3f466cbdee965dcc3720331404d0de5787881bb8a95992dd6e1
3
+ size 5715964
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ ultralytics
3
+ torch
4
+ torchvision
5
+ strhub
6
+ Pillow
7
+ opencv-python
8
+ numpy
samples/T1.png ADDED

Git LFS Details

  • SHA256: e342674a1d65afa411806908fc186db15058e5cb76d89b64bf3b56117ce1622f
  • Pointer size: 132 Bytes
  • Size of remote file: 9.28 MB
samples/T2.png ADDED

Git LFS Details

  • SHA256: 8256e3156e5a16973574bb81879269128c0c7728000df15f7136eb65e0b1544a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.4 MB
samples/image_0004.png ADDED

Git LFS Details

  • SHA256: 2853354cd82fa5ae60c224a44c3d8fc80436c1d806587f0abd47df763460ad3e
  • Pointer size: 132 Bytes
  • Size of remote file: 3.02 MB
samples/image_0082.png ADDED

Git LFS Details

  • SHA256: e476ee9ab4472b41e990a4c2f0d194e2f4152c4234d0897b80bb113d1ea5d50b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.87 MB
strhub/__init__.py ADDED
File without changes
strhub/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (166 Bytes). View file
 
strhub/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (154 Bytes). View file
 
strhub/data/__init__.py ADDED
File without changes
strhub/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (171 Bytes). View file
 
strhub/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (159 Bytes). View file
 
strhub/data/__pycache__/aa_overrides.cpython-312.pyc ADDED
Binary file (1.67 kB). View file
 
strhub/data/__pycache__/augment.cpython-312.pyc ADDED
Binary file (5.2 kB). View file
 
strhub/data/__pycache__/dataset.cpython-311.pyc ADDED
Binary file (8.05 kB). View file
 
strhub/data/__pycache__/dataset.cpython-312.pyc ADDED
Binary file (7.07 kB). View file
 
strhub/data/__pycache__/module.cpython-311.pyc ADDED
Binary file (7.19 kB). View file
 
strhub/data/__pycache__/module.cpython-312.pyc ADDED
Binary file (6.65 kB). View file
 
strhub/data/__pycache__/utils.cpython-311.pyc ADDED
Binary file (10.9 kB). View file
 
strhub/data/__pycache__/utils.cpython-312.pyc ADDED
Binary file (8.84 kB). View file
 
strhub/data/aa_overrides.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Extends default ops to accept optional parameters."""
17
+ from functools import partial
18
+
19
+ from timm.data.auto_augment import _LEVEL_DENOM, LEVEL_TO_ARG, NAME_TO_OP, _randomly_negate, rotate
20
+
21
+
22
+ def rotate_expand(img, degrees, **kwargs):
23
+ """Rotate operation with expand=True to avoid cutting off the characters"""
24
+ kwargs['expand'] = True
25
+ return rotate(img, degrees, **kwargs)
26
+
27
+
28
+ def _level_to_arg(level, hparams, key, default):
29
+ magnitude = hparams.get(key, default)
30
+ level = (level / _LEVEL_DENOM) * magnitude
31
+ level = _randomly_negate(level)
32
+ return (level,)
33
+
34
+
35
+ def apply():
36
+ # Overrides
37
+ NAME_TO_OP.update({
38
+ 'Rotate': rotate_expand,
39
+ })
40
+ LEVEL_TO_ARG.update({
41
+ 'Rotate': partial(_level_to_arg, key='rotate_deg', default=30.0),
42
+ 'ShearX': partial(_level_to_arg, key='shear_x_pct', default=0.3),
43
+ 'ShearY': partial(_level_to_arg, key='shear_y_pct', default=0.3),
44
+ 'TranslateXRel': partial(_level_to_arg, key='translate_x_pct', default=0.45),
45
+ 'TranslateYRel': partial(_level_to_arg, key='translate_y_pct', default=0.45),
46
+ })
strhub/data/augment.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+
18
+ import imgaug.augmenters as iaa
19
+ import numpy as np
20
+ from PIL import Image, ImageFilter
21
+
22
+ from timm.data import auto_augment
23
+
24
+ from strhub.data import aa_overrides
25
+
26
+ aa_overrides.apply()
27
+
28
+ _OP_CACHE = {}
29
+
30
+
31
+ def _get_op(key, factory):
32
+ try:
33
+ op = _OP_CACHE[key]
34
+ except KeyError:
35
+ op = factory()
36
+ _OP_CACHE[key] = op
37
+ return op
38
+
39
+
40
+ def _get_param(level, img, max_dim_factor, min_level=1):
41
+ max_level = max(min_level, max_dim_factor * max(img.size))
42
+ return round(min(level, max_level))
43
+
44
+
45
+ def gaussian_blur(img, radius, **__):
46
+ radius = _get_param(radius, img, 0.02)
47
+ key = 'gaussian_blur_' + str(radius)
48
+ op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius))
49
+ return img.filter(op)
50
+
51
+
52
+ def motion_blur(img, k, **__):
53
+ k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values
54
+ key = 'motion_blur_' + str(k)
55
+ op = _get_op(key, lambda: iaa.MotionBlur(k))
56
+ return Image.fromarray(op(image=np.asarray(img)))
57
+
58
+
59
+ def gaussian_noise(img, scale, **_):
60
+ scale = _get_param(scale, img, 0.25) | 1 # bin to odd values
61
+ key = 'gaussian_noise_' + str(scale)
62
+ op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale))
63
+ return Image.fromarray(op(image=np.asarray(img)))
64
+
65
+
66
+ def poisson_noise(img, lam, **_):
67
+ lam = _get_param(lam, img, 0.2) | 1 # bin to odd values
68
+ key = 'poisson_noise_' + str(lam)
69
+ op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam))
70
+ return Image.fromarray(op(image=np.asarray(img)))
71
+
72
+
73
+ def _level_to_arg(level, _hparams, max):
74
+ level = max * level / auto_augment._LEVEL_DENOM
75
+ return (level,)
76
+
77
+
78
+ _RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy()
79
+ _RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops
80
+ _RAND_TRANSFORMS.extend([
81
+ 'GaussianBlur',
82
+ # 'MotionBlur',
83
+ # 'GaussianNoise',
84
+ 'PoissonNoise',
85
+ ])
86
+ auto_augment.LEVEL_TO_ARG.update({
87
+ 'GaussianBlur': partial(_level_to_arg, max=4),
88
+ 'MotionBlur': partial(_level_to_arg, max=20),
89
+ 'GaussianNoise': partial(_level_to_arg, max=0.1 * 255),
90
+ 'PoissonNoise': partial(_level_to_arg, max=40),
91
+ })
92
+ auto_augment.NAME_TO_OP.update({
93
+ 'GaussianBlur': gaussian_blur,
94
+ 'MotionBlur': motion_blur,
95
+ 'GaussianNoise': gaussian_noise,
96
+ 'PoissonNoise': poisson_noise,
97
+ })
98
+
99
+
100
+ def rand_augment_transform(magnitude=5, num_layers=3):
101
+ # These are tuned for magnitude=5, which means that effective magnitudes are half of these values.
102
+ hparams = {
103
+ 'rotate_deg': 30,
104
+ 'shear_x_pct': 0.9,
105
+ 'shear_y_pct': 0.2,
106
+ 'translate_x_pct': 0.10,
107
+ 'translate_y_pct': 0.30,
108
+ }
109
+ ra_ops = auto_augment.rand_augment_ops(magnitude, hparams=hparams, transforms=_RAND_TRANSFORMS)
110
+ # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice)
111
+ choice_weights = [1.0 / len(ra_ops) for _ in range(len(ra_ops))]
112
+ return auto_augment.RandAugment(ra_ops, num_layers, choice_weights)
strhub/data/dataset.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import glob
16
+ import io
17
+ import logging
18
+ import unicodedata
19
+ from pathlib import Path, PurePath
20
+ from typing import Callable, Optional, Union
21
+
22
+ import lmdb
23
+ from PIL import Image
24
+
25
+ from torch.utils.data import ConcatDataset, Dataset
26
+
27
+ from strhub.data.utils import CharsetAdapter
28
+
29
+ log = logging.getLogger(__name__)
30
+
31
+
32
+ def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs):
33
+ try:
34
+ kwargs.pop('root') # prevent 'root' from being passed via kwargs
35
+ except KeyError:
36
+ pass
37
+ root = Path(root).absolute()
38
+ log.info(f'dataset root:\t{root}')
39
+ datasets = []
40
+ for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True):
41
+ mdb = Path(mdb)
42
+ ds_name = str(mdb.parent.relative_to(root))
43
+ ds_root = str(mdb.parent.absolute())
44
+ dataset = LmdbDataset(ds_root, *args, **kwargs)
45
+ log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}')
46
+ datasets.append(dataset)
47
+ return ConcatDataset(datasets)
48
+
49
+
50
+ class LmdbDataset(Dataset):
51
+ """Dataset interface to an LMDB database.
52
+
53
+ It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned
54
+ as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset.
55
+ Labels are transformed according to the charset.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ root: str,
61
+ charset: str,
62
+ max_label_len: int,
63
+ min_image_dim: int = 0,
64
+ remove_whitespace: bool = True,
65
+ normalize_unicode: bool = True,
66
+ unlabelled: bool = False,
67
+ transform: Optional[Callable] = None,
68
+ ):
69
+ self._env = None
70
+ self.root = root
71
+ self.unlabelled = unlabelled
72
+ self.transform = transform
73
+ self.labels = []
74
+ self.filtered_index_list = []
75
+ self.num_samples = self._preprocess_labels(
76
+ charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim
77
+ )
78
+
79
+ def __del__(self):
80
+ if self._env is not None:
81
+ self._env.close()
82
+ self._env = None
83
+
84
+ def _create_env(self):
85
+ return lmdb.open(
86
+ self.root, max_readers=1, readonly=True, create=False, readahead=False, meminit=False, lock=False
87
+ )
88
+
89
+ @property
90
+ def env(self):
91
+ if self._env is None:
92
+ self._env = self._create_env()
93
+ return self._env
94
+
95
+ def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim):
96
+ charset_adapter = CharsetAdapter(charset)
97
+ with self._create_env() as env, env.begin() as txn:
98
+ num_samples = int(txn.get('num-samples'.encode()))
99
+ if self.unlabelled:
100
+ return num_samples
101
+ for index in range(num_samples):
102
+ index += 1 # lmdb starts with 1
103
+ label_key = f'label-{index:09d}'.encode()
104
+ label = txn.get(label_key).decode()
105
+ # Normally, whitespace is removed from the labels.
106
+ if remove_whitespace:
107
+ label = ''.join(label.split())
108
+ # Normalize unicode composites (if any) and convert to compatible ASCII characters
109
+ if normalize_unicode:
110
+ label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode()
111
+ # Filter by length before removing unsupported characters. The original label might be too long.
112
+ if len(label) > max_label_len:
113
+ continue
114
+ label = charset_adapter(label)
115
+ # We filter out samples which don't contain any supported characters
116
+ if not label:
117
+ continue
118
+ # Filter images that are too small.
119
+ if min_image_dim > 0:
120
+ img_key = f'image-{index:09d}'.encode()
121
+ buf = io.BytesIO(txn.get(img_key))
122
+ w, h = Image.open(buf).size
123
+ if w < self.min_image_dim or h < self.min_image_dim:
124
+ continue
125
+ self.labels.append(label)
126
+ self.filtered_index_list.append(index)
127
+ return len(self.labels)
128
+
129
+ def __len__(self):
130
+ return self.num_samples
131
+
132
+ def __getitem__(self, index):
133
+ if self.unlabelled:
134
+ label = index
135
+ else:
136
+ label = self.labels[index]
137
+ index = self.filtered_index_list[index]
138
+
139
+ img_key = f'image-{index:09d}'.encode()
140
+ with self.env.begin() as txn:
141
+ imgbuf = txn.get(img_key)
142
+ buf = io.BytesIO(imgbuf)
143
+ img = Image.open(buf).convert('RGB')
144
+
145
+ if self.transform is not None:
146
+ img = self.transform(img)
147
+
148
+ return img, label
strhub/data/module.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pathlib import PurePath
17
+ from typing import Callable, Optional, Sequence
18
+
19
+ from torch.utils.data import DataLoader
20
+ from torchvision import transforms as T
21
+
22
+ import pytorch_lightning as pl
23
+
24
+ from .dataset import LmdbDataset, build_tree_dataset
25
+
26
+
27
+ class SceneTextDataModule(pl.LightningDataModule):
28
+ TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80')
29
+ TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80')
30
+ TEST_NEW = ('ArT', 'COCOv1.4', 'Uber')
31
+ TEST_CUSTOM = ("blade",)
32
+ TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW))
33
+
34
+ def __init__(
35
+ self,
36
+ root_dir: str,
37
+ train_dir: str,
38
+ img_size: Sequence[int],
39
+ max_label_length: int,
40
+ charset_train: str,
41
+ charset_test: str,
42
+ batch_size: int,
43
+ num_workers: int,
44
+ augment: bool,
45
+ remove_whitespace: bool = True,
46
+ normalize_unicode: bool = True,
47
+ min_image_dim: int = 0,
48
+ rotation: int = 0,
49
+ collate_fn: Optional[Callable] = None,
50
+ ):
51
+ super().__init__()
52
+ self.root_dir = root_dir
53
+ self.train_dir = train_dir
54
+ self.img_size = tuple(img_size)
55
+ self.max_label_length = max_label_length
56
+ self.charset_train = charset_train
57
+ self.charset_test = charset_test
58
+ self.batch_size = batch_size
59
+ self.num_workers = num_workers
60
+ self.augment = augment
61
+ self.remove_whitespace = remove_whitespace
62
+ self.normalize_unicode = normalize_unicode
63
+ self.min_image_dim = min_image_dim
64
+ self.rotation = rotation
65
+ self.collate_fn = collate_fn
66
+ self._train_dataset = None
67
+ self._val_dataset = None
68
+
69
+ @staticmethod
70
+ def get_transform(img_size: tuple[int], augment: bool = False, rotation: int = 0):
71
+ transforms = []
72
+ if augment:
73
+ from .augment import rand_augment_transform
74
+
75
+ transforms.append(rand_augment_transform())
76
+ if rotation:
77
+ transforms.append(lambda img: img.rotate(rotation, expand=True))
78
+ transforms.extend([
79
+ T.Resize(img_size, T.InterpolationMode.BICUBIC),
80
+ T.ToTensor(),
81
+ T.Normalize(0.5, 0.5),
82
+ ])
83
+ return T.Compose(transforms)
84
+
85
+ @property
86
+ def train_dataset(self):
87
+ if self._train_dataset is None:
88
+ transform = self.get_transform(self.img_size, self.augment)
89
+ root = PurePath(self.root_dir, 'train', self.train_dir)
90
+ self._train_dataset = build_tree_dataset(
91
+ root,
92
+ self.charset_train,
93
+ self.max_label_length,
94
+ self.min_image_dim,
95
+ self.remove_whitespace,
96
+ self.normalize_unicode,
97
+ transform=transform,
98
+ )
99
+ return self._train_dataset
100
+
101
+ @property
102
+ def val_dataset(self):
103
+ if self._val_dataset is None:
104
+ transform = self.get_transform(self.img_size)
105
+ root = PurePath(self.root_dir, 'val')
106
+ self._val_dataset = build_tree_dataset(
107
+ root,
108
+ self.charset_test,
109
+ self.max_label_length,
110
+ self.min_image_dim,
111
+ self.remove_whitespace,
112
+ self.normalize_unicode,
113
+ transform=transform,
114
+ )
115
+ return self._val_dataset
116
+
117
+ def train_dataloader(self):
118
+ return DataLoader(
119
+ self.train_dataset,
120
+ batch_size=self.batch_size,
121
+ shuffle=True,
122
+ num_workers=self.num_workers,
123
+ persistent_workers=self.num_workers > 0,
124
+ pin_memory=True,
125
+ collate_fn=self.collate_fn,
126
+ )
127
+
128
+ def val_dataloader(self):
129
+ return DataLoader(
130
+ self.val_dataset,
131
+ batch_size=self.batch_size,
132
+ num_workers=self.num_workers,
133
+ persistent_workers=self.num_workers > 0,
134
+ pin_memory=True,
135
+ collate_fn=self.collate_fn,
136
+ )
137
+
138
+ def test_dataloaders(self, subset):
139
+ transform = self.get_transform(self.img_size, rotation=self.rotation)
140
+ root = PurePath(self.root_dir, 'test')
141
+ datasets = {
142
+ s: LmdbDataset(
143
+ str(root / s),
144
+ self.charset_test,
145
+ self.max_label_length,
146
+ self.min_image_dim,
147
+ self.remove_whitespace,
148
+ self.normalize_unicode,
149
+ transform=transform,
150
+ )
151
+ for s in subset
152
+ }
153
+ return {
154
+ k: DataLoader(
155
+ v, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=self.collate_fn
156
+ )
157
+ for k, v in datasets.items()
158
+ }
strhub/data/utils.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from abc import ABC, abstractmethod
18
+ from itertools import groupby
19
+ from typing import Optional
20
+
21
+ import torch
22
+ from torch import Tensor
23
+ from torch.nn.utils.rnn import pad_sequence
24
+
25
+
26
+ class CharsetAdapter:
27
+ """Transforms labels according to the target charset."""
28
+
29
+ def __init__(self, target_charset) -> None:
30
+ super().__init__()
31
+ self.lowercase_only = target_charset == target_charset.lower()
32
+ self.uppercase_only = target_charset == target_charset.upper()
33
+ self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
34
+
35
+ def __call__(self, label):
36
+ if self.lowercase_only:
37
+ label = label.lower()
38
+ elif self.uppercase_only:
39
+ label = label.upper()
40
+ # Remove unsupported characters
41
+ label = self.unsupported.sub('', label)
42
+ return label
43
+
44
+
45
+ class BaseTokenizer(ABC):
46
+
47
+ def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
48
+ self._itos = specials_first + tuple(charset) + specials_last
49
+ self._stoi = {s: i for i, s in enumerate(self._itos)}
50
+
51
+ def __len__(self):
52
+ return len(self._itos)
53
+
54
+ def _tok2ids(self, tokens: str) -> list[int]:
55
+ return [self._stoi[s] for s in tokens]
56
+
57
+ def _ids2tok(self, token_ids: list[int], join: bool = True) -> str:
58
+ tokens = [self._itos[i] for i in token_ids]
59
+ return ''.join(tokens) if join else tokens
60
+
61
+ @abstractmethod
62
+ def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
63
+ """Encode a batch of labels to a representation suitable for the model.
64
+
65
+ Args:
66
+ labels: List of labels. Each can be of arbitrary length.
67
+ device: Create tensor on this device.
68
+
69
+ Returns:
70
+ Batched tensor representation padded to the max label length. Shape: N, L
71
+ """
72
+ raise NotImplementedError
73
+
74
+ @abstractmethod
75
+ def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
76
+ """Internal method which performs the necessary filtering prior to decoding."""
77
+ raise NotImplementedError
78
+
79
+ def decode(self, token_dists: Tensor, raw: bool = False) -> tuple[list[str], list[Tensor]]:
80
+ """Decode a batch of token distributions.
81
+
82
+ Args:
83
+ token_dists: softmax probabilities over the token distribution. Shape: N, L, C
84
+ raw: return unprocessed labels (will return list of list of strings)
85
+
86
+ Returns:
87
+ list of string labels (arbitrary length) and
88
+ their corresponding sequence probabilities as a list of Tensors
89
+ """
90
+ batch_tokens = []
91
+ batch_probs = []
92
+ for dist in token_dists:
93
+ probs, ids = dist.max(-1) # greedy selection
94
+ if not raw:
95
+ probs, ids = self._filter(probs, ids)
96
+ tokens = self._ids2tok(ids, not raw)
97
+ batch_tokens.append(tokens)
98
+ batch_probs.append(probs)
99
+ return batch_tokens, batch_probs
100
+
101
+
102
+ class Tokenizer(BaseTokenizer):
103
+ BOS = '[B]'
104
+ EOS = '[E]'
105
+ PAD = '[P]'
106
+
107
+ def __init__(self, charset: str) -> None:
108
+ specials_first = (self.EOS,)
109
+ specials_last = (self.BOS, self.PAD)
110
+ super().__init__(charset, specials_first, specials_last)
111
+ self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last]
112
+
113
+ def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
114
+ batch = [
115
+ torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
116
+ for y in labels
117
+ ]
118
+ return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
119
+
120
+ def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
121
+ ids = ids.tolist()
122
+ try:
123
+ eos_idx = ids.index(self.eos_id)
124
+ except ValueError:
125
+ eos_idx = len(ids) # Nothing to truncate.
126
+ # Truncate after EOS
127
+ ids = ids[:eos_idx]
128
+ probs = probs[: eos_idx + 1] # but include prob. for EOS (if it exists)
129
+ return probs, ids
130
+
131
+
132
+ class CTCTokenizer(BaseTokenizer):
133
+ BLANK = '[B]'
134
+
135
+ def __init__(self, charset: str) -> None:
136
+ # BLANK uses index == 0 by default
137
+ super().__init__(charset, specials_first=(self.BLANK,))
138
+ self.blank_id = self._stoi[self.BLANK]
139
+
140
+ def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
141
+ # We use a padded representation since we don't want to use CUDNN's CTC implementation
142
+ batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels]
143
+ return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
144
+
145
+ def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
146
+ # Best path decoding:
147
+ ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens
148
+ ids = [x for x in ids if x != self.blank_id] # Remove BLANKs
149
+ # `probs` is just pass-through since all positions are considered part of the path
150
+ return probs, ids
strhub/models/__init__.py ADDED
File without changes
strhub/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (173 Bytes). View file
 
strhub/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (161 Bytes). View file
 
strhub/models/__pycache__/base.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
strhub/models/__pycache__/base.cpython-312.pyc ADDED
Binary file (11.6 kB). View file
 
strhub/models/__pycache__/utils.cpython-311.pyc ADDED
Binary file (9.09 kB). View file
 
strhub/models/__pycache__/utils.cpython-312.pyc ADDED
Binary file (8.09 kB). View file
 
strhub/models/base.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2022 Darwin Bautista
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass
19
+ from typing import Optional
20
+
21
+ from nltk import edit_distance
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torch import Tensor
26
+ from torch.optim import Optimizer
27
+ from torch.optim.lr_scheduler import OneCycleLR
28
+
29
+ import pytorch_lightning as pl
30
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
31
+ from timm.optim import create_optimizer_v2
32
+
33
+ from strhub.data.utils import BaseTokenizer, CharsetAdapter, CTCTokenizer, Tokenizer
34
+
35
+
36
+ @dataclass
37
+ class BatchResult:
38
+ num_samples: int
39
+ correct: int
40
+ ned: float
41
+ confidence: float
42
+ label_length: int
43
+ loss: Tensor
44
+ loss_numel: int
45
+
46
+
47
+ EPOCH_OUTPUT = list[dict[str, BatchResult]]
48
+
49
+
50
+ class BaseSystem(pl.LightningModule, ABC):
51
+
52
+ def __init__(
53
+ self,
54
+ tokenizer: BaseTokenizer,
55
+ charset_test: str,
56
+ batch_size: int,
57
+ lr: float,
58
+ warmup_pct: float,
59
+ weight_decay: float,
60
+ ) -> None:
61
+ super().__init__()
62
+ self.tokenizer = tokenizer
63
+ self.charset_adapter = CharsetAdapter(charset_test)
64
+ self.batch_size = batch_size
65
+ self.lr = lr
66
+ self.warmup_pct = warmup_pct
67
+ self.weight_decay = weight_decay
68
+ self.outputs: EPOCH_OUTPUT = []
69
+
70
+ @abstractmethod
71
+ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
72
+ """Inference
73
+
74
+ Args:
75
+ images: Batch of images. Shape: N, Ch, H, W
76
+ max_length: Max sequence length of the output. If None, will use default.
77
+
78
+ Returns:
79
+ logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
80
+ """
81
+ raise NotImplementedError
82
+
83
+ @abstractmethod
84
+ def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
85
+ """Like forward(), but also computes the loss (calls forward() internally).
86
+
87
+ Args:
88
+ images: Batch of images. Shape: N, Ch, H, W
89
+ labels: Text labels of the images
90
+
91
+ Returns:
92
+ logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
93
+ loss: mean loss for the batch
94
+ loss_numel: number of elements the loss was calculated from
95
+ """
96
+ raise NotImplementedError
97
+
98
+ def configure_optimizers(self):
99
+ agb = self.trainer.accumulate_grad_batches
100
+ # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
101
+ lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.0
102
+ lr = lr_scale * self.lr
103
+ optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay)
104
+ sched = OneCycleLR(
105
+ optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, cycle_momentum=False
106
+ )
107
+ return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}}
108
+
109
+ def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer) -> None:
110
+ optimizer.zero_grad(set_to_none=True)
111
+
112
+ def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]:
113
+ images, labels = batch
114
+
115
+ correct = 0
116
+ total = 0
117
+ ned = 0
118
+ confidence = 0
119
+ label_length = 0
120
+ if validation:
121
+ logits, loss, loss_numel = self.forward_logits_loss(images, labels)
122
+ else:
123
+ # At test-time, we shouldn't specify a max_label_length because the test-time charset used
124
+ # might be different from the train-time charset. max_label_length in eval_logits_loss() is computed
125
+ # based on the transformed label, which could be wrong if the actual gt label contains characters existing
126
+ # in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com"
127
+ # is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters
128
+ # long only, which sets max_label_length = 23. This will cause the model prediction to be truncated.
129
+ logits = self.forward(images)
130
+ loss = loss_numel = None # Only used for validation; not needed at test-time.
131
+
132
+ probs = logits.softmax(-1)
133
+ preds, probs = self.tokenizer.decode(probs)
134
+ for pred, prob, gt in zip(preds, probs, labels):
135
+ confidence += prob.prod().item()
136
+ pred = self.charset_adapter(pred)
137
+ # Follow ICDAR 2019 definition of N.E.D.
138
+ ned += edit_distance(pred, gt) / max(len(pred), len(gt))
139
+ if pred == gt:
140
+ correct += 1
141
+ total += 1
142
+ label_length += len(pred)
143
+ return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel))
144
+
145
+ @staticmethod
146
+ def _aggregate_results(outputs: EPOCH_OUTPUT) -> tuple[float, float, float]:
147
+ if not outputs:
148
+ return 0.0, 0.0, 0.0
149
+ total_loss = 0
150
+ total_loss_numel = 0
151
+ total_n_correct = 0
152
+ total_norm_ED = 0
153
+ total_size = 0
154
+ for result in outputs:
155
+ result = result['output']
156
+ total_loss += result.loss_numel * result.loss
157
+ total_loss_numel += result.loss_numel
158
+ total_n_correct += result.correct
159
+ total_norm_ED += result.ned
160
+ total_size += result.num_samples
161
+ acc = total_n_correct / total_size
162
+ ned = 1 - total_norm_ED / total_size
163
+ loss = total_loss / total_loss_numel
164
+ return acc, ned, loss
165
+
166
+ def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
167
+ result = self._eval_step(batch, True)
168
+ self.outputs.append(result)
169
+ return result
170
+
171
+ def on_validation_epoch_end(self) -> None:
172
+ acc, ned, loss = self._aggregate_results(self.outputs)
173
+ self.outputs.clear()
174
+ self.log('val_accuracy', 100 * acc, sync_dist=True)
175
+ self.log('val_NED', 100 * ned, sync_dist=True)
176
+ self.log('val_loss', loss, sync_dist=True)
177
+ self.log('hp_metric', acc, sync_dist=True)
178
+
179
+ def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
180
+ return self._eval_step(batch, False)
181
+
182
+
183
+ class CrossEntropySystem(BaseSystem):
184
+
185
+ def __init__(
186
+ self, charset_train: str, charset_test: str, batch_size: int, lr: float, warmup_pct: float, weight_decay: float
187
+ ) -> None:
188
+ tokenizer = Tokenizer(charset_train)
189
+ super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
190
+ self.bos_id = tokenizer.bos_id
191
+ self.eos_id = tokenizer.eos_id
192
+ self.pad_id = tokenizer.pad_id
193
+
194
+ def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
195
+ targets = self.tokenizer.encode(labels, self.device)
196
+ targets = targets[:, 1:] # Discard <bos>
197
+ max_len = targets.shape[1] - 1 # exclude <eos> from count
198
+ logits = self.forward(images, max_len)
199
+ loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id)
200
+ loss_numel = (targets != self.pad_id).sum()
201
+ return logits, loss, loss_numel
202
+
203
+
204
+ class CTCSystem(BaseSystem):
205
+
206
+ def __init__(
207
+ self, charset_train: str, charset_test: str, batch_size: int, lr: float, warmup_pct: float, weight_decay: float
208
+ ) -> None:
209
+ tokenizer = CTCTokenizer(charset_train)
210
+ super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
211
+ self.blank_id = tokenizer.blank_id
212
+
213
+ def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
214
+ targets = self.tokenizer.encode(labels, self.device)
215
+ logits = self.forward(images)
216
+ log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims
217
+ T, N, _ = log_probs.shape
218
+ input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device)
219
+ target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device)
220
+ loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True)
221
+ return logits, loss, N
strhub/models/modules.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Shared modules used by CRNN and TRBA"""
2
+ from torch import nn
3
+
4
+
5
+ class BidirectionalLSTM(nn.Module):
6
+ """Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py"""
7
+
8
+ def __init__(self, input_size, hidden_size, output_size):
9
+ super().__init__()
10
+ self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
11
+ self.linear = nn.Linear(hidden_size * 2, output_size)
12
+
13
+ def forward(self, input):
14
+ """
15
+ input : visual feature [batch_size x T x input_size], T = num_steps.
16
+ output : contextual feature [batch_size x T x output_size]
17
+ """
18
+ recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
19
+ output = self.linear(recurrent) # batch_size x T x output_size
20
+ return output
strhub/models/parseq/__init__.py ADDED
File without changes
strhub/models/parseq/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (180 Bytes). View file
 
strhub/models/parseq/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (168 Bytes). View file
 
strhub/models/parseq/__pycache__/model.cpython-311.pyc ADDED
Binary file (8.97 kB). View file
 
strhub/models/parseq/__pycache__/model.cpython-312.pyc ADDED
Binary file (8.27 kB). View file
 
strhub/models/parseq/__pycache__/modules.cpython-311.pyc ADDED
Binary file (8.83 kB). View file
 
strhub/models/parseq/__pycache__/modules.cpython-312.pyc ADDED
Binary file (7.73 kB). View file
 
strhub/models/parseq/__pycache__/system.cpython-311.pyc ADDED
Binary file (9.4 kB). View file
 
strhub/models/parseq/__pycache__/system.cpython-312.pyc ADDED
Binary file (8.59 kB). View file