TungDuong commited on
Commit
eb20997
·
verified ·
1 Parent(s): f4dccd7

source code

Browse files
src/Text_Localization/__pycache__/prepare_dataset.cpython-312.pyc ADDED
Binary file (7.86 kB). View file
 
src/Text_Localization/prepare_dataset.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import xml.etree.ElementTree as ET
6
+ import shutil
7
+ import yaml
8
+
9
+ from sklearn.model_selection import train_test_split
10
+
11
+ location_path = r'Dataset/locations.xml'
12
+ tree = ET.parse(location_path)
13
+ root = tree.getroot()
14
+
15
+
16
+ def get_gt_bboxes(location_path):
17
+ """get all the gt bbox of text in dataset
18
+
19
+ Args:
20
+ location_path: (path)
21
+ Return:
22
+ gt_imagepaths[1] (list): image's name
23
+ gt_locations (list): bboxes of each image
24
+ """
25
+ gt_imagepaths = []
26
+ gt_imagesizes = []
27
+ gt_locations = []
28
+
29
+ for image in root:
30
+ # get path to image
31
+ image_name = image[0].text
32
+ image_path = os.path.join('Dataset', image_name)
33
+ gt_imagepaths.append(image_path)
34
+
35
+ # get the image size
36
+ w = image[1].get('x')
37
+ h = image[1].get('y')
38
+ gt_imagesizes.append([h, w])
39
+
40
+ # bboxes in the image
41
+ bbs = []
42
+ for bbox in image[2]:
43
+ x = np.int64(float(bbox.get('x')))
44
+ y = np.int64(float(bbox.get('y')))
45
+ width = np.int64(float(bbox.get('width')))
46
+ height = np.int64(float(bbox.get('height')))
47
+ bbs.append([x, y, width, height])
48
+
49
+ gt_locations.append(bbs)
50
+
51
+ return gt_imagepaths, gt_imagesizes, gt_locations
52
+
53
+ gt_imagepaths, gt_imagesizes, gt_locations = get_gt_bboxes(location_path)
54
+
55
+ def visualize_gt_bboxes(image_path, gt_locations):
56
+ image = cv2.imread(image_path)
57
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
58
+
59
+ for gt_location in gt_locations:
60
+ x, y, width, height = gt_location
61
+ x, y, width, height = int(x), int(y), int(width), int(height)
62
+
63
+ image = cv2.rectangle(image, (x, y), (x+width, y+height), color=(255, 0, 0), thickness=2)
64
+
65
+ plt.imshow(image)
66
+ plt.axis('off')
67
+ plt.show()
68
+
69
+ def visualize_gt_bboxes_yolo(image_path, gt_location_yolo):
70
+ image = cv2.imread(image_path)
71
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
72
+
73
+ image_height, image_width = image.shape[:2]
74
+
75
+ # Convert to original format
76
+ for data in gt_location_yolo:
77
+ xc, yc, w, h = data[1:]
78
+ xmin = int((xc - w/2) * image_width)
79
+ ymin = int((yc - h/2) * image_height)
80
+ xmax = int((xc + w/2) * image_width)
81
+ ymax = int((yc + h/2) * image_height)
82
+
83
+ image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=(255, 0, 0), thickness=2)
84
+
85
+ plt.imshow(image)
86
+ plt.axis('off')
87
+ plt.show()
88
+
89
+
90
+ def convert_yolo_format(gt_locations, gt_imagesizes):
91
+ gt_locations_yolo = []
92
+
93
+ for image, image_size in zip(gt_locations, gt_imagesizes):
94
+ gt_location_yolo = []
95
+ for gt_location in image:
96
+ x, y, w, h = gt_location
97
+ image_height, image_width = image_size
98
+
99
+ xc = (x + w/2) / float(image_width)
100
+ yc = (y + h/2) / float(image_height)
101
+ width = w / float(image_width)
102
+ height = h / float(image_height)
103
+
104
+ # class = 0 -> meaning contains text
105
+ class_id = 0
106
+ gt_location_yolo.append([class_id, xc, yc, width, height])
107
+
108
+ gt_locations_yolo.append(gt_location_yolo)
109
+
110
+ return gt_locations_yolo
111
+
112
+ gt_locations_yolo = convert_yolo_format(gt_locations, gt_imagesizes)
113
+
114
+ def save_data_into_yolo_folder(data, src_img_dir, save_dir):
115
+ # Create folder if not exist
116
+ os.makedirs(save_dir, exist_ok=True)
117
+
118
+ # Make images and labels folder
119
+ os.makedirs(os.path.join(save_dir, 'images'), exist_ok=True)
120
+ os.makedirs(os.path.join(save_dir, 'labels'), exist_ok=True)
121
+
122
+ # write data into yolo folder
123
+ for dt in data:
124
+ # copy data
125
+ image_path = dt[0]
126
+ shutil.copy(image_path, os.path.join(save_dir, 'images'))
127
+
128
+ #copy labels
129
+ image_name = os.path.basename(image_path)
130
+ image_name = os.path.splitext(image_name)[0]
131
+
132
+ with open(os.path.join(save_dir, 'labels', f'{image_name}.txt'), "w") as f:
133
+ for label in dt[1]:
134
+ label_str = " ".join(map(str, label))
135
+ f.write(f'{label_str}\n')
136
+
137
+
138
+
139
+ seed = 0
140
+ val_size = 0.2
141
+ test_size = 0.125
142
+ dataset = [[gt_imagepath, gt_location_yolo] for gt_imagepath, gt_location_yolo in zip(gt_imagepaths, gt_locations_yolo)]
143
+ train_data, val_data = train_test_split(dataset, test_size=val_size, random_state=42, shuffle=True)
144
+ train_data, test_data = train_test_split(train_data, test_size=test_size, random_state=42, shuffle=True)
145
+
146
+ save_yolo_data_dir = 'yolo_data'
147
+ os.makedirs(save_yolo_data_dir, exist_ok=True)
148
+ save_data_into_yolo_folder(
149
+ data=train_data,
150
+ src_img_dir=save_yolo_data_dir,
151
+ save_dir=os.path.join(save_yolo_data_dir, 'train')
152
+ )
153
+ save_data_into_yolo_folder(
154
+ data=val_data,
155
+ src_img_dir=save_yolo_data_dir,
156
+ save_dir=os.path.join(save_yolo_data_dir, 'val')
157
+ )
158
+ save_data_into_yolo_folder(
159
+ data=test_data,
160
+ src_img_dir=save_yolo_data_dir,
161
+ save_dir=os.path.join(save_yolo_data_dir, 'test')
162
+ )
163
+
164
+ class_label = ['text']
165
+ # Create data.yaml file
166
+ data_yaml = {
167
+ "path": '../yolo_data',
168
+ 'train': 'train/images',
169
+ 'test': 'test/images',
170
+ 'val': 'val/images',
171
+ 'nc': 1,
172
+ 'names': class_label
173
+ }
174
+
175
+ yolo_yaml_path = os.path.join(save_yolo_data_dir, 'data.yaml')
176
+ with open(yolo_yaml_path, "w") as f:
177
+ yaml.dump(data_yaml, f, default_flow_style=False)
src/Text_Localization/text_localization.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ultralytics
2
+ import json
3
+ from ultralytics import YOLO
4
+
5
+ yolo_yaml_path = 'yolo_data/data.yaml'
6
+ config_path = 'src/config.json'
7
+
8
+ def load_json_config(config_path):
9
+ with open(config_path, "r") as f:
10
+ config = json.load(f)
11
+
12
+ return config
13
+
14
+ config = load_json_config(config_path)
15
+
16
+ # Load model
17
+ model = YOLO('yolo11m.pt')
18
+
19
+ # Train model
20
+ results = model.train(
21
+ data=yolo_yaml_path,
22
+ epochs=config['yolov11']['epochs'],
23
+ imgsz=config['yolov11']['image_size'],
24
+ cache=config['yolov11']['cache'],
25
+ patience=config['yolov11']['patience'],
26
+ plots=config['yolov11']['plots']
27
+ )
28
+
29
+ # Evaluate model
30
+ model_path = 'checkpoints/yolov11m.pt'
31
+ model = YOLO(model_path)
32
+
33
+ metrics = model.val()
src/Text_Recognization/__pycache__/dataloader.cpython-312.pyc ADDED
Binary file (4.38 kB). View file
 
src/Text_Recognization/__pycache__/prepare_dataset.cpython-312.pyc ADDED
Binary file (11.1 kB). View file
 
src/Text_Recognization/__pycache__/text_recognization.cpython-312.pyc ADDED
Binary file (3.54 kB). View file
 
src/Text_Recognization/dataloader.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import json
4
+ import sys
5
+
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torchvision import transforms
8
+ from sklearn.model_selection import train_test_split
9
+ from src.Text_Recognization.prepare_dataset import *
10
+
11
+ # data augmentation
12
+ data_transforms = {
13
+ "train": transforms.Compose(
14
+ [
15
+ transforms.ToTensor(),
16
+ transforms.Resize((100, 400)),
17
+ transforms.ColorJitter(
18
+ brightness=0.5,
19
+ contrast=0.5,
20
+ saturation=0.5
21
+ ),
22
+ transforms.GaussianBlur(3),
23
+ transforms.RandomAffine(
24
+ degrees=1,
25
+ shear=1
26
+ ),
27
+ transforms.RandomPerspective(
28
+ distortion_scale=0.3,
29
+ p=0.5
30
+ ),
31
+ transforms.RandomRotation(degrees=15),
32
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33
+ ]
34
+ ),
35
+ "val": transforms.Compose(
36
+ [
37
+ transforms.ToTensor(),
38
+ transforms.Resize((100, 400)),
39
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
40
+ ]
41
+ )
42
+ }
43
+
44
+ def load_json_config(config_path):
45
+ with open(config_path, "r") as f:
46
+ config = json.load(f)
47
+
48
+ return config
49
+
50
+ # Dataloader
51
+ class STRDataset(Dataset):
52
+ def __init__(self, image_paths, labels, char_to_idx, transforms=None):
53
+ self.image_paths = image_paths
54
+ self.labels = labels
55
+ self.char_to_idx = char_to_idx
56
+ self.transforms= transforms
57
+
58
+ def __len__(self):
59
+ return len(self.image_paths)
60
+
61
+ def __getitem__(self, idx):
62
+ image = cv2.imread(self.image_paths[idx])
63
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
64
+
65
+ if self.transforms:
66
+ image = self.transforms(image)
67
+
68
+ label_encoded, length = encode(self.labels[idx], self.char_to_idx, self.labels)
69
+
70
+ return image, label_encoded, length
71
+
72
+ def get_dataloader():
73
+ val_size = 0.1
74
+ test_size = 0.1
75
+ root_path = 'Dataset'
76
+ config_path = 'src/config.json'
77
+
78
+ # get image paths and labels
79
+ image_paths, labels = get_imagepaths_and_labels(root_path)
80
+ char_to_idx, idx_to_char = build_vocab(root_path)
81
+
82
+
83
+ config = load_json_config(config_path)
84
+
85
+ X_train, X_val, y_train, y_val = train_test_split(image_paths, labels, test_size=val_size, random_state=42, shuffle=True)
86
+ X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=test_size, random_state=42, shuffle=True)
87
+ train_dataset = STRDataset(X_train, y_train, char_to_idx, transforms=data_transforms['train'])
88
+ train_loader = DataLoader(train_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
89
+
90
+ val_dataset = STRDataset(X_val, y_val, char_to_idx, transforms=data_transforms['val'])
91
+ val_loader = DataLoader(val_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
92
+
93
+ test_dataset = STRDataset(X_test, y_test, char_to_idx, transforms=data_transforms['val'])
94
+ test_loader = DataLoader(test_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
95
+
96
+ return train_loader, val_loader, test_loader
src/Text_Recognization/prepare_dataset.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import xml.etree.ElementTree as ET
5
+ import cv2
6
+ import json
7
+ import matplotlib.pyplot as plt
8
+ import argparse
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def extract_data_from_xml(root_path):
16
+ words_path = os.path.join(root_path, 'words.xml')
17
+ tree = ET.parse(words_path)
18
+ root = tree.getroot()
19
+
20
+ image_paths = []
21
+ image_sizes = []
22
+ image_labels = []
23
+ bboxes = []
24
+
25
+ for image in root:
26
+ imagename = image[0].text
27
+ image_path = os.path.join(root_path, imagename)
28
+ image_paths.append(image_path)
29
+
30
+ image_height = image[1].get('x')
31
+ image_width = image[1].get('y')
32
+ image_sizes.append([image_height, image_width])
33
+
34
+ bboxes_in_image = []
35
+ labels_in_bboxes = []
36
+ for bbox in image[2]:
37
+ x = float(bbox.get('x'))
38
+ y = float(bbox.get('y'))
39
+ width = float(bbox.get('width'))
40
+ height = float(bbox.get('height'))
41
+ bboxes_in_image.append([x, y, width, height])
42
+
43
+ # get text in this bbox
44
+ labels = bbox.find('tag').text
45
+ labels_in_bboxes.append(labels)
46
+
47
+ bboxes.append(bboxes_in_image)
48
+ image_labels.append(labels_in_bboxes)
49
+
50
+ return image_paths, image_sizes, bboxes, image_labels
51
+
52
+ def visualize_gt_bboxes(image_path, gt_locations, gt_labels):
53
+ image = cv2.imread(image_path)
54
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
55
+
56
+ for gt_location, gt_label in zip(gt_locations, gt_labels):
57
+ x, y, width, height = gt_location
58
+ x, y, width, height = int(x), int(y), int(width), int(height)
59
+
60
+ image = cv2.rectangle(image, (x, y), (x+width, y+height), color=(255, 0, 0), thickness=2)
61
+ image = cv2.putText(image, gt_label, (x, y-10), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale = 3, color=(255, 0, 0), thickness=2)
62
+
63
+ plt.imshow(image)
64
+ plt.axis('off')
65
+ plt.show()
66
+
67
+
68
+ def split_bboxes_from_image(image_paths, image_labels, bboxes, save_dir):
69
+ """create a new dataset contains bboxes and corresponding labels
70
+
71
+ Args:
72
+ image_paths
73
+ image_labels
74
+ bboxes
75
+ save_dir
76
+ Return:
77
+ non-return
78
+ """
79
+ os.makedirs(save_dir, exist_ok=True)
80
+ os.makedirs('unvalid_images', exist_ok=True)
81
+
82
+ bboxes_idx = 0
83
+ unvalid_bboxes = 0
84
+ new_labels = [] # List to store labels
85
+ for image_path, bbox, label in zip(image_paths, bboxes, image_labels):
86
+ image = cv2.imread(image_path)
87
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
88
+
89
+ if image is None:
90
+ print(image_path)
91
+ continue
92
+
93
+ for bb, lb in zip(bbox, label):
94
+ x, y, width, height = bb
95
+ x, y, width, height = int(x), int(y), int(width), int(height)
96
+
97
+ cropped_text = image[y:y+height, x:x+width]
98
+
99
+ # Filter if x, y, width, height is invalid cordinates
100
+ if x < 0 or y < 0 or width < 0 or height < 0:
101
+ continue
102
+
103
+ # Filter text contain special characters
104
+ if 'é' in [lb[i].lower() for i in range(len(lb))] or 'ñ' in [lb[i].lower() for i in range(len(lb))] or '£' in [lb[i].lower() for i in range(len(lb))]:
105
+ continue
106
+
107
+ # Filter out if text is too light or too dark
108
+ if np.mean(cropped_text) < 30 or np.mean(cropped_text) > 230:
109
+ cv2.imwrite(f'unvalid_images\\unvalid_image{unvalid_bboxes}_{lb}.jpg', cropped_text)
110
+ unvalid_bboxes += 1
111
+ continue
112
+
113
+ # Filter out if image is too small
114
+ if width < 10 or height < 10:
115
+ cv2.imwrite(f'unvalid_images\\unvalid_image{unvalid_bboxes}_{lb}.jpg', cropped_text)
116
+ unvalid_bboxes += 1
117
+ continue
118
+
119
+ new_image_path = os.path.join(save_dir, f'cropped_image{bboxes_idx}.jpg')
120
+ cv2.imwrite(new_image_path, cropped_text)
121
+ new_label = new_image_path + '\t' + lb
122
+ new_labels.append(new_label)
123
+ bboxes_idx += 1
124
+
125
+ # Write labels into a text file
126
+ with open(os.path.join(save_dir, 'labels.txt'), "w") as f:
127
+ for new_label in new_labels:
128
+ f.write(f'{new_label}\n')
129
+
130
+
131
+ def build_vocab(root_dir):
132
+ img_paths = []
133
+ labels = []
134
+
135
+ # Read labels from text file
136
+ with open(os.path.join(root_dir, 'ocr_dataset', 'labels.txt'), "r") as f:
137
+ for label in f:
138
+ labels.append(label.strip().split("\t")[1])
139
+ img_paths.append(label.strip().split("\t")[0])
140
+
141
+ # build the vocab
142
+ vocab = set()
143
+ for label in labels:
144
+ for i in range(len(label)):
145
+ vocab.add(label[i])
146
+
147
+ # "blank" character
148
+ vocab = list(sorted(vocab))
149
+ vocab = "".join(vocab)
150
+ blank_char = '@'
151
+ vocab = vocab + 'z'
152
+ vocab = vocab + blank_char
153
+
154
+ # build a dictionary convert from vocab to idx and idx to vocab
155
+ char_to_idx = {
156
+ char: idx + 1 for idx, char in enumerate(vocab)
157
+ }
158
+ idx_to_char = {
159
+ idx: char for char, idx in char_to_idx.items()
160
+ }
161
+
162
+ return char_to_idx, idx_to_char
163
+
164
+ def get_imagepaths_and_labels(root_path):
165
+ img_paths = []
166
+ labels = []
167
+
168
+ # Read labels from text file
169
+ with open(os.path.join(root_path, 'ocr_dataset', 'labels.txt'), "r") as f:
170
+ for label in f:
171
+ labels.append(label.strip().split("\t")[1])
172
+ img_paths.append(label.strip().split("\t")[0])
173
+
174
+ return img_paths, labels
175
+
176
+ def encode(label, char_to_idx, labels):
177
+ max_length_label = np.max([len(lb) for lb in labels])
178
+
179
+ # encode label
180
+ encoded_label = torch.tensor(
181
+ [char_to_idx[char.lower()] for char in label],
182
+ dtype=torch.int32
183
+ )
184
+ label_len = len(encoded_label)
185
+ length = torch.tensor(
186
+ label_len,
187
+ dtype=torch.int32
188
+ )
189
+ padded_label = F.pad(
190
+ encoded_label,
191
+ (0, max_length_label-label_len),
192
+ value=0
193
+ )
194
+ return padded_label, length
195
+
196
+ def decode(encoded_label, idx_to_char, char_to_idx, blank_char='@'):
197
+ label = []
198
+ encoded_label = encoded_label.detach().numpy()
199
+ for i in range(len(encoded_label)):
200
+ if encoded_label[i] == 0:
201
+ break
202
+ elif (i == 0 or encoded_label[i] != encoded_label[i-1]) and encoded_label[i] != char_to_idx[blank_char]:
203
+ label.append(idx_to_char[encoded_label[i]])
204
+
205
+ label = "".join(label)
206
+ return label
207
+
208
+ def main():
209
+ parser = argparse.ArgumentParser()
210
+ parser.add_argument("--path", type=str, default=os.getcwd(), help="Path to the root directory")
211
+ args = parser.parse_args()
212
+
213
+ root_path = os.path.join(args.path, 'Dataset')
214
+
215
+ image_paths, image_sizes, bboxes, image_labels = extract_data_from_xml(root_path)
216
+ save_dir = 'Dataset/ocr_dataset'
217
+ split_bboxes_from_image(image_paths, image_labels, bboxes, save_dir)
218
+ char_to_idx, idx_to_char = build_vocab(root_path)
219
+
220
+ if __name__ == '__main__':
221
+ main()
src/Text_Recognization/text_recognization.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+
5
+ from torchvision.models import resnet101
6
+
7
+ class BackBone(nn.Module):
8
+ def __init__(self, num_unfreeze_layers=3):
9
+ super(BackBone, self).__init__()
10
+ model = resnet101(weights='IMAGENET1K_V2', progress=True)
11
+ feature_maps = list(model.children())[:8]
12
+
13
+ # Adding an AdaptiveAvgPooling (batch_size, 2048, 8, 8) -> (batch_size, 2048, 1, 8)
14
+ feature_maps.append(nn.AdaptiveAvgPool2d((1, None)))
15
+ self.backbone = nn.Sequential(*feature_maps)
16
+
17
+ for layer in list(self.backbone.parameters())[-(num_unfreeze_layers+1):]:
18
+ layer.requires_grad = True
19
+
20
+ def forward(self, image):
21
+ return self.backbone(image)
22
+
23
+ class CRNN(nn.Module):
24
+ def __init__(self, vocab_size, hidden_size, n_layers, dropout=0.2, num_unfreeze_layers=3):
25
+ super(CRNN, self).__init__()
26
+ self.backbone = BackBone(num_unfreeze_layers=num_unfreeze_layers)
27
+
28
+ self.mapSeq = nn.Sequential(
29
+ nn.Linear(2048, 512),
30
+ nn.ReLU(),
31
+ nn.Dropout(p=dropout)
32
+ )
33
+
34
+ self.gru = nn.GRU(
35
+ input_size=512,
36
+ hidden_size=hidden_size,
37
+ num_layers=n_layers,
38
+ bidirectional=True,
39
+ batch_first=True,
40
+ dropout=dropout if n_layers > 1 else 0
41
+ )
42
+
43
+ self.layer_norm = nn.LayerNorm(hidden_size * 2)
44
+
45
+ # Dense layers
46
+ self.out = nn.Sequential(
47
+ nn.Linear(hidden_size * 2, vocab_size),
48
+ nn.LogSoftmax(dim=2)
49
+ )
50
+
51
+ def forward(self, x):
52
+ x = self.backbone(x)
53
+ # (batch_size, 2048, 1, 8) -> (batch_size, 8, 2048, 1)
54
+ x = x.permute(0, 3, 1, 2)
55
+ # flatten -> (batch_size, 8, 2048)
56
+ x = x.view(x.size(0), x.size(1), -1)
57
+ x = self.mapSeq(x)
58
+ x, _ = self.gru(x)
59
+ x = self.layer_norm(x)
60
+ x = self.out(x)
61
+ # (batch_size, 8, vocab_size) -> (8, batch_size, vocab_size)
62
+ x = x.permute(1, 0, 2)
63
+
64
+ return x
src/Text_Recognization/trainer.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ import os
4
+ import argparse
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision
8
+ from tqdm import tqdm
9
+
10
+ sys.path.append(os.getcwd())
11
+
12
+ from src.Text_Recognization.text_recognization import *
13
+ from src.Text_Recognization.prepare_dataset import *
14
+ from src.Text_Recognization.dataloader import *
15
+
16
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+
18
+ def load_json_config(config_path):
19
+ with open(config_path, "r") as f:
20
+ config = json.load(f)
21
+
22
+ return config
23
+
24
+ def evaluate(model, dataloader, criterion, device):
25
+ model.eval()
26
+ losses = []
27
+
28
+ with torch.no_grad():
29
+ for images, labels, labels_len in dataloader:
30
+ images = images.to(device)
31
+ labels = labels.to(device)
32
+
33
+ outputs = model(images)
34
+ logits_lens = torch.full(
35
+ size=(outputs.size(1), ),
36
+ fill_value=outputs.size(0),
37
+ dtype=torch.long
38
+ ).to(device)
39
+
40
+ loss = criterion(outputs, labels, logits_lens, labels_len)
41
+ losses.append(loss.item())
42
+
43
+ eval_loss = sum(losses) / len(losses)
44
+ return eval_loss
45
+
46
+
47
+ def training_loop(model, train_loader, val_loader, learning_rate, epochs, optimizer, criterion, scheduler, device):
48
+ model.to(device)
49
+
50
+ train_losses = []
51
+ val_losses = []
52
+
53
+ for epoch in range(epochs):
54
+ model.train()
55
+
56
+ batch_losses = []
57
+ for images, labels, labels_len in tqdm(train_loader):
58
+ images = images.to(device)
59
+ labels = labels.to(device)
60
+
61
+ optimizer.zero_grad()
62
+ outputs = model(images)
63
+
64
+ logits_lens = torch.full(
65
+ size=(outputs.size(1), ),
66
+ fill_value=outputs.size(0),
67
+ dtype=torch.long
68
+ ).to(device)
69
+
70
+ loss = criterion(outputs, labels, logits_lens, labels_len)
71
+
72
+ loss.backward()
73
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
74
+ optimizer.step()
75
+
76
+ batch_losses.append(loss.item())
77
+
78
+ train_loss = sum(batch_losses) / len(batch_losses)
79
+ train_losses.append(train_loss)
80
+
81
+ val_loss = evaluate(model, val_loader, criterion, device)
82
+ val_losses.append(val_loss)
83
+
84
+ print(f"epoch: {epoch+1}/{epochs}\ttrain_loss:{train_loss}\tval_loss:{val_loss}")
85
+
86
+ scheduler.step()
87
+
88
+ return train_losses, val_losses
89
+
90
+ def main():
91
+ parser = argparse.ArgumentParser()
92
+ parser.add_argument('--root_path', type=str, default=os.getcwd(), help='Path to the root directory')
93
+ parser.add_argument('--checkpoints_path', type=str, default=os.path.join(os.getcwd(), 'checkpoints'), help='Path to the checkpoint directory')
94
+
95
+ args = parser.parse_args()
96
+ config_path = 'src/config.json'
97
+ dataset_path = os.path.join(args.root_path, 'Dataset')
98
+ config = load_json_config(config_path)
99
+
100
+ # dictionary char and idx
101
+ char_to_idx, idx_to_char = build_vocab(dataset_path)
102
+
103
+ # model
104
+ model = CRNN(vocab_size=config['vocab_size'], hidden_size=config['CRNN']['hidden_size'], n_layers=config['CRNN']['n_layers'])
105
+
106
+ # dataloader
107
+ train_loader, val_loader, test_loader = get_dataloader()
108
+
109
+ # define hyper parammeters
110
+ criterion = nn.CTCLoss(
111
+ blank=char_to_idx[config['blank_char']],
112
+ zero_infinity=True,
113
+ reduction='mean'
114
+ )
115
+ optimizer = torch.optim.Adam(
116
+ model.parameters(),
117
+ lr=config['CRNN']['learning_rate'],
118
+ weight_decay=config['CRNN']['weight_decay']
119
+ )
120
+ scheduler = torch.optim.lr_scheduler.StepLR(
121
+ optimizer=optimizer,
122
+ step_size=config['CRNN']['scheduler_step_size'],
123
+ gamma=0.1
124
+ )
125
+
126
+ # training loop
127
+ train_losses, val_losses = training_loop(
128
+ model=model,
129
+ train_loader=train_loader,
130
+ val_loader=val_loader,
131
+ learning_rate=config['CRNN']['learning_rate'],
132
+ epochs=config['CRNN']['epochs'],
133
+ optimizer=optimizer,
134
+ criterion=criterion,
135
+ scheduler=scheduler,
136
+ device=device
137
+ )
138
+
139
+ # save model
140
+ if not os.path.exists(args.checkpoints_path):
141
+ os.makedirs(args.checkpoints_path)
142
+ os.makedirs(os.path.join(args.checkpoints_path, 'losses'))
143
+ torch.save(model.state_dict(), os.path.join(args.checkpoints_path, 'crnn.pt'))
144
+
145
+ # draw losses
146
+ fig, axis = plt.subplots(1, 2, figsize=(8, 8))
147
+ axis[0].plot(train_losses, label='train_loss')
148
+ axis[0].set_xlabel('Epochs')
149
+ axis[0].set_ylabel('Loss')
150
+ axis[0].axis('off')
151
+ axis[0].legend()
152
+
153
+ axis[1].plot(val_losses, label='val_loss')
154
+ axis[1].set_xlabel('Epochs')
155
+ axis[1].set_ylabel('Loss')
156
+ axis[1].axis('off')
157
+ axis[1].legend()
158
+
159
+ plt.savefig(os.path.join(args.checkpoints_path, 'losses', 'losses.png'))
160
+
161
+ if __name__ == '__main__':
162
+ main()
src/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "yolov11":
3
+ {
4
+ "epochs": 100,
5
+ "image_size": 640,
6
+ "cache": true,
7
+ "patience": 20,
8
+ "plots": true
9
+ },
10
+ "CRNN":
11
+ {
12
+ "batch_size": 64,
13
+ "epochs": 100,
14
+ "hidden_size": 256,
15
+ "n_layers": 3,
16
+ "dropout": 0.2,
17
+ "unfreeze_layers": 3,
18
+ "learning_rate": 5e-4,
19
+ "weight_decay": 1e-5,
20
+ "scheduler_step_size": 30
21
+ },
22
+ "blank_char": "@",
23
+ "vocab_size": 73
24
+ }
src/predict.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import cv2
5
+ import argparse
6
+ import matplotlib.pyplot as plt
7
+
8
+ import ultralytics
9
+ import torch
10
+ import torch.nn as nn
11
+ import torchvision
12
+ from torchvision.models import resnet101
13
+ from torchvision import transforms
14
+
15
+ sys.path.append(os.getcwd())
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+
18
+ from ultralytics import YOLO
19
+ from src.Text_Recognization.text_recognization import *
20
+ from src.Text_Recognization.prepare_dataset import *
21
+
22
+ # config
23
+ def load_json_config(config_path):
24
+ with open(config_path, "r") as f:
25
+ config = json.load(f)
26
+
27
+ return config
28
+
29
+ config = load_json_config('src/config.json')
30
+
31
+ # char to idx
32
+ char_to_idx, idx_to_char = build_vocab('Dataset')
33
+
34
+ # text detection model
35
+ text_det_model_path = 'checkpoints/yolov11m.pt'
36
+ yolo = YOLO(text_det_model_path)
37
+
38
+ # text recognition model
39
+ text_rec_model_path = 'checkpoints/crnn_extend_vocab.pt'
40
+
41
+ # rcnn model
42
+ rcnn_model = CRNN(vocab_size=74, hidden_size=config['CRNN']['hidden_size'], n_layers=config['CRNN']['n_layers'])
43
+ rcnn_model.load_state_dict(torch.load(text_rec_model_path, weights_only=True, map_location=torch.device('cpu')))
44
+
45
+ def text_detection(img_path, text_det_model):
46
+ text_det_results = text_det_model(img_path, verbose=False)[0]
47
+
48
+ bboxes = text_det_results.boxes.xyxy.tolist()
49
+ classes = text_det_results.boxes.cls.tolist()
50
+ names = text_det_results.names
51
+ confs = text_det_results.boxes.conf.tolist()
52
+
53
+ return bboxes, classes, names, confs
54
+
55
+ def visualize_gt_bboxes_yolo(image_path, gt_location_yolo):
56
+ image = cv2.imread(image_path)
57
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
58
+
59
+ # Convert to original format
60
+ for data in gt_location_yolo:
61
+ xmin, ymin, xmax, ymax = data
62
+ xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
63
+
64
+ image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=(255, 0, 0), thickness=2)
65
+
66
+ plt.imshow(image)
67
+ plt.axis('off')
68
+ plt.show()
69
+
70
+ def text_recognization(image, data_transforms, text_reg_model, idx_to_char=idx_to_char, device=device):
71
+ transformsed_image = data_transforms(image)
72
+ transformsed_image = transformsed_image.unsqueeze(0).to(device)
73
+ text_reg_model.to(device)
74
+ text_reg_model.eval()
75
+
76
+ with torch.no_grad():
77
+ preds = text_reg_model(transformsed_image)
78
+ _, idx = torch.max(preds, dim=2)
79
+ idx = idx.view(-1)
80
+ text = decode(idx, idx_to_char, char_to_idx)
81
+
82
+ return text, idx
83
+
84
+ def visualize_detection(image, detections):
85
+ plt.figure(figsize=(10, 8))
86
+
87
+ for bbox, detected_classes, conf, text, _ in detections:
88
+ x1, y1, x2, y2 = bbox
89
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
90
+
91
+ image = cv2.rectangle(image, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2)
92
+ image = cv2.putText(image, f"{conf:.2f} {text}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
93
+
94
+ plt.imshow(image)
95
+ plt.axis('off')
96
+ plt.show()
97
+ return image
98
+
99
+ data_transforms = transforms.Compose([
100
+ transforms.ToTensor(),
101
+ transforms.Resize((100, 400)),
102
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
103
+ ])
104
+
105
+ def prediction(image, text_det_model=yolo, text_reg_model=rcnn_model, idx_to_char=idx_to_char, char_to_idx=char_to_idx, data_transforms=data_transforms, device=device):
106
+ # detection
107
+ bboxes, classes, names, confs = text_detection(image, text_det_model)
108
+
109
+ predictions = []
110
+ for bbox, cls, conf in zip(bboxes, classes, confs):
111
+ x1, y1, x2, y2 = bbox
112
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
113
+ detected_text = image[y1:y2, x1:x2]
114
+ text, encoded_text = text_recognization(detected_text, data_transforms, text_reg_model, idx_to_char, device)
115
+ predictions.append((bbox, cls, conf, text, encoded_text))
116
+ print(bbox, cls, conf, text)
117
+
118
+ return predictions
119
+
120
+ def main():
121
+ parser = argparse.ArgumentParser()
122
+ parser.add_argument('--image_path', type=str, help='Path to the image')
123
+ parser.add_argument('--save_path', type=str, default=None, help='Path to save the image')
124
+ args = parser.parse_args()
125
+ image_path = args.image_path
126
+ image = cv2.imread(image_path)
127
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
128
+
129
+ detections = prediction(image)
130
+ image = visualize_detection(image, detections)
131
+
132
+ if args.save_path:
133
+ print(f"Saving the image to {os.path.join(args.save_path, 'predicted_image.jpg')}")
134
+ cv2.imwrite(os.path.join(args.save_path, 'predicted_image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
135
+
136
+ if __name__ == '__main__':
137
+ main()