Spaces:
Runtime error
Runtime error
source code
Browse files- src/Text_Localization/__pycache__/prepare_dataset.cpython-312.pyc +0 -0
- src/Text_Localization/prepare_dataset.py +177 -0
- src/Text_Localization/text_localization.py +33 -0
- src/Text_Recognization/__pycache__/dataloader.cpython-312.pyc +0 -0
- src/Text_Recognization/__pycache__/prepare_dataset.cpython-312.pyc +0 -0
- src/Text_Recognization/__pycache__/text_recognization.cpython-312.pyc +0 -0
- src/Text_Recognization/dataloader.py +96 -0
- src/Text_Recognization/prepare_dataset.py +221 -0
- src/Text_Recognization/text_recognization.py +64 -0
- src/Text_Recognization/trainer.py +162 -0
- src/config.json +24 -0
- src/predict.py +137 -0
src/Text_Localization/__pycache__/prepare_dataset.cpython-312.pyc
ADDED
|
Binary file (7.86 kB). View file
|
|
|
src/Text_Localization/prepare_dataset.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import cv2
|
| 5 |
+
import xml.etree.ElementTree as ET
|
| 6 |
+
import shutil
|
| 7 |
+
import yaml
|
| 8 |
+
|
| 9 |
+
from sklearn.model_selection import train_test_split
|
| 10 |
+
|
| 11 |
+
location_path = r'Dataset/locations.xml'
|
| 12 |
+
tree = ET.parse(location_path)
|
| 13 |
+
root = tree.getroot()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_gt_bboxes(location_path):
|
| 17 |
+
"""get all the gt bbox of text in dataset
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
location_path: (path)
|
| 21 |
+
Return:
|
| 22 |
+
gt_imagepaths[1] (list): image's name
|
| 23 |
+
gt_locations (list): bboxes of each image
|
| 24 |
+
"""
|
| 25 |
+
gt_imagepaths = []
|
| 26 |
+
gt_imagesizes = []
|
| 27 |
+
gt_locations = []
|
| 28 |
+
|
| 29 |
+
for image in root:
|
| 30 |
+
# get path to image
|
| 31 |
+
image_name = image[0].text
|
| 32 |
+
image_path = os.path.join('Dataset', image_name)
|
| 33 |
+
gt_imagepaths.append(image_path)
|
| 34 |
+
|
| 35 |
+
# get the image size
|
| 36 |
+
w = image[1].get('x')
|
| 37 |
+
h = image[1].get('y')
|
| 38 |
+
gt_imagesizes.append([h, w])
|
| 39 |
+
|
| 40 |
+
# bboxes in the image
|
| 41 |
+
bbs = []
|
| 42 |
+
for bbox in image[2]:
|
| 43 |
+
x = np.int64(float(bbox.get('x')))
|
| 44 |
+
y = np.int64(float(bbox.get('y')))
|
| 45 |
+
width = np.int64(float(bbox.get('width')))
|
| 46 |
+
height = np.int64(float(bbox.get('height')))
|
| 47 |
+
bbs.append([x, y, width, height])
|
| 48 |
+
|
| 49 |
+
gt_locations.append(bbs)
|
| 50 |
+
|
| 51 |
+
return gt_imagepaths, gt_imagesizes, gt_locations
|
| 52 |
+
|
| 53 |
+
gt_imagepaths, gt_imagesizes, gt_locations = get_gt_bboxes(location_path)
|
| 54 |
+
|
| 55 |
+
def visualize_gt_bboxes(image_path, gt_locations):
|
| 56 |
+
image = cv2.imread(image_path)
|
| 57 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 58 |
+
|
| 59 |
+
for gt_location in gt_locations:
|
| 60 |
+
x, y, width, height = gt_location
|
| 61 |
+
x, y, width, height = int(x), int(y), int(width), int(height)
|
| 62 |
+
|
| 63 |
+
image = cv2.rectangle(image, (x, y), (x+width, y+height), color=(255, 0, 0), thickness=2)
|
| 64 |
+
|
| 65 |
+
plt.imshow(image)
|
| 66 |
+
plt.axis('off')
|
| 67 |
+
plt.show()
|
| 68 |
+
|
| 69 |
+
def visualize_gt_bboxes_yolo(image_path, gt_location_yolo):
|
| 70 |
+
image = cv2.imread(image_path)
|
| 71 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 72 |
+
|
| 73 |
+
image_height, image_width = image.shape[:2]
|
| 74 |
+
|
| 75 |
+
# Convert to original format
|
| 76 |
+
for data in gt_location_yolo:
|
| 77 |
+
xc, yc, w, h = data[1:]
|
| 78 |
+
xmin = int((xc - w/2) * image_width)
|
| 79 |
+
ymin = int((yc - h/2) * image_height)
|
| 80 |
+
xmax = int((xc + w/2) * image_width)
|
| 81 |
+
ymax = int((yc + h/2) * image_height)
|
| 82 |
+
|
| 83 |
+
image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=(255, 0, 0), thickness=2)
|
| 84 |
+
|
| 85 |
+
plt.imshow(image)
|
| 86 |
+
plt.axis('off')
|
| 87 |
+
plt.show()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def convert_yolo_format(gt_locations, gt_imagesizes):
|
| 91 |
+
gt_locations_yolo = []
|
| 92 |
+
|
| 93 |
+
for image, image_size in zip(gt_locations, gt_imagesizes):
|
| 94 |
+
gt_location_yolo = []
|
| 95 |
+
for gt_location in image:
|
| 96 |
+
x, y, w, h = gt_location
|
| 97 |
+
image_height, image_width = image_size
|
| 98 |
+
|
| 99 |
+
xc = (x + w/2) / float(image_width)
|
| 100 |
+
yc = (y + h/2) / float(image_height)
|
| 101 |
+
width = w / float(image_width)
|
| 102 |
+
height = h / float(image_height)
|
| 103 |
+
|
| 104 |
+
# class = 0 -> meaning contains text
|
| 105 |
+
class_id = 0
|
| 106 |
+
gt_location_yolo.append([class_id, xc, yc, width, height])
|
| 107 |
+
|
| 108 |
+
gt_locations_yolo.append(gt_location_yolo)
|
| 109 |
+
|
| 110 |
+
return gt_locations_yolo
|
| 111 |
+
|
| 112 |
+
gt_locations_yolo = convert_yolo_format(gt_locations, gt_imagesizes)
|
| 113 |
+
|
| 114 |
+
def save_data_into_yolo_folder(data, src_img_dir, save_dir):
|
| 115 |
+
# Create folder if not exist
|
| 116 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 117 |
+
|
| 118 |
+
# Make images and labels folder
|
| 119 |
+
os.makedirs(os.path.join(save_dir, 'images'), exist_ok=True)
|
| 120 |
+
os.makedirs(os.path.join(save_dir, 'labels'), exist_ok=True)
|
| 121 |
+
|
| 122 |
+
# write data into yolo folder
|
| 123 |
+
for dt in data:
|
| 124 |
+
# copy data
|
| 125 |
+
image_path = dt[0]
|
| 126 |
+
shutil.copy(image_path, os.path.join(save_dir, 'images'))
|
| 127 |
+
|
| 128 |
+
#copy labels
|
| 129 |
+
image_name = os.path.basename(image_path)
|
| 130 |
+
image_name = os.path.splitext(image_name)[0]
|
| 131 |
+
|
| 132 |
+
with open(os.path.join(save_dir, 'labels', f'{image_name}.txt'), "w") as f:
|
| 133 |
+
for label in dt[1]:
|
| 134 |
+
label_str = " ".join(map(str, label))
|
| 135 |
+
f.write(f'{label_str}\n')
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
seed = 0
|
| 140 |
+
val_size = 0.2
|
| 141 |
+
test_size = 0.125
|
| 142 |
+
dataset = [[gt_imagepath, gt_location_yolo] for gt_imagepath, gt_location_yolo in zip(gt_imagepaths, gt_locations_yolo)]
|
| 143 |
+
train_data, val_data = train_test_split(dataset, test_size=val_size, random_state=42, shuffle=True)
|
| 144 |
+
train_data, test_data = train_test_split(train_data, test_size=test_size, random_state=42, shuffle=True)
|
| 145 |
+
|
| 146 |
+
save_yolo_data_dir = 'yolo_data'
|
| 147 |
+
os.makedirs(save_yolo_data_dir, exist_ok=True)
|
| 148 |
+
save_data_into_yolo_folder(
|
| 149 |
+
data=train_data,
|
| 150 |
+
src_img_dir=save_yolo_data_dir,
|
| 151 |
+
save_dir=os.path.join(save_yolo_data_dir, 'train')
|
| 152 |
+
)
|
| 153 |
+
save_data_into_yolo_folder(
|
| 154 |
+
data=val_data,
|
| 155 |
+
src_img_dir=save_yolo_data_dir,
|
| 156 |
+
save_dir=os.path.join(save_yolo_data_dir, 'val')
|
| 157 |
+
)
|
| 158 |
+
save_data_into_yolo_folder(
|
| 159 |
+
data=test_data,
|
| 160 |
+
src_img_dir=save_yolo_data_dir,
|
| 161 |
+
save_dir=os.path.join(save_yolo_data_dir, 'test')
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
class_label = ['text']
|
| 165 |
+
# Create data.yaml file
|
| 166 |
+
data_yaml = {
|
| 167 |
+
"path": '../yolo_data',
|
| 168 |
+
'train': 'train/images',
|
| 169 |
+
'test': 'test/images',
|
| 170 |
+
'val': 'val/images',
|
| 171 |
+
'nc': 1,
|
| 172 |
+
'names': class_label
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
yolo_yaml_path = os.path.join(save_yolo_data_dir, 'data.yaml')
|
| 176 |
+
with open(yolo_yaml_path, "w") as f:
|
| 177 |
+
yaml.dump(data_yaml, f, default_flow_style=False)
|
src/Text_Localization/text_localization.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ultralytics
|
| 2 |
+
import json
|
| 3 |
+
from ultralytics import YOLO
|
| 4 |
+
|
| 5 |
+
yolo_yaml_path = 'yolo_data/data.yaml'
|
| 6 |
+
config_path = 'src/config.json'
|
| 7 |
+
|
| 8 |
+
def load_json_config(config_path):
|
| 9 |
+
with open(config_path, "r") as f:
|
| 10 |
+
config = json.load(f)
|
| 11 |
+
|
| 12 |
+
return config
|
| 13 |
+
|
| 14 |
+
config = load_json_config(config_path)
|
| 15 |
+
|
| 16 |
+
# Load model
|
| 17 |
+
model = YOLO('yolo11m.pt')
|
| 18 |
+
|
| 19 |
+
# Train model
|
| 20 |
+
results = model.train(
|
| 21 |
+
data=yolo_yaml_path,
|
| 22 |
+
epochs=config['yolov11']['epochs'],
|
| 23 |
+
imgsz=config['yolov11']['image_size'],
|
| 24 |
+
cache=config['yolov11']['cache'],
|
| 25 |
+
patience=config['yolov11']['patience'],
|
| 26 |
+
plots=config['yolov11']['plots']
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Evaluate model
|
| 30 |
+
model_path = 'checkpoints/yolov11m.pt'
|
| 31 |
+
model = YOLO(model_path)
|
| 32 |
+
|
| 33 |
+
metrics = model.val()
|
src/Text_Recognization/__pycache__/dataloader.cpython-312.pyc
ADDED
|
Binary file (4.38 kB). View file
|
|
|
src/Text_Recognization/__pycache__/prepare_dataset.cpython-312.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
src/Text_Recognization/__pycache__/text_recognization.cpython-312.pyc
ADDED
|
Binary file (3.54 kB). View file
|
|
|
src/Text_Recognization/dataloader.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
import json
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
from torch.utils.data import Dataset, DataLoader
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from sklearn.model_selection import train_test_split
|
| 9 |
+
from src.Text_Recognization.prepare_dataset import *
|
| 10 |
+
|
| 11 |
+
# data augmentation
|
| 12 |
+
data_transforms = {
|
| 13 |
+
"train": transforms.Compose(
|
| 14 |
+
[
|
| 15 |
+
transforms.ToTensor(),
|
| 16 |
+
transforms.Resize((100, 400)),
|
| 17 |
+
transforms.ColorJitter(
|
| 18 |
+
brightness=0.5,
|
| 19 |
+
contrast=0.5,
|
| 20 |
+
saturation=0.5
|
| 21 |
+
),
|
| 22 |
+
transforms.GaussianBlur(3),
|
| 23 |
+
transforms.RandomAffine(
|
| 24 |
+
degrees=1,
|
| 25 |
+
shear=1
|
| 26 |
+
),
|
| 27 |
+
transforms.RandomPerspective(
|
| 28 |
+
distortion_scale=0.3,
|
| 29 |
+
p=0.5
|
| 30 |
+
),
|
| 31 |
+
transforms.RandomRotation(degrees=15),
|
| 32 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 33 |
+
]
|
| 34 |
+
),
|
| 35 |
+
"val": transforms.Compose(
|
| 36 |
+
[
|
| 37 |
+
transforms.ToTensor(),
|
| 38 |
+
transforms.Resize((100, 400)),
|
| 39 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 40 |
+
]
|
| 41 |
+
)
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def load_json_config(config_path):
|
| 45 |
+
with open(config_path, "r") as f:
|
| 46 |
+
config = json.load(f)
|
| 47 |
+
|
| 48 |
+
return config
|
| 49 |
+
|
| 50 |
+
# Dataloader
|
| 51 |
+
class STRDataset(Dataset):
|
| 52 |
+
def __init__(self, image_paths, labels, char_to_idx, transforms=None):
|
| 53 |
+
self.image_paths = image_paths
|
| 54 |
+
self.labels = labels
|
| 55 |
+
self.char_to_idx = char_to_idx
|
| 56 |
+
self.transforms= transforms
|
| 57 |
+
|
| 58 |
+
def __len__(self):
|
| 59 |
+
return len(self.image_paths)
|
| 60 |
+
|
| 61 |
+
def __getitem__(self, idx):
|
| 62 |
+
image = cv2.imread(self.image_paths[idx])
|
| 63 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 64 |
+
|
| 65 |
+
if self.transforms:
|
| 66 |
+
image = self.transforms(image)
|
| 67 |
+
|
| 68 |
+
label_encoded, length = encode(self.labels[idx], self.char_to_idx, self.labels)
|
| 69 |
+
|
| 70 |
+
return image, label_encoded, length
|
| 71 |
+
|
| 72 |
+
def get_dataloader():
|
| 73 |
+
val_size = 0.1
|
| 74 |
+
test_size = 0.1
|
| 75 |
+
root_path = 'Dataset'
|
| 76 |
+
config_path = 'src/config.json'
|
| 77 |
+
|
| 78 |
+
# get image paths and labels
|
| 79 |
+
image_paths, labels = get_imagepaths_and_labels(root_path)
|
| 80 |
+
char_to_idx, idx_to_char = build_vocab(root_path)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
config = load_json_config(config_path)
|
| 84 |
+
|
| 85 |
+
X_train, X_val, y_train, y_val = train_test_split(image_paths, labels, test_size=val_size, random_state=42, shuffle=True)
|
| 86 |
+
X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=test_size, random_state=42, shuffle=True)
|
| 87 |
+
train_dataset = STRDataset(X_train, y_train, char_to_idx, transforms=data_transforms['train'])
|
| 88 |
+
train_loader = DataLoader(train_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
|
| 89 |
+
|
| 90 |
+
val_dataset = STRDataset(X_val, y_val, char_to_idx, transforms=data_transforms['val'])
|
| 91 |
+
val_loader = DataLoader(val_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
|
| 92 |
+
|
| 93 |
+
test_dataset = STRDataset(X_test, y_test, char_to_idx, transforms=data_transforms['val'])
|
| 94 |
+
test_loader = DataLoader(test_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
|
| 95 |
+
|
| 96 |
+
return train_loader, val_loader, test_loader
|
src/Text_Recognization/prepare_dataset.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
import xml.etree.ElementTree as ET
|
| 5 |
+
import cv2
|
| 6 |
+
import json
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def extract_data_from_xml(root_path):
|
| 16 |
+
words_path = os.path.join(root_path, 'words.xml')
|
| 17 |
+
tree = ET.parse(words_path)
|
| 18 |
+
root = tree.getroot()
|
| 19 |
+
|
| 20 |
+
image_paths = []
|
| 21 |
+
image_sizes = []
|
| 22 |
+
image_labels = []
|
| 23 |
+
bboxes = []
|
| 24 |
+
|
| 25 |
+
for image in root:
|
| 26 |
+
imagename = image[0].text
|
| 27 |
+
image_path = os.path.join(root_path, imagename)
|
| 28 |
+
image_paths.append(image_path)
|
| 29 |
+
|
| 30 |
+
image_height = image[1].get('x')
|
| 31 |
+
image_width = image[1].get('y')
|
| 32 |
+
image_sizes.append([image_height, image_width])
|
| 33 |
+
|
| 34 |
+
bboxes_in_image = []
|
| 35 |
+
labels_in_bboxes = []
|
| 36 |
+
for bbox in image[2]:
|
| 37 |
+
x = float(bbox.get('x'))
|
| 38 |
+
y = float(bbox.get('y'))
|
| 39 |
+
width = float(bbox.get('width'))
|
| 40 |
+
height = float(bbox.get('height'))
|
| 41 |
+
bboxes_in_image.append([x, y, width, height])
|
| 42 |
+
|
| 43 |
+
# get text in this bbox
|
| 44 |
+
labels = bbox.find('tag').text
|
| 45 |
+
labels_in_bboxes.append(labels)
|
| 46 |
+
|
| 47 |
+
bboxes.append(bboxes_in_image)
|
| 48 |
+
image_labels.append(labels_in_bboxes)
|
| 49 |
+
|
| 50 |
+
return image_paths, image_sizes, bboxes, image_labels
|
| 51 |
+
|
| 52 |
+
def visualize_gt_bboxes(image_path, gt_locations, gt_labels):
|
| 53 |
+
image = cv2.imread(image_path)
|
| 54 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 55 |
+
|
| 56 |
+
for gt_location, gt_label in zip(gt_locations, gt_labels):
|
| 57 |
+
x, y, width, height = gt_location
|
| 58 |
+
x, y, width, height = int(x), int(y), int(width), int(height)
|
| 59 |
+
|
| 60 |
+
image = cv2.rectangle(image, (x, y), (x+width, y+height), color=(255, 0, 0), thickness=2)
|
| 61 |
+
image = cv2.putText(image, gt_label, (x, y-10), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale = 3, color=(255, 0, 0), thickness=2)
|
| 62 |
+
|
| 63 |
+
plt.imshow(image)
|
| 64 |
+
plt.axis('off')
|
| 65 |
+
plt.show()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def split_bboxes_from_image(image_paths, image_labels, bboxes, save_dir):
|
| 69 |
+
"""create a new dataset contains bboxes and corresponding labels
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
image_paths
|
| 73 |
+
image_labels
|
| 74 |
+
bboxes
|
| 75 |
+
save_dir
|
| 76 |
+
Return:
|
| 77 |
+
non-return
|
| 78 |
+
"""
|
| 79 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 80 |
+
os.makedirs('unvalid_images', exist_ok=True)
|
| 81 |
+
|
| 82 |
+
bboxes_idx = 0
|
| 83 |
+
unvalid_bboxes = 0
|
| 84 |
+
new_labels = [] # List to store labels
|
| 85 |
+
for image_path, bbox, label in zip(image_paths, bboxes, image_labels):
|
| 86 |
+
image = cv2.imread(image_path)
|
| 87 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 88 |
+
|
| 89 |
+
if image is None:
|
| 90 |
+
print(image_path)
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
for bb, lb in zip(bbox, label):
|
| 94 |
+
x, y, width, height = bb
|
| 95 |
+
x, y, width, height = int(x), int(y), int(width), int(height)
|
| 96 |
+
|
| 97 |
+
cropped_text = image[y:y+height, x:x+width]
|
| 98 |
+
|
| 99 |
+
# Filter if x, y, width, height is invalid cordinates
|
| 100 |
+
if x < 0 or y < 0 or width < 0 or height < 0:
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
# Filter text contain special characters
|
| 104 |
+
if 'é' in [lb[i].lower() for i in range(len(lb))] or 'ñ' in [lb[i].lower() for i in range(len(lb))] or '£' in [lb[i].lower() for i in range(len(lb))]:
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
# Filter out if text is too light or too dark
|
| 108 |
+
if np.mean(cropped_text) < 30 or np.mean(cropped_text) > 230:
|
| 109 |
+
cv2.imwrite(f'unvalid_images\\unvalid_image{unvalid_bboxes}_{lb}.jpg', cropped_text)
|
| 110 |
+
unvalid_bboxes += 1
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
# Filter out if image is too small
|
| 114 |
+
if width < 10 or height < 10:
|
| 115 |
+
cv2.imwrite(f'unvalid_images\\unvalid_image{unvalid_bboxes}_{lb}.jpg', cropped_text)
|
| 116 |
+
unvalid_bboxes += 1
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
new_image_path = os.path.join(save_dir, f'cropped_image{bboxes_idx}.jpg')
|
| 120 |
+
cv2.imwrite(new_image_path, cropped_text)
|
| 121 |
+
new_label = new_image_path + '\t' + lb
|
| 122 |
+
new_labels.append(new_label)
|
| 123 |
+
bboxes_idx += 1
|
| 124 |
+
|
| 125 |
+
# Write labels into a text file
|
| 126 |
+
with open(os.path.join(save_dir, 'labels.txt'), "w") as f:
|
| 127 |
+
for new_label in new_labels:
|
| 128 |
+
f.write(f'{new_label}\n')
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def build_vocab(root_dir):
|
| 132 |
+
img_paths = []
|
| 133 |
+
labels = []
|
| 134 |
+
|
| 135 |
+
# Read labels from text file
|
| 136 |
+
with open(os.path.join(root_dir, 'ocr_dataset', 'labels.txt'), "r") as f:
|
| 137 |
+
for label in f:
|
| 138 |
+
labels.append(label.strip().split("\t")[1])
|
| 139 |
+
img_paths.append(label.strip().split("\t")[0])
|
| 140 |
+
|
| 141 |
+
# build the vocab
|
| 142 |
+
vocab = set()
|
| 143 |
+
for label in labels:
|
| 144 |
+
for i in range(len(label)):
|
| 145 |
+
vocab.add(label[i])
|
| 146 |
+
|
| 147 |
+
# "blank" character
|
| 148 |
+
vocab = list(sorted(vocab))
|
| 149 |
+
vocab = "".join(vocab)
|
| 150 |
+
blank_char = '@'
|
| 151 |
+
vocab = vocab + 'z'
|
| 152 |
+
vocab = vocab + blank_char
|
| 153 |
+
|
| 154 |
+
# build a dictionary convert from vocab to idx and idx to vocab
|
| 155 |
+
char_to_idx = {
|
| 156 |
+
char: idx + 1 for idx, char in enumerate(vocab)
|
| 157 |
+
}
|
| 158 |
+
idx_to_char = {
|
| 159 |
+
idx: char for char, idx in char_to_idx.items()
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
return char_to_idx, idx_to_char
|
| 163 |
+
|
| 164 |
+
def get_imagepaths_and_labels(root_path):
|
| 165 |
+
img_paths = []
|
| 166 |
+
labels = []
|
| 167 |
+
|
| 168 |
+
# Read labels from text file
|
| 169 |
+
with open(os.path.join(root_path, 'ocr_dataset', 'labels.txt'), "r") as f:
|
| 170 |
+
for label in f:
|
| 171 |
+
labels.append(label.strip().split("\t")[1])
|
| 172 |
+
img_paths.append(label.strip().split("\t")[0])
|
| 173 |
+
|
| 174 |
+
return img_paths, labels
|
| 175 |
+
|
| 176 |
+
def encode(label, char_to_idx, labels):
|
| 177 |
+
max_length_label = np.max([len(lb) for lb in labels])
|
| 178 |
+
|
| 179 |
+
# encode label
|
| 180 |
+
encoded_label = torch.tensor(
|
| 181 |
+
[char_to_idx[char.lower()] for char in label],
|
| 182 |
+
dtype=torch.int32
|
| 183 |
+
)
|
| 184 |
+
label_len = len(encoded_label)
|
| 185 |
+
length = torch.tensor(
|
| 186 |
+
label_len,
|
| 187 |
+
dtype=torch.int32
|
| 188 |
+
)
|
| 189 |
+
padded_label = F.pad(
|
| 190 |
+
encoded_label,
|
| 191 |
+
(0, max_length_label-label_len),
|
| 192 |
+
value=0
|
| 193 |
+
)
|
| 194 |
+
return padded_label, length
|
| 195 |
+
|
| 196 |
+
def decode(encoded_label, idx_to_char, char_to_idx, blank_char='@'):
|
| 197 |
+
label = []
|
| 198 |
+
encoded_label = encoded_label.detach().numpy()
|
| 199 |
+
for i in range(len(encoded_label)):
|
| 200 |
+
if encoded_label[i] == 0:
|
| 201 |
+
break
|
| 202 |
+
elif (i == 0 or encoded_label[i] != encoded_label[i-1]) and encoded_label[i] != char_to_idx[blank_char]:
|
| 203 |
+
label.append(idx_to_char[encoded_label[i]])
|
| 204 |
+
|
| 205 |
+
label = "".join(label)
|
| 206 |
+
return label
|
| 207 |
+
|
| 208 |
+
def main():
|
| 209 |
+
parser = argparse.ArgumentParser()
|
| 210 |
+
parser.add_argument("--path", type=str, default=os.getcwd(), help="Path to the root directory")
|
| 211 |
+
args = parser.parse_args()
|
| 212 |
+
|
| 213 |
+
root_path = os.path.join(args.path, 'Dataset')
|
| 214 |
+
|
| 215 |
+
image_paths, image_sizes, bboxes, image_labels = extract_data_from_xml(root_path)
|
| 216 |
+
save_dir = 'Dataset/ocr_dataset'
|
| 217 |
+
split_bboxes_from_image(image_paths, image_labels, bboxes, save_dir)
|
| 218 |
+
char_to_idx, idx_to_char = build_vocab(root_path)
|
| 219 |
+
|
| 220 |
+
if __name__ == '__main__':
|
| 221 |
+
main()
|
src/Text_Recognization/text_recognization.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision
|
| 4 |
+
|
| 5 |
+
from torchvision.models import resnet101
|
| 6 |
+
|
| 7 |
+
class BackBone(nn.Module):
|
| 8 |
+
def __init__(self, num_unfreeze_layers=3):
|
| 9 |
+
super(BackBone, self).__init__()
|
| 10 |
+
model = resnet101(weights='IMAGENET1K_V2', progress=True)
|
| 11 |
+
feature_maps = list(model.children())[:8]
|
| 12 |
+
|
| 13 |
+
# Adding an AdaptiveAvgPooling (batch_size, 2048, 8, 8) -> (batch_size, 2048, 1, 8)
|
| 14 |
+
feature_maps.append(nn.AdaptiveAvgPool2d((1, None)))
|
| 15 |
+
self.backbone = nn.Sequential(*feature_maps)
|
| 16 |
+
|
| 17 |
+
for layer in list(self.backbone.parameters())[-(num_unfreeze_layers+1):]:
|
| 18 |
+
layer.requires_grad = True
|
| 19 |
+
|
| 20 |
+
def forward(self, image):
|
| 21 |
+
return self.backbone(image)
|
| 22 |
+
|
| 23 |
+
class CRNN(nn.Module):
|
| 24 |
+
def __init__(self, vocab_size, hidden_size, n_layers, dropout=0.2, num_unfreeze_layers=3):
|
| 25 |
+
super(CRNN, self).__init__()
|
| 26 |
+
self.backbone = BackBone(num_unfreeze_layers=num_unfreeze_layers)
|
| 27 |
+
|
| 28 |
+
self.mapSeq = nn.Sequential(
|
| 29 |
+
nn.Linear(2048, 512),
|
| 30 |
+
nn.ReLU(),
|
| 31 |
+
nn.Dropout(p=dropout)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
self.gru = nn.GRU(
|
| 35 |
+
input_size=512,
|
| 36 |
+
hidden_size=hidden_size,
|
| 37 |
+
num_layers=n_layers,
|
| 38 |
+
bidirectional=True,
|
| 39 |
+
batch_first=True,
|
| 40 |
+
dropout=dropout if n_layers > 1 else 0
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self.layer_norm = nn.LayerNorm(hidden_size * 2)
|
| 44 |
+
|
| 45 |
+
# Dense layers
|
| 46 |
+
self.out = nn.Sequential(
|
| 47 |
+
nn.Linear(hidden_size * 2, vocab_size),
|
| 48 |
+
nn.LogSoftmax(dim=2)
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
x = self.backbone(x)
|
| 53 |
+
# (batch_size, 2048, 1, 8) -> (batch_size, 8, 2048, 1)
|
| 54 |
+
x = x.permute(0, 3, 1, 2)
|
| 55 |
+
# flatten -> (batch_size, 8, 2048)
|
| 56 |
+
x = x.view(x.size(0), x.size(1), -1)
|
| 57 |
+
x = self.mapSeq(x)
|
| 58 |
+
x, _ = self.gru(x)
|
| 59 |
+
x = self.layer_norm(x)
|
| 60 |
+
x = self.out(x)
|
| 61 |
+
# (batch_size, 8, vocab_size) -> (8, batch_size, vocab_size)
|
| 62 |
+
x = x.permute(1, 0, 2)
|
| 63 |
+
|
| 64 |
+
return x
|
src/Text_Recognization/trainer.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torchvision
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
sys.path.append(os.getcwd())
|
| 11 |
+
|
| 12 |
+
from src.Text_Recognization.text_recognization import *
|
| 13 |
+
from src.Text_Recognization.prepare_dataset import *
|
| 14 |
+
from src.Text_Recognization.dataloader import *
|
| 15 |
+
|
| 16 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 17 |
+
|
| 18 |
+
def load_json_config(config_path):
|
| 19 |
+
with open(config_path, "r") as f:
|
| 20 |
+
config = json.load(f)
|
| 21 |
+
|
| 22 |
+
return config
|
| 23 |
+
|
| 24 |
+
def evaluate(model, dataloader, criterion, device):
|
| 25 |
+
model.eval()
|
| 26 |
+
losses = []
|
| 27 |
+
|
| 28 |
+
with torch.no_grad():
|
| 29 |
+
for images, labels, labels_len in dataloader:
|
| 30 |
+
images = images.to(device)
|
| 31 |
+
labels = labels.to(device)
|
| 32 |
+
|
| 33 |
+
outputs = model(images)
|
| 34 |
+
logits_lens = torch.full(
|
| 35 |
+
size=(outputs.size(1), ),
|
| 36 |
+
fill_value=outputs.size(0),
|
| 37 |
+
dtype=torch.long
|
| 38 |
+
).to(device)
|
| 39 |
+
|
| 40 |
+
loss = criterion(outputs, labels, logits_lens, labels_len)
|
| 41 |
+
losses.append(loss.item())
|
| 42 |
+
|
| 43 |
+
eval_loss = sum(losses) / len(losses)
|
| 44 |
+
return eval_loss
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def training_loop(model, train_loader, val_loader, learning_rate, epochs, optimizer, criterion, scheduler, device):
|
| 48 |
+
model.to(device)
|
| 49 |
+
|
| 50 |
+
train_losses = []
|
| 51 |
+
val_losses = []
|
| 52 |
+
|
| 53 |
+
for epoch in range(epochs):
|
| 54 |
+
model.train()
|
| 55 |
+
|
| 56 |
+
batch_losses = []
|
| 57 |
+
for images, labels, labels_len in tqdm(train_loader):
|
| 58 |
+
images = images.to(device)
|
| 59 |
+
labels = labels.to(device)
|
| 60 |
+
|
| 61 |
+
optimizer.zero_grad()
|
| 62 |
+
outputs = model(images)
|
| 63 |
+
|
| 64 |
+
logits_lens = torch.full(
|
| 65 |
+
size=(outputs.size(1), ),
|
| 66 |
+
fill_value=outputs.size(0),
|
| 67 |
+
dtype=torch.long
|
| 68 |
+
).to(device)
|
| 69 |
+
|
| 70 |
+
loss = criterion(outputs, labels, logits_lens, labels_len)
|
| 71 |
+
|
| 72 |
+
loss.backward()
|
| 73 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
|
| 74 |
+
optimizer.step()
|
| 75 |
+
|
| 76 |
+
batch_losses.append(loss.item())
|
| 77 |
+
|
| 78 |
+
train_loss = sum(batch_losses) / len(batch_losses)
|
| 79 |
+
train_losses.append(train_loss)
|
| 80 |
+
|
| 81 |
+
val_loss = evaluate(model, val_loader, criterion, device)
|
| 82 |
+
val_losses.append(val_loss)
|
| 83 |
+
|
| 84 |
+
print(f"epoch: {epoch+1}/{epochs}\ttrain_loss:{train_loss}\tval_loss:{val_loss}")
|
| 85 |
+
|
| 86 |
+
scheduler.step()
|
| 87 |
+
|
| 88 |
+
return train_losses, val_losses
|
| 89 |
+
|
| 90 |
+
def main():
|
| 91 |
+
parser = argparse.ArgumentParser()
|
| 92 |
+
parser.add_argument('--root_path', type=str, default=os.getcwd(), help='Path to the root directory')
|
| 93 |
+
parser.add_argument('--checkpoints_path', type=str, default=os.path.join(os.getcwd(), 'checkpoints'), help='Path to the checkpoint directory')
|
| 94 |
+
|
| 95 |
+
args = parser.parse_args()
|
| 96 |
+
config_path = 'src/config.json'
|
| 97 |
+
dataset_path = os.path.join(args.root_path, 'Dataset')
|
| 98 |
+
config = load_json_config(config_path)
|
| 99 |
+
|
| 100 |
+
# dictionary char and idx
|
| 101 |
+
char_to_idx, idx_to_char = build_vocab(dataset_path)
|
| 102 |
+
|
| 103 |
+
# model
|
| 104 |
+
model = CRNN(vocab_size=config['vocab_size'], hidden_size=config['CRNN']['hidden_size'], n_layers=config['CRNN']['n_layers'])
|
| 105 |
+
|
| 106 |
+
# dataloader
|
| 107 |
+
train_loader, val_loader, test_loader = get_dataloader()
|
| 108 |
+
|
| 109 |
+
# define hyper parammeters
|
| 110 |
+
criterion = nn.CTCLoss(
|
| 111 |
+
blank=char_to_idx[config['blank_char']],
|
| 112 |
+
zero_infinity=True,
|
| 113 |
+
reduction='mean'
|
| 114 |
+
)
|
| 115 |
+
optimizer = torch.optim.Adam(
|
| 116 |
+
model.parameters(),
|
| 117 |
+
lr=config['CRNN']['learning_rate'],
|
| 118 |
+
weight_decay=config['CRNN']['weight_decay']
|
| 119 |
+
)
|
| 120 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
| 121 |
+
optimizer=optimizer,
|
| 122 |
+
step_size=config['CRNN']['scheduler_step_size'],
|
| 123 |
+
gamma=0.1
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# training loop
|
| 127 |
+
train_losses, val_losses = training_loop(
|
| 128 |
+
model=model,
|
| 129 |
+
train_loader=train_loader,
|
| 130 |
+
val_loader=val_loader,
|
| 131 |
+
learning_rate=config['CRNN']['learning_rate'],
|
| 132 |
+
epochs=config['CRNN']['epochs'],
|
| 133 |
+
optimizer=optimizer,
|
| 134 |
+
criterion=criterion,
|
| 135 |
+
scheduler=scheduler,
|
| 136 |
+
device=device
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# save model
|
| 140 |
+
if not os.path.exists(args.checkpoints_path):
|
| 141 |
+
os.makedirs(args.checkpoints_path)
|
| 142 |
+
os.makedirs(os.path.join(args.checkpoints_path, 'losses'))
|
| 143 |
+
torch.save(model.state_dict(), os.path.join(args.checkpoints_path, 'crnn.pt'))
|
| 144 |
+
|
| 145 |
+
# draw losses
|
| 146 |
+
fig, axis = plt.subplots(1, 2, figsize=(8, 8))
|
| 147 |
+
axis[0].plot(train_losses, label='train_loss')
|
| 148 |
+
axis[0].set_xlabel('Epochs')
|
| 149 |
+
axis[0].set_ylabel('Loss')
|
| 150 |
+
axis[0].axis('off')
|
| 151 |
+
axis[0].legend()
|
| 152 |
+
|
| 153 |
+
axis[1].plot(val_losses, label='val_loss')
|
| 154 |
+
axis[1].set_xlabel('Epochs')
|
| 155 |
+
axis[1].set_ylabel('Loss')
|
| 156 |
+
axis[1].axis('off')
|
| 157 |
+
axis[1].legend()
|
| 158 |
+
|
| 159 |
+
plt.savefig(os.path.join(args.checkpoints_path, 'losses', 'losses.png'))
|
| 160 |
+
|
| 161 |
+
if __name__ == '__main__':
|
| 162 |
+
main()
|
src/config.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"yolov11":
|
| 3 |
+
{
|
| 4 |
+
"epochs": 100,
|
| 5 |
+
"image_size": 640,
|
| 6 |
+
"cache": true,
|
| 7 |
+
"patience": 20,
|
| 8 |
+
"plots": true
|
| 9 |
+
},
|
| 10 |
+
"CRNN":
|
| 11 |
+
{
|
| 12 |
+
"batch_size": 64,
|
| 13 |
+
"epochs": 100,
|
| 14 |
+
"hidden_size": 256,
|
| 15 |
+
"n_layers": 3,
|
| 16 |
+
"dropout": 0.2,
|
| 17 |
+
"unfreeze_layers": 3,
|
| 18 |
+
"learning_rate": 5e-4,
|
| 19 |
+
"weight_decay": 1e-5,
|
| 20 |
+
"scheduler_step_size": 30
|
| 21 |
+
},
|
| 22 |
+
"blank_char": "@",
|
| 23 |
+
"vocab_size": 73
|
| 24 |
+
}
|
src/predict.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import cv2
|
| 5 |
+
import argparse
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
|
| 8 |
+
import ultralytics
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torchvision
|
| 12 |
+
from torchvision.models import resnet101
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
|
| 15 |
+
sys.path.append(os.getcwd())
|
| 16 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 17 |
+
|
| 18 |
+
from ultralytics import YOLO
|
| 19 |
+
from src.Text_Recognization.text_recognization import *
|
| 20 |
+
from src.Text_Recognization.prepare_dataset import *
|
| 21 |
+
|
| 22 |
+
# config
|
| 23 |
+
def load_json_config(config_path):
|
| 24 |
+
with open(config_path, "r") as f:
|
| 25 |
+
config = json.load(f)
|
| 26 |
+
|
| 27 |
+
return config
|
| 28 |
+
|
| 29 |
+
config = load_json_config('src/config.json')
|
| 30 |
+
|
| 31 |
+
# char to idx
|
| 32 |
+
char_to_idx, idx_to_char = build_vocab('Dataset')
|
| 33 |
+
|
| 34 |
+
# text detection model
|
| 35 |
+
text_det_model_path = 'checkpoints/yolov11m.pt'
|
| 36 |
+
yolo = YOLO(text_det_model_path)
|
| 37 |
+
|
| 38 |
+
# text recognition model
|
| 39 |
+
text_rec_model_path = 'checkpoints/crnn_extend_vocab.pt'
|
| 40 |
+
|
| 41 |
+
# rcnn model
|
| 42 |
+
rcnn_model = CRNN(vocab_size=74, hidden_size=config['CRNN']['hidden_size'], n_layers=config['CRNN']['n_layers'])
|
| 43 |
+
rcnn_model.load_state_dict(torch.load(text_rec_model_path, weights_only=True, map_location=torch.device('cpu')))
|
| 44 |
+
|
| 45 |
+
def text_detection(img_path, text_det_model):
|
| 46 |
+
text_det_results = text_det_model(img_path, verbose=False)[0]
|
| 47 |
+
|
| 48 |
+
bboxes = text_det_results.boxes.xyxy.tolist()
|
| 49 |
+
classes = text_det_results.boxes.cls.tolist()
|
| 50 |
+
names = text_det_results.names
|
| 51 |
+
confs = text_det_results.boxes.conf.tolist()
|
| 52 |
+
|
| 53 |
+
return bboxes, classes, names, confs
|
| 54 |
+
|
| 55 |
+
def visualize_gt_bboxes_yolo(image_path, gt_location_yolo):
|
| 56 |
+
image = cv2.imread(image_path)
|
| 57 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 58 |
+
|
| 59 |
+
# Convert to original format
|
| 60 |
+
for data in gt_location_yolo:
|
| 61 |
+
xmin, ymin, xmax, ymax = data
|
| 62 |
+
xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
|
| 63 |
+
|
| 64 |
+
image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=(255, 0, 0), thickness=2)
|
| 65 |
+
|
| 66 |
+
plt.imshow(image)
|
| 67 |
+
plt.axis('off')
|
| 68 |
+
plt.show()
|
| 69 |
+
|
| 70 |
+
def text_recognization(image, data_transforms, text_reg_model, idx_to_char=idx_to_char, device=device):
|
| 71 |
+
transformsed_image = data_transforms(image)
|
| 72 |
+
transformsed_image = transformsed_image.unsqueeze(0).to(device)
|
| 73 |
+
text_reg_model.to(device)
|
| 74 |
+
text_reg_model.eval()
|
| 75 |
+
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
preds = text_reg_model(transformsed_image)
|
| 78 |
+
_, idx = torch.max(preds, dim=2)
|
| 79 |
+
idx = idx.view(-1)
|
| 80 |
+
text = decode(idx, idx_to_char, char_to_idx)
|
| 81 |
+
|
| 82 |
+
return text, idx
|
| 83 |
+
|
| 84 |
+
def visualize_detection(image, detections):
|
| 85 |
+
plt.figure(figsize=(10, 8))
|
| 86 |
+
|
| 87 |
+
for bbox, detected_classes, conf, text, _ in detections:
|
| 88 |
+
x1, y1, x2, y2 = bbox
|
| 89 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
| 90 |
+
|
| 91 |
+
image = cv2.rectangle(image, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2)
|
| 92 |
+
image = cv2.putText(image, f"{conf:.2f} {text}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
|
| 93 |
+
|
| 94 |
+
plt.imshow(image)
|
| 95 |
+
plt.axis('off')
|
| 96 |
+
plt.show()
|
| 97 |
+
return image
|
| 98 |
+
|
| 99 |
+
data_transforms = transforms.Compose([
|
| 100 |
+
transforms.ToTensor(),
|
| 101 |
+
transforms.Resize((100, 400)),
|
| 102 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 103 |
+
])
|
| 104 |
+
|
| 105 |
+
def prediction(image, text_det_model=yolo, text_reg_model=rcnn_model, idx_to_char=idx_to_char, char_to_idx=char_to_idx, data_transforms=data_transforms, device=device):
|
| 106 |
+
# detection
|
| 107 |
+
bboxes, classes, names, confs = text_detection(image, text_det_model)
|
| 108 |
+
|
| 109 |
+
predictions = []
|
| 110 |
+
for bbox, cls, conf in zip(bboxes, classes, confs):
|
| 111 |
+
x1, y1, x2, y2 = bbox
|
| 112 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
| 113 |
+
detected_text = image[y1:y2, x1:x2]
|
| 114 |
+
text, encoded_text = text_recognization(detected_text, data_transforms, text_reg_model, idx_to_char, device)
|
| 115 |
+
predictions.append((bbox, cls, conf, text, encoded_text))
|
| 116 |
+
print(bbox, cls, conf, text)
|
| 117 |
+
|
| 118 |
+
return predictions
|
| 119 |
+
|
| 120 |
+
def main():
|
| 121 |
+
parser = argparse.ArgumentParser()
|
| 122 |
+
parser.add_argument('--image_path', type=str, help='Path to the image')
|
| 123 |
+
parser.add_argument('--save_path', type=str, default=None, help='Path to save the image')
|
| 124 |
+
args = parser.parse_args()
|
| 125 |
+
image_path = args.image_path
|
| 126 |
+
image = cv2.imread(image_path)
|
| 127 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 128 |
+
|
| 129 |
+
detections = prediction(image)
|
| 130 |
+
image = visualize_detection(image, detections)
|
| 131 |
+
|
| 132 |
+
if args.save_path:
|
| 133 |
+
print(f"Saving the image to {os.path.join(args.save_path, 'predicted_image.jpg')}")
|
| 134 |
+
cv2.imwrite(os.path.join(args.save_path, 'predicted_image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
|
| 135 |
+
|
| 136 |
+
if __name__ == '__main__':
|
| 137 |
+
main()
|