Spaces:
Sleeping
Sleeping
Upload 54 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +4 -0
- .gradio/flagged/dataset1.csv +2 -0
- __pycache__/parseq_recognize.cpython-311.pyc +0 -0
- __pycache__/yolo_detect.cpython-311.pyc +0 -0
- app.py +34 -0
- parseq_recognize.py +18 -0
- pretrained_model/parseq.ckpt +3 -0
- pretrained_model/yolo_obb.pt +3 -0
- requirements.txt +8 -0
- samples/T1.png +3 -0
- samples/T2.png +3 -0
- samples/image_0004.png +3 -0
- samples/image_0082.png +3 -0
- strhub/__init__.py +0 -0
- strhub/__pycache__/__init__.cpython-311.pyc +0 -0
- strhub/__pycache__/__init__.cpython-312.pyc +0 -0
- strhub/data/__init__.py +0 -0
- strhub/data/__pycache__/__init__.cpython-311.pyc +0 -0
- strhub/data/__pycache__/__init__.cpython-312.pyc +0 -0
- strhub/data/__pycache__/aa_overrides.cpython-312.pyc +0 -0
- strhub/data/__pycache__/augment.cpython-312.pyc +0 -0
- strhub/data/__pycache__/dataset.cpython-311.pyc +0 -0
- strhub/data/__pycache__/dataset.cpython-312.pyc +0 -0
- strhub/data/__pycache__/module.cpython-311.pyc +0 -0
- strhub/data/__pycache__/module.cpython-312.pyc +0 -0
- strhub/data/__pycache__/utils.cpython-311.pyc +0 -0
- strhub/data/__pycache__/utils.cpython-312.pyc +0 -0
- strhub/data/aa_overrides.py +46 -0
- strhub/data/augment.py +112 -0
- strhub/data/dataset.py +148 -0
- strhub/data/module.py +158 -0
- strhub/data/utils.py +150 -0
- strhub/models/__init__.py +0 -0
- strhub/models/__pycache__/__init__.cpython-311.pyc +0 -0
- strhub/models/__pycache__/__init__.cpython-312.pyc +0 -0
- strhub/models/__pycache__/base.cpython-311.pyc +0 -0
- strhub/models/__pycache__/base.cpython-312.pyc +0 -0
- strhub/models/__pycache__/utils.cpython-311.pyc +0 -0
- strhub/models/__pycache__/utils.cpython-312.pyc +0 -0
- strhub/models/base.py +221 -0
- strhub/models/modules.py +20 -0
- strhub/models/parseq/__init__.py +0 -0
- strhub/models/parseq/__pycache__/__init__.cpython-311.pyc +0 -0
- strhub/models/parseq/__pycache__/__init__.cpython-312.pyc +0 -0
- strhub/models/parseq/__pycache__/model.cpython-311.pyc +0 -0
- strhub/models/parseq/__pycache__/model.cpython-312.pyc +0 -0
- strhub/models/parseq/__pycache__/modules.cpython-311.pyc +0 -0
- strhub/models/parseq/__pycache__/modules.cpython-312.pyc +0 -0
- strhub/models/parseq/__pycache__/system.cpython-311.pyc +0 -0
- strhub/models/parseq/__pycache__/system.cpython-312.pyc +0 -0
.gitattributes
CHANGED
|
@@ -37,3 +37,7 @@ ocr_demo/samples/image_0004.png filter=lfs diff=lfs merge=lfs -text
|
|
| 37 |
ocr_demo/samples/image_0082.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
ocr_demo/samples/T1.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
ocr_demo/samples/T2.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
ocr_demo/samples/image_0082.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
ocr_demo/samples/T1.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
ocr_demo/samples/T2.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
samples/image_0004.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
samples/image_0082.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
samples/T1.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
samples/T2.png filter=lfs diff=lfs merge=lfs -text
|
.gradio/flagged/dataset1.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Choose a sample image,Recognized Text,timestamp
|
| 2 |
+
,,2025-08-07 15:20:31.167102
|
__pycache__/parseq_recognize.cpython-311.pyc
ADDED
|
Binary file (2.07 kB). View file
|
|
|
__pycache__/yolo_detect.cpython-311.pyc
ADDED
|
Binary file (4 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from yolo_detect import OBBPredictor
|
| 3 |
+
from parseq_recognize import TextRecognizer
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# Initialize models
|
| 7 |
+
yolo_model_path = "pretrained_model\\yolo_obb.pt"
|
| 8 |
+
parseq_ckpt_path = "pretrained_model\\parseq.ckpt"
|
| 9 |
+
|
| 10 |
+
detector = OBBPredictor(yolo_model_path)
|
| 11 |
+
recognizer = TextRecognizer(parseq_ckpt_path, device='cpu') # or 'cuda' if on GPU
|
| 12 |
+
|
| 13 |
+
# ==== OCR pipeline function ====
|
| 14 |
+
def run_pipeline(image):
|
| 15 |
+
crops = detector.predict(image)
|
| 16 |
+
recognized_texts = [recognizer.recognize(crop) for crop in crops]
|
| 17 |
+
final_output = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(recognized_texts)])
|
| 18 |
+
return final_output if recognized_texts else "No text detected."
|
| 19 |
+
|
| 20 |
+
# ==== Get sample image paths ====
|
| 21 |
+
example_images = [f"samples/{f}" for f in os.listdir("samples") if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
| 22 |
+
|
| 23 |
+
# ==== Gradio app with ONLY sample images ====
|
| 24 |
+
demo = gr.Interface(
|
| 25 |
+
fn=run_pipeline,
|
| 26 |
+
inputs=gr.Image(type="pil", label="Choose a sample image"),
|
| 27 |
+
outputs=gr.Textbox(label="Recognized Text"),
|
| 28 |
+
examples=[[img] for img in example_images], # list of lists required
|
| 29 |
+
title="Two-Stage OCR Network for Aero Engine Blades Serial Number",
|
| 30 |
+
description="Choose only one of the predefined image. The model will detect text regions and recognize their contents."
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
demo.launch()
|
parseq_recognize.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from strhub.data.module import SceneTextDataModule
|
| 4 |
+
from strhub.models.utils import load_from_checkpoint
|
| 5 |
+
|
| 6 |
+
class TextRecognizer:
|
| 7 |
+
def __init__(self, ckpt_path, device='cpu'):
|
| 8 |
+
self.device = device
|
| 9 |
+
self.parseq = load_from_checkpoint(ckpt_path).eval().to(device)
|
| 10 |
+
self.img_transform = SceneTextDataModule.get_transform(self.parseq.hparams.img_size)
|
| 11 |
+
|
| 12 |
+
def recognize(self, image_pil):
|
| 13 |
+
image_tensor = self.img_transform(image_pil).unsqueeze(0).to(self.device)
|
| 14 |
+
with torch.no_grad():
|
| 15 |
+
logits = self.parseq(image_tensor)
|
| 16 |
+
pred = logits.softmax(-1)
|
| 17 |
+
label, _ = self.parseq.tokenizer.decode(pred)
|
| 18 |
+
return label[0]
|
pretrained_model/parseq.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c95fbe3efe9c59f71e7f75761b7b70b5ed5097e7f502cf138d6eded042f7c073
|
| 3 |
+
size 96584214
|
pretrained_model/yolo_obb.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:739cd8cd3f49a3f466cbdee965dcc3720331404d0de5787881bb8a95992dd6e1
|
| 3 |
+
size 5715964
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
ultralytics
|
| 3 |
+
torch
|
| 4 |
+
torchvision
|
| 5 |
+
strhub
|
| 6 |
+
Pillow
|
| 7 |
+
opencv-python
|
| 8 |
+
numpy
|
samples/T1.png
ADDED
|
Git LFS Details
|
samples/T2.png
ADDED
|
Git LFS Details
|
samples/image_0004.png
ADDED
|
Git LFS Details
|
samples/image_0082.png
ADDED
|
Git LFS Details
|
strhub/__init__.py
ADDED
|
File without changes
|
strhub/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (166 Bytes). View file
|
|
|
strhub/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
strhub/data/__init__.py
ADDED
|
File without changes
|
strhub/data/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
strhub/data/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (159 Bytes). View file
|
|
|
strhub/data/__pycache__/aa_overrides.cpython-312.pyc
ADDED
|
Binary file (1.67 kB). View file
|
|
|
strhub/data/__pycache__/augment.cpython-312.pyc
ADDED
|
Binary file (5.2 kB). View file
|
|
|
strhub/data/__pycache__/dataset.cpython-311.pyc
ADDED
|
Binary file (8.05 kB). View file
|
|
|
strhub/data/__pycache__/dataset.cpython-312.pyc
ADDED
|
Binary file (7.07 kB). View file
|
|
|
strhub/data/__pycache__/module.cpython-311.pyc
ADDED
|
Binary file (7.19 kB). View file
|
|
|
strhub/data/__pycache__/module.cpython-312.pyc
ADDED
|
Binary file (6.65 kB). View file
|
|
|
strhub/data/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
strhub/data/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (8.84 kB). View file
|
|
|
strhub/data/aa_overrides.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Scene Text Recognition Model Hub
|
| 2 |
+
# Copyright 2022 Darwin Bautista
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Extends default ops to accept optional parameters."""
|
| 17 |
+
from functools import partial
|
| 18 |
+
|
| 19 |
+
from timm.data.auto_augment import _LEVEL_DENOM, LEVEL_TO_ARG, NAME_TO_OP, _randomly_negate, rotate
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def rotate_expand(img, degrees, **kwargs):
|
| 23 |
+
"""Rotate operation with expand=True to avoid cutting off the characters"""
|
| 24 |
+
kwargs['expand'] = True
|
| 25 |
+
return rotate(img, degrees, **kwargs)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _level_to_arg(level, hparams, key, default):
|
| 29 |
+
magnitude = hparams.get(key, default)
|
| 30 |
+
level = (level / _LEVEL_DENOM) * magnitude
|
| 31 |
+
level = _randomly_negate(level)
|
| 32 |
+
return (level,)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def apply():
|
| 36 |
+
# Overrides
|
| 37 |
+
NAME_TO_OP.update({
|
| 38 |
+
'Rotate': rotate_expand,
|
| 39 |
+
})
|
| 40 |
+
LEVEL_TO_ARG.update({
|
| 41 |
+
'Rotate': partial(_level_to_arg, key='rotate_deg', default=30.0),
|
| 42 |
+
'ShearX': partial(_level_to_arg, key='shear_x_pct', default=0.3),
|
| 43 |
+
'ShearY': partial(_level_to_arg, key='shear_y_pct', default=0.3),
|
| 44 |
+
'TranslateXRel': partial(_level_to_arg, key='translate_x_pct', default=0.45),
|
| 45 |
+
'TranslateYRel': partial(_level_to_arg, key='translate_y_pct', default=0.45),
|
| 46 |
+
})
|
strhub/data/augment.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Scene Text Recognition Model Hub
|
| 2 |
+
# Copyright 2022 Darwin Bautista
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from functools import partial
|
| 17 |
+
|
| 18 |
+
import imgaug.augmenters as iaa
|
| 19 |
+
import numpy as np
|
| 20 |
+
from PIL import Image, ImageFilter
|
| 21 |
+
|
| 22 |
+
from timm.data import auto_augment
|
| 23 |
+
|
| 24 |
+
from strhub.data import aa_overrides
|
| 25 |
+
|
| 26 |
+
aa_overrides.apply()
|
| 27 |
+
|
| 28 |
+
_OP_CACHE = {}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _get_op(key, factory):
|
| 32 |
+
try:
|
| 33 |
+
op = _OP_CACHE[key]
|
| 34 |
+
except KeyError:
|
| 35 |
+
op = factory()
|
| 36 |
+
_OP_CACHE[key] = op
|
| 37 |
+
return op
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _get_param(level, img, max_dim_factor, min_level=1):
|
| 41 |
+
max_level = max(min_level, max_dim_factor * max(img.size))
|
| 42 |
+
return round(min(level, max_level))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def gaussian_blur(img, radius, **__):
|
| 46 |
+
radius = _get_param(radius, img, 0.02)
|
| 47 |
+
key = 'gaussian_blur_' + str(radius)
|
| 48 |
+
op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius))
|
| 49 |
+
return img.filter(op)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def motion_blur(img, k, **__):
|
| 53 |
+
k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values
|
| 54 |
+
key = 'motion_blur_' + str(k)
|
| 55 |
+
op = _get_op(key, lambda: iaa.MotionBlur(k))
|
| 56 |
+
return Image.fromarray(op(image=np.asarray(img)))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def gaussian_noise(img, scale, **_):
|
| 60 |
+
scale = _get_param(scale, img, 0.25) | 1 # bin to odd values
|
| 61 |
+
key = 'gaussian_noise_' + str(scale)
|
| 62 |
+
op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale))
|
| 63 |
+
return Image.fromarray(op(image=np.asarray(img)))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def poisson_noise(img, lam, **_):
|
| 67 |
+
lam = _get_param(lam, img, 0.2) | 1 # bin to odd values
|
| 68 |
+
key = 'poisson_noise_' + str(lam)
|
| 69 |
+
op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam))
|
| 70 |
+
return Image.fromarray(op(image=np.asarray(img)))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _level_to_arg(level, _hparams, max):
|
| 74 |
+
level = max * level / auto_augment._LEVEL_DENOM
|
| 75 |
+
return (level,)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
_RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy()
|
| 79 |
+
_RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops
|
| 80 |
+
_RAND_TRANSFORMS.extend([
|
| 81 |
+
'GaussianBlur',
|
| 82 |
+
# 'MotionBlur',
|
| 83 |
+
# 'GaussianNoise',
|
| 84 |
+
'PoissonNoise',
|
| 85 |
+
])
|
| 86 |
+
auto_augment.LEVEL_TO_ARG.update({
|
| 87 |
+
'GaussianBlur': partial(_level_to_arg, max=4),
|
| 88 |
+
'MotionBlur': partial(_level_to_arg, max=20),
|
| 89 |
+
'GaussianNoise': partial(_level_to_arg, max=0.1 * 255),
|
| 90 |
+
'PoissonNoise': partial(_level_to_arg, max=40),
|
| 91 |
+
})
|
| 92 |
+
auto_augment.NAME_TO_OP.update({
|
| 93 |
+
'GaussianBlur': gaussian_blur,
|
| 94 |
+
'MotionBlur': motion_blur,
|
| 95 |
+
'GaussianNoise': gaussian_noise,
|
| 96 |
+
'PoissonNoise': poisson_noise,
|
| 97 |
+
})
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def rand_augment_transform(magnitude=5, num_layers=3):
|
| 101 |
+
# These are tuned for magnitude=5, which means that effective magnitudes are half of these values.
|
| 102 |
+
hparams = {
|
| 103 |
+
'rotate_deg': 30,
|
| 104 |
+
'shear_x_pct': 0.9,
|
| 105 |
+
'shear_y_pct': 0.2,
|
| 106 |
+
'translate_x_pct': 0.10,
|
| 107 |
+
'translate_y_pct': 0.30,
|
| 108 |
+
}
|
| 109 |
+
ra_ops = auto_augment.rand_augment_ops(magnitude, hparams=hparams, transforms=_RAND_TRANSFORMS)
|
| 110 |
+
# Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice)
|
| 111 |
+
choice_weights = [1.0 / len(ra_ops) for _ in range(len(ra_ops))]
|
| 112 |
+
return auto_augment.RandAugment(ra_ops, num_layers, choice_weights)
|
strhub/data/dataset.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Scene Text Recognition Model Hub
|
| 2 |
+
# Copyright 2022 Darwin Bautista
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import glob
|
| 16 |
+
import io
|
| 17 |
+
import logging
|
| 18 |
+
import unicodedata
|
| 19 |
+
from pathlib import Path, PurePath
|
| 20 |
+
from typing import Callable, Optional, Union
|
| 21 |
+
|
| 22 |
+
import lmdb
|
| 23 |
+
from PIL import Image
|
| 24 |
+
|
| 25 |
+
from torch.utils.data import ConcatDataset, Dataset
|
| 26 |
+
|
| 27 |
+
from strhub.data.utils import CharsetAdapter
|
| 28 |
+
|
| 29 |
+
log = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs):
|
| 33 |
+
try:
|
| 34 |
+
kwargs.pop('root') # prevent 'root' from being passed via kwargs
|
| 35 |
+
except KeyError:
|
| 36 |
+
pass
|
| 37 |
+
root = Path(root).absolute()
|
| 38 |
+
log.info(f'dataset root:\t{root}')
|
| 39 |
+
datasets = []
|
| 40 |
+
for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True):
|
| 41 |
+
mdb = Path(mdb)
|
| 42 |
+
ds_name = str(mdb.parent.relative_to(root))
|
| 43 |
+
ds_root = str(mdb.parent.absolute())
|
| 44 |
+
dataset = LmdbDataset(ds_root, *args, **kwargs)
|
| 45 |
+
log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}')
|
| 46 |
+
datasets.append(dataset)
|
| 47 |
+
return ConcatDataset(datasets)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class LmdbDataset(Dataset):
|
| 51 |
+
"""Dataset interface to an LMDB database.
|
| 52 |
+
|
| 53 |
+
It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned
|
| 54 |
+
as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset.
|
| 55 |
+
Labels are transformed according to the charset.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
root: str,
|
| 61 |
+
charset: str,
|
| 62 |
+
max_label_len: int,
|
| 63 |
+
min_image_dim: int = 0,
|
| 64 |
+
remove_whitespace: bool = True,
|
| 65 |
+
normalize_unicode: bool = True,
|
| 66 |
+
unlabelled: bool = False,
|
| 67 |
+
transform: Optional[Callable] = None,
|
| 68 |
+
):
|
| 69 |
+
self._env = None
|
| 70 |
+
self.root = root
|
| 71 |
+
self.unlabelled = unlabelled
|
| 72 |
+
self.transform = transform
|
| 73 |
+
self.labels = []
|
| 74 |
+
self.filtered_index_list = []
|
| 75 |
+
self.num_samples = self._preprocess_labels(
|
| 76 |
+
charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def __del__(self):
|
| 80 |
+
if self._env is not None:
|
| 81 |
+
self._env.close()
|
| 82 |
+
self._env = None
|
| 83 |
+
|
| 84 |
+
def _create_env(self):
|
| 85 |
+
return lmdb.open(
|
| 86 |
+
self.root, max_readers=1, readonly=True, create=False, readahead=False, meminit=False, lock=False
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def env(self):
|
| 91 |
+
if self._env is None:
|
| 92 |
+
self._env = self._create_env()
|
| 93 |
+
return self._env
|
| 94 |
+
|
| 95 |
+
def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim):
|
| 96 |
+
charset_adapter = CharsetAdapter(charset)
|
| 97 |
+
with self._create_env() as env, env.begin() as txn:
|
| 98 |
+
num_samples = int(txn.get('num-samples'.encode()))
|
| 99 |
+
if self.unlabelled:
|
| 100 |
+
return num_samples
|
| 101 |
+
for index in range(num_samples):
|
| 102 |
+
index += 1 # lmdb starts with 1
|
| 103 |
+
label_key = f'label-{index:09d}'.encode()
|
| 104 |
+
label = txn.get(label_key).decode()
|
| 105 |
+
# Normally, whitespace is removed from the labels.
|
| 106 |
+
if remove_whitespace:
|
| 107 |
+
label = ''.join(label.split())
|
| 108 |
+
# Normalize unicode composites (if any) and convert to compatible ASCII characters
|
| 109 |
+
if normalize_unicode:
|
| 110 |
+
label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode()
|
| 111 |
+
# Filter by length before removing unsupported characters. The original label might be too long.
|
| 112 |
+
if len(label) > max_label_len:
|
| 113 |
+
continue
|
| 114 |
+
label = charset_adapter(label)
|
| 115 |
+
# We filter out samples which don't contain any supported characters
|
| 116 |
+
if not label:
|
| 117 |
+
continue
|
| 118 |
+
# Filter images that are too small.
|
| 119 |
+
if min_image_dim > 0:
|
| 120 |
+
img_key = f'image-{index:09d}'.encode()
|
| 121 |
+
buf = io.BytesIO(txn.get(img_key))
|
| 122 |
+
w, h = Image.open(buf).size
|
| 123 |
+
if w < self.min_image_dim or h < self.min_image_dim:
|
| 124 |
+
continue
|
| 125 |
+
self.labels.append(label)
|
| 126 |
+
self.filtered_index_list.append(index)
|
| 127 |
+
return len(self.labels)
|
| 128 |
+
|
| 129 |
+
def __len__(self):
|
| 130 |
+
return self.num_samples
|
| 131 |
+
|
| 132 |
+
def __getitem__(self, index):
|
| 133 |
+
if self.unlabelled:
|
| 134 |
+
label = index
|
| 135 |
+
else:
|
| 136 |
+
label = self.labels[index]
|
| 137 |
+
index = self.filtered_index_list[index]
|
| 138 |
+
|
| 139 |
+
img_key = f'image-{index:09d}'.encode()
|
| 140 |
+
with self.env.begin() as txn:
|
| 141 |
+
imgbuf = txn.get(img_key)
|
| 142 |
+
buf = io.BytesIO(imgbuf)
|
| 143 |
+
img = Image.open(buf).convert('RGB')
|
| 144 |
+
|
| 145 |
+
if self.transform is not None:
|
| 146 |
+
img = self.transform(img)
|
| 147 |
+
|
| 148 |
+
return img, label
|
strhub/data/module.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Scene Text Recognition Model Hub
|
| 2 |
+
# Copyright 2022 Darwin Bautista
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from pathlib import PurePath
|
| 17 |
+
from typing import Callable, Optional, Sequence
|
| 18 |
+
|
| 19 |
+
from torch.utils.data import DataLoader
|
| 20 |
+
from torchvision import transforms as T
|
| 21 |
+
|
| 22 |
+
import pytorch_lightning as pl
|
| 23 |
+
|
| 24 |
+
from .dataset import LmdbDataset, build_tree_dataset
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SceneTextDataModule(pl.LightningDataModule):
|
| 28 |
+
TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80')
|
| 29 |
+
TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80')
|
| 30 |
+
TEST_NEW = ('ArT', 'COCOv1.4', 'Uber')
|
| 31 |
+
TEST_CUSTOM = ("blade",)
|
| 32 |
+
TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW))
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
root_dir: str,
|
| 37 |
+
train_dir: str,
|
| 38 |
+
img_size: Sequence[int],
|
| 39 |
+
max_label_length: int,
|
| 40 |
+
charset_train: str,
|
| 41 |
+
charset_test: str,
|
| 42 |
+
batch_size: int,
|
| 43 |
+
num_workers: int,
|
| 44 |
+
augment: bool,
|
| 45 |
+
remove_whitespace: bool = True,
|
| 46 |
+
normalize_unicode: bool = True,
|
| 47 |
+
min_image_dim: int = 0,
|
| 48 |
+
rotation: int = 0,
|
| 49 |
+
collate_fn: Optional[Callable] = None,
|
| 50 |
+
):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.root_dir = root_dir
|
| 53 |
+
self.train_dir = train_dir
|
| 54 |
+
self.img_size = tuple(img_size)
|
| 55 |
+
self.max_label_length = max_label_length
|
| 56 |
+
self.charset_train = charset_train
|
| 57 |
+
self.charset_test = charset_test
|
| 58 |
+
self.batch_size = batch_size
|
| 59 |
+
self.num_workers = num_workers
|
| 60 |
+
self.augment = augment
|
| 61 |
+
self.remove_whitespace = remove_whitespace
|
| 62 |
+
self.normalize_unicode = normalize_unicode
|
| 63 |
+
self.min_image_dim = min_image_dim
|
| 64 |
+
self.rotation = rotation
|
| 65 |
+
self.collate_fn = collate_fn
|
| 66 |
+
self._train_dataset = None
|
| 67 |
+
self._val_dataset = None
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def get_transform(img_size: tuple[int], augment: bool = False, rotation: int = 0):
|
| 71 |
+
transforms = []
|
| 72 |
+
if augment:
|
| 73 |
+
from .augment import rand_augment_transform
|
| 74 |
+
|
| 75 |
+
transforms.append(rand_augment_transform())
|
| 76 |
+
if rotation:
|
| 77 |
+
transforms.append(lambda img: img.rotate(rotation, expand=True))
|
| 78 |
+
transforms.extend([
|
| 79 |
+
T.Resize(img_size, T.InterpolationMode.BICUBIC),
|
| 80 |
+
T.ToTensor(),
|
| 81 |
+
T.Normalize(0.5, 0.5),
|
| 82 |
+
])
|
| 83 |
+
return T.Compose(transforms)
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def train_dataset(self):
|
| 87 |
+
if self._train_dataset is None:
|
| 88 |
+
transform = self.get_transform(self.img_size, self.augment)
|
| 89 |
+
root = PurePath(self.root_dir, 'train', self.train_dir)
|
| 90 |
+
self._train_dataset = build_tree_dataset(
|
| 91 |
+
root,
|
| 92 |
+
self.charset_train,
|
| 93 |
+
self.max_label_length,
|
| 94 |
+
self.min_image_dim,
|
| 95 |
+
self.remove_whitespace,
|
| 96 |
+
self.normalize_unicode,
|
| 97 |
+
transform=transform,
|
| 98 |
+
)
|
| 99 |
+
return self._train_dataset
|
| 100 |
+
|
| 101 |
+
@property
|
| 102 |
+
def val_dataset(self):
|
| 103 |
+
if self._val_dataset is None:
|
| 104 |
+
transform = self.get_transform(self.img_size)
|
| 105 |
+
root = PurePath(self.root_dir, 'val')
|
| 106 |
+
self._val_dataset = build_tree_dataset(
|
| 107 |
+
root,
|
| 108 |
+
self.charset_test,
|
| 109 |
+
self.max_label_length,
|
| 110 |
+
self.min_image_dim,
|
| 111 |
+
self.remove_whitespace,
|
| 112 |
+
self.normalize_unicode,
|
| 113 |
+
transform=transform,
|
| 114 |
+
)
|
| 115 |
+
return self._val_dataset
|
| 116 |
+
|
| 117 |
+
def train_dataloader(self):
|
| 118 |
+
return DataLoader(
|
| 119 |
+
self.train_dataset,
|
| 120 |
+
batch_size=self.batch_size,
|
| 121 |
+
shuffle=True,
|
| 122 |
+
num_workers=self.num_workers,
|
| 123 |
+
persistent_workers=self.num_workers > 0,
|
| 124 |
+
pin_memory=True,
|
| 125 |
+
collate_fn=self.collate_fn,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def val_dataloader(self):
|
| 129 |
+
return DataLoader(
|
| 130 |
+
self.val_dataset,
|
| 131 |
+
batch_size=self.batch_size,
|
| 132 |
+
num_workers=self.num_workers,
|
| 133 |
+
persistent_workers=self.num_workers > 0,
|
| 134 |
+
pin_memory=True,
|
| 135 |
+
collate_fn=self.collate_fn,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def test_dataloaders(self, subset):
|
| 139 |
+
transform = self.get_transform(self.img_size, rotation=self.rotation)
|
| 140 |
+
root = PurePath(self.root_dir, 'test')
|
| 141 |
+
datasets = {
|
| 142 |
+
s: LmdbDataset(
|
| 143 |
+
str(root / s),
|
| 144 |
+
self.charset_test,
|
| 145 |
+
self.max_label_length,
|
| 146 |
+
self.min_image_dim,
|
| 147 |
+
self.remove_whitespace,
|
| 148 |
+
self.normalize_unicode,
|
| 149 |
+
transform=transform,
|
| 150 |
+
)
|
| 151 |
+
for s in subset
|
| 152 |
+
}
|
| 153 |
+
return {
|
| 154 |
+
k: DataLoader(
|
| 155 |
+
v, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=self.collate_fn
|
| 156 |
+
)
|
| 157 |
+
for k, v in datasets.items()
|
| 158 |
+
}
|
strhub/data/utils.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Scene Text Recognition Model Hub
|
| 2 |
+
# Copyright 2022 Darwin Bautista
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import re
|
| 17 |
+
from abc import ABC, abstractmethod
|
| 18 |
+
from itertools import groupby
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch import Tensor
|
| 23 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CharsetAdapter:
|
| 27 |
+
"""Transforms labels according to the target charset."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, target_charset) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.lowercase_only = target_charset == target_charset.lower()
|
| 32 |
+
self.uppercase_only = target_charset == target_charset.upper()
|
| 33 |
+
self.unsupported = re.compile(f'[^{re.escape(target_charset)}]')
|
| 34 |
+
|
| 35 |
+
def __call__(self, label):
|
| 36 |
+
if self.lowercase_only:
|
| 37 |
+
label = label.lower()
|
| 38 |
+
elif self.uppercase_only:
|
| 39 |
+
label = label.upper()
|
| 40 |
+
# Remove unsupported characters
|
| 41 |
+
label = self.unsupported.sub('', label)
|
| 42 |
+
return label
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class BaseTokenizer(ABC):
|
| 46 |
+
|
| 47 |
+
def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
|
| 48 |
+
self._itos = specials_first + tuple(charset) + specials_last
|
| 49 |
+
self._stoi = {s: i for i, s in enumerate(self._itos)}
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return len(self._itos)
|
| 53 |
+
|
| 54 |
+
def _tok2ids(self, tokens: str) -> list[int]:
|
| 55 |
+
return [self._stoi[s] for s in tokens]
|
| 56 |
+
|
| 57 |
+
def _ids2tok(self, token_ids: list[int], join: bool = True) -> str:
|
| 58 |
+
tokens = [self._itos[i] for i in token_ids]
|
| 59 |
+
return ''.join(tokens) if join else tokens
|
| 60 |
+
|
| 61 |
+
@abstractmethod
|
| 62 |
+
def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
|
| 63 |
+
"""Encode a batch of labels to a representation suitable for the model.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
labels: List of labels. Each can be of arbitrary length.
|
| 67 |
+
device: Create tensor on this device.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Batched tensor representation padded to the max label length. Shape: N, L
|
| 71 |
+
"""
|
| 72 |
+
raise NotImplementedError
|
| 73 |
+
|
| 74 |
+
@abstractmethod
|
| 75 |
+
def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
|
| 76 |
+
"""Internal method which performs the necessary filtering prior to decoding."""
|
| 77 |
+
raise NotImplementedError
|
| 78 |
+
|
| 79 |
+
def decode(self, token_dists: Tensor, raw: bool = False) -> tuple[list[str], list[Tensor]]:
|
| 80 |
+
"""Decode a batch of token distributions.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
token_dists: softmax probabilities over the token distribution. Shape: N, L, C
|
| 84 |
+
raw: return unprocessed labels (will return list of list of strings)
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
list of string labels (arbitrary length) and
|
| 88 |
+
their corresponding sequence probabilities as a list of Tensors
|
| 89 |
+
"""
|
| 90 |
+
batch_tokens = []
|
| 91 |
+
batch_probs = []
|
| 92 |
+
for dist in token_dists:
|
| 93 |
+
probs, ids = dist.max(-1) # greedy selection
|
| 94 |
+
if not raw:
|
| 95 |
+
probs, ids = self._filter(probs, ids)
|
| 96 |
+
tokens = self._ids2tok(ids, not raw)
|
| 97 |
+
batch_tokens.append(tokens)
|
| 98 |
+
batch_probs.append(probs)
|
| 99 |
+
return batch_tokens, batch_probs
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class Tokenizer(BaseTokenizer):
|
| 103 |
+
BOS = '[B]'
|
| 104 |
+
EOS = '[E]'
|
| 105 |
+
PAD = '[P]'
|
| 106 |
+
|
| 107 |
+
def __init__(self, charset: str) -> None:
|
| 108 |
+
specials_first = (self.EOS,)
|
| 109 |
+
specials_last = (self.BOS, self.PAD)
|
| 110 |
+
super().__init__(charset, specials_first, specials_last)
|
| 111 |
+
self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last]
|
| 112 |
+
|
| 113 |
+
def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
|
| 114 |
+
batch = [
|
| 115 |
+
torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
|
| 116 |
+
for y in labels
|
| 117 |
+
]
|
| 118 |
+
return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
|
| 119 |
+
|
| 120 |
+
def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
|
| 121 |
+
ids = ids.tolist()
|
| 122 |
+
try:
|
| 123 |
+
eos_idx = ids.index(self.eos_id)
|
| 124 |
+
except ValueError:
|
| 125 |
+
eos_idx = len(ids) # Nothing to truncate.
|
| 126 |
+
# Truncate after EOS
|
| 127 |
+
ids = ids[:eos_idx]
|
| 128 |
+
probs = probs[: eos_idx + 1] # but include prob. for EOS (if it exists)
|
| 129 |
+
return probs, ids
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class CTCTokenizer(BaseTokenizer):
|
| 133 |
+
BLANK = '[B]'
|
| 134 |
+
|
| 135 |
+
def __init__(self, charset: str) -> None:
|
| 136 |
+
# BLANK uses index == 0 by default
|
| 137 |
+
super().__init__(charset, specials_first=(self.BLANK,))
|
| 138 |
+
self.blank_id = self._stoi[self.BLANK]
|
| 139 |
+
|
| 140 |
+
def encode(self, labels: list[str], device: Optional[torch.device] = None) -> Tensor:
|
| 141 |
+
# We use a padded representation since we don't want to use CUDNN's CTC implementation
|
| 142 |
+
batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels]
|
| 143 |
+
return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
|
| 144 |
+
|
| 145 |
+
def _filter(self, probs: Tensor, ids: Tensor) -> tuple[Tensor, list[int]]:
|
| 146 |
+
# Best path decoding:
|
| 147 |
+
ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens
|
| 148 |
+
ids = [x for x in ids if x != self.blank_id] # Remove BLANKs
|
| 149 |
+
# `probs` is just pass-through since all positions are considered part of the path
|
| 150 |
+
return probs, ids
|
strhub/models/__init__.py
ADDED
|
File without changes
|
strhub/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
strhub/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (161 Bytes). View file
|
|
|
strhub/models/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
strhub/models/__pycache__/base.cpython-312.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
strhub/models/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (9.09 kB). View file
|
|
|
strhub/models/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (8.09 kB). View file
|
|
|
strhub/models/base.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Scene Text Recognition Model Hub
|
| 2 |
+
# Copyright 2022 Darwin Bautista
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
from abc import ABC, abstractmethod
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from nltk import edit_distance
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
from torch.optim import Optimizer
|
| 27 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
| 28 |
+
|
| 29 |
+
import pytorch_lightning as pl
|
| 30 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 31 |
+
from timm.optim import create_optimizer_v2
|
| 32 |
+
|
| 33 |
+
from strhub.data.utils import BaseTokenizer, CharsetAdapter, CTCTokenizer, Tokenizer
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class BatchResult:
|
| 38 |
+
num_samples: int
|
| 39 |
+
correct: int
|
| 40 |
+
ned: float
|
| 41 |
+
confidence: float
|
| 42 |
+
label_length: int
|
| 43 |
+
loss: Tensor
|
| 44 |
+
loss_numel: int
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
EPOCH_OUTPUT = list[dict[str, BatchResult]]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class BaseSystem(pl.LightningModule, ABC):
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
tokenizer: BaseTokenizer,
|
| 55 |
+
charset_test: str,
|
| 56 |
+
batch_size: int,
|
| 57 |
+
lr: float,
|
| 58 |
+
warmup_pct: float,
|
| 59 |
+
weight_decay: float,
|
| 60 |
+
) -> None:
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.tokenizer = tokenizer
|
| 63 |
+
self.charset_adapter = CharsetAdapter(charset_test)
|
| 64 |
+
self.batch_size = batch_size
|
| 65 |
+
self.lr = lr
|
| 66 |
+
self.warmup_pct = warmup_pct
|
| 67 |
+
self.weight_decay = weight_decay
|
| 68 |
+
self.outputs: EPOCH_OUTPUT = []
|
| 69 |
+
|
| 70 |
+
@abstractmethod
|
| 71 |
+
def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
|
| 72 |
+
"""Inference
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
images: Batch of images. Shape: N, Ch, H, W
|
| 76 |
+
max_length: Max sequence length of the output. If None, will use default.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
|
| 80 |
+
"""
|
| 81 |
+
raise NotImplementedError
|
| 82 |
+
|
| 83 |
+
@abstractmethod
|
| 84 |
+
def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
|
| 85 |
+
"""Like forward(), but also computes the loss (calls forward() internally).
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
images: Batch of images. Shape: N, Ch, H, W
|
| 89 |
+
labels: Text labels of the images
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials)
|
| 93 |
+
loss: mean loss for the batch
|
| 94 |
+
loss_numel: number of elements the loss was calculated from
|
| 95 |
+
"""
|
| 96 |
+
raise NotImplementedError
|
| 97 |
+
|
| 98 |
+
def configure_optimizers(self):
|
| 99 |
+
agb = self.trainer.accumulate_grad_batches
|
| 100 |
+
# Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP.
|
| 101 |
+
lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.0
|
| 102 |
+
lr = lr_scale * self.lr
|
| 103 |
+
optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay)
|
| 104 |
+
sched = OneCycleLR(
|
| 105 |
+
optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, cycle_momentum=False
|
| 106 |
+
)
|
| 107 |
+
return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}}
|
| 108 |
+
|
| 109 |
+
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer) -> None:
|
| 110 |
+
optimizer.zero_grad(set_to_none=True)
|
| 111 |
+
|
| 112 |
+
def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]:
|
| 113 |
+
images, labels = batch
|
| 114 |
+
|
| 115 |
+
correct = 0
|
| 116 |
+
total = 0
|
| 117 |
+
ned = 0
|
| 118 |
+
confidence = 0
|
| 119 |
+
label_length = 0
|
| 120 |
+
if validation:
|
| 121 |
+
logits, loss, loss_numel = self.forward_logits_loss(images, labels)
|
| 122 |
+
else:
|
| 123 |
+
# At test-time, we shouldn't specify a max_label_length because the test-time charset used
|
| 124 |
+
# might be different from the train-time charset. max_label_length in eval_logits_loss() is computed
|
| 125 |
+
# based on the transformed label, which could be wrong if the actual gt label contains characters existing
|
| 126 |
+
# in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com"
|
| 127 |
+
# is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters
|
| 128 |
+
# long only, which sets max_label_length = 23. This will cause the model prediction to be truncated.
|
| 129 |
+
logits = self.forward(images)
|
| 130 |
+
loss = loss_numel = None # Only used for validation; not needed at test-time.
|
| 131 |
+
|
| 132 |
+
probs = logits.softmax(-1)
|
| 133 |
+
preds, probs = self.tokenizer.decode(probs)
|
| 134 |
+
for pred, prob, gt in zip(preds, probs, labels):
|
| 135 |
+
confidence += prob.prod().item()
|
| 136 |
+
pred = self.charset_adapter(pred)
|
| 137 |
+
# Follow ICDAR 2019 definition of N.E.D.
|
| 138 |
+
ned += edit_distance(pred, gt) / max(len(pred), len(gt))
|
| 139 |
+
if pred == gt:
|
| 140 |
+
correct += 1
|
| 141 |
+
total += 1
|
| 142 |
+
label_length += len(pred)
|
| 143 |
+
return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel))
|
| 144 |
+
|
| 145 |
+
@staticmethod
|
| 146 |
+
def _aggregate_results(outputs: EPOCH_OUTPUT) -> tuple[float, float, float]:
|
| 147 |
+
if not outputs:
|
| 148 |
+
return 0.0, 0.0, 0.0
|
| 149 |
+
total_loss = 0
|
| 150 |
+
total_loss_numel = 0
|
| 151 |
+
total_n_correct = 0
|
| 152 |
+
total_norm_ED = 0
|
| 153 |
+
total_size = 0
|
| 154 |
+
for result in outputs:
|
| 155 |
+
result = result['output']
|
| 156 |
+
total_loss += result.loss_numel * result.loss
|
| 157 |
+
total_loss_numel += result.loss_numel
|
| 158 |
+
total_n_correct += result.correct
|
| 159 |
+
total_norm_ED += result.ned
|
| 160 |
+
total_size += result.num_samples
|
| 161 |
+
acc = total_n_correct / total_size
|
| 162 |
+
ned = 1 - total_norm_ED / total_size
|
| 163 |
+
loss = total_loss / total_loss_numel
|
| 164 |
+
return acc, ned, loss
|
| 165 |
+
|
| 166 |
+
def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
|
| 167 |
+
result = self._eval_step(batch, True)
|
| 168 |
+
self.outputs.append(result)
|
| 169 |
+
return result
|
| 170 |
+
|
| 171 |
+
def on_validation_epoch_end(self) -> None:
|
| 172 |
+
acc, ned, loss = self._aggregate_results(self.outputs)
|
| 173 |
+
self.outputs.clear()
|
| 174 |
+
self.log('val_accuracy', 100 * acc, sync_dist=True)
|
| 175 |
+
self.log('val_NED', 100 * ned, sync_dist=True)
|
| 176 |
+
self.log('val_loss', loss, sync_dist=True)
|
| 177 |
+
self.log('hp_metric', acc, sync_dist=True)
|
| 178 |
+
|
| 179 |
+
def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
|
| 180 |
+
return self._eval_step(batch, False)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class CrossEntropySystem(BaseSystem):
|
| 184 |
+
|
| 185 |
+
def __init__(
|
| 186 |
+
self, charset_train: str, charset_test: str, batch_size: int, lr: float, warmup_pct: float, weight_decay: float
|
| 187 |
+
) -> None:
|
| 188 |
+
tokenizer = Tokenizer(charset_train)
|
| 189 |
+
super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
|
| 190 |
+
self.bos_id = tokenizer.bos_id
|
| 191 |
+
self.eos_id = tokenizer.eos_id
|
| 192 |
+
self.pad_id = tokenizer.pad_id
|
| 193 |
+
|
| 194 |
+
def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
|
| 195 |
+
targets = self.tokenizer.encode(labels, self.device)
|
| 196 |
+
targets = targets[:, 1:] # Discard <bos>
|
| 197 |
+
max_len = targets.shape[1] - 1 # exclude <eos> from count
|
| 198 |
+
logits = self.forward(images, max_len)
|
| 199 |
+
loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id)
|
| 200 |
+
loss_numel = (targets != self.pad_id).sum()
|
| 201 |
+
return logits, loss, loss_numel
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class CTCSystem(BaseSystem):
|
| 205 |
+
|
| 206 |
+
def __init__(
|
| 207 |
+
self, charset_train: str, charset_test: str, batch_size: int, lr: float, warmup_pct: float, weight_decay: float
|
| 208 |
+
) -> None:
|
| 209 |
+
tokenizer = CTCTokenizer(charset_train)
|
| 210 |
+
super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay)
|
| 211 |
+
self.blank_id = tokenizer.blank_id
|
| 212 |
+
|
| 213 |
+
def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]:
|
| 214 |
+
targets = self.tokenizer.encode(labels, self.device)
|
| 215 |
+
logits = self.forward(images)
|
| 216 |
+
log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims
|
| 217 |
+
T, N, _ = log_probs.shape
|
| 218 |
+
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device)
|
| 219 |
+
target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device)
|
| 220 |
+
loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True)
|
| 221 |
+
return logits, loss, N
|
strhub/models/modules.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""Shared modules used by CRNN and TRBA"""
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BidirectionalLSTM(nn.Module):
|
| 6 |
+
"""Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py"""
|
| 7 |
+
|
| 8 |
+
def __init__(self, input_size, hidden_size, output_size):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
|
| 11 |
+
self.linear = nn.Linear(hidden_size * 2, output_size)
|
| 12 |
+
|
| 13 |
+
def forward(self, input):
|
| 14 |
+
"""
|
| 15 |
+
input : visual feature [batch_size x T x input_size], T = num_steps.
|
| 16 |
+
output : contextual feature [batch_size x T x output_size]
|
| 17 |
+
"""
|
| 18 |
+
recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
|
| 19 |
+
output = self.linear(recurrent) # batch_size x T x output_size
|
| 20 |
+
return output
|
strhub/models/parseq/__init__.py
ADDED
|
File without changes
|
strhub/models/parseq/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (180 Bytes). View file
|
|
|
strhub/models/parseq/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (168 Bytes). View file
|
|
|
strhub/models/parseq/__pycache__/model.cpython-311.pyc
ADDED
|
Binary file (8.97 kB). View file
|
|
|
strhub/models/parseq/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (8.27 kB). View file
|
|
|
strhub/models/parseq/__pycache__/modules.cpython-311.pyc
ADDED
|
Binary file (8.83 kB). View file
|
|
|
strhub/models/parseq/__pycache__/modules.cpython-312.pyc
ADDED
|
Binary file (7.73 kB). View file
|
|
|
strhub/models/parseq/__pycache__/system.cpython-311.pyc
ADDED
|
Binary file (9.4 kB). View file
|
|
|
strhub/models/parseq/__pycache__/system.cpython-312.pyc
ADDED
|
Binary file (8.59 kB). View file
|
|
|