TungDuong commited on
Commit
f4dccd7
·
verified ·
1 Parent(s): 5215b56

Delete src

Browse files
src/Text_Localization/__pycache__/prepare_dataset.cpython-312.pyc DELETED
Binary file (7.86 kB)
 
src/Text_Localization/prepare_dataset.py DELETED
@@ -1,177 +0,0 @@
1
- import numpy as np
2
- import os
3
- import matplotlib.pyplot as plt
4
- import cv2
5
- import xml.etree.ElementTree as ET
6
- import shutil
7
- import yaml
8
-
9
- from sklearn.model_selection import train_test_split
10
-
11
- location_path = r'Dataset/locations.xml'
12
- tree = ET.parse(location_path)
13
- root = tree.getroot()
14
-
15
-
16
- def get_gt_bboxes(location_path):
17
- """get all the gt bbox of text in dataset
18
-
19
- Args:
20
- location_path: (path)
21
- Return:
22
- gt_imagepaths[1] (list): image's name
23
- gt_locations (list): bboxes of each image
24
- """
25
- gt_imagepaths = []
26
- gt_imagesizes = []
27
- gt_locations = []
28
-
29
- for image in root:
30
- # get path to image
31
- image_name = image[0].text
32
- image_path = os.path.join('Dataset', image_name)
33
- gt_imagepaths.append(image_path)
34
-
35
- # get the image size
36
- w = image[1].get('x')
37
- h = image[1].get('y')
38
- gt_imagesizes.append([h, w])
39
-
40
- # bboxes in the image
41
- bbs = []
42
- for bbox in image[2]:
43
- x = np.int64(float(bbox.get('x')))
44
- y = np.int64(float(bbox.get('y')))
45
- width = np.int64(float(bbox.get('width')))
46
- height = np.int64(float(bbox.get('height')))
47
- bbs.append([x, y, width, height])
48
-
49
- gt_locations.append(bbs)
50
-
51
- return gt_imagepaths, gt_imagesizes, gt_locations
52
-
53
- gt_imagepaths, gt_imagesizes, gt_locations = get_gt_bboxes(location_path)
54
-
55
- def visualize_gt_bboxes(image_path, gt_locations):
56
- image = cv2.imread(image_path)
57
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
58
-
59
- for gt_location in gt_locations:
60
- x, y, width, height = gt_location
61
- x, y, width, height = int(x), int(y), int(width), int(height)
62
-
63
- image = cv2.rectangle(image, (x, y), (x+width, y+height), color=(255, 0, 0), thickness=2)
64
-
65
- plt.imshow(image)
66
- plt.axis('off')
67
- plt.show()
68
-
69
- def visualize_gt_bboxes_yolo(image_path, gt_location_yolo):
70
- image = cv2.imread(image_path)
71
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
72
-
73
- image_height, image_width = image.shape[:2]
74
-
75
- # Convert to original format
76
- for data in gt_location_yolo:
77
- xc, yc, w, h = data[1:]
78
- xmin = int((xc - w/2) * image_width)
79
- ymin = int((yc - h/2) * image_height)
80
- xmax = int((xc + w/2) * image_width)
81
- ymax = int((yc + h/2) * image_height)
82
-
83
- image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=(255, 0, 0), thickness=2)
84
-
85
- plt.imshow(image)
86
- plt.axis('off')
87
- plt.show()
88
-
89
-
90
- def convert_yolo_format(gt_locations, gt_imagesizes):
91
- gt_locations_yolo = []
92
-
93
- for image, image_size in zip(gt_locations, gt_imagesizes):
94
- gt_location_yolo = []
95
- for gt_location in image:
96
- x, y, w, h = gt_location
97
- image_height, image_width = image_size
98
-
99
- xc = (x + w/2) / float(image_width)
100
- yc = (y + h/2) / float(image_height)
101
- width = w / float(image_width)
102
- height = h / float(image_height)
103
-
104
- # class = 0 -> meaning contains text
105
- class_id = 0
106
- gt_location_yolo.append([class_id, xc, yc, width, height])
107
-
108
- gt_locations_yolo.append(gt_location_yolo)
109
-
110
- return gt_locations_yolo
111
-
112
- gt_locations_yolo = convert_yolo_format(gt_locations, gt_imagesizes)
113
-
114
- def save_data_into_yolo_folder(data, src_img_dir, save_dir):
115
- # Create folder if not exist
116
- os.makedirs(save_dir, exist_ok=True)
117
-
118
- # Make images and labels folder
119
- os.makedirs(os.path.join(save_dir, 'images'), exist_ok=True)
120
- os.makedirs(os.path.join(save_dir, 'labels'), exist_ok=True)
121
-
122
- # write data into yolo folder
123
- for dt in data:
124
- # copy data
125
- image_path = dt[0]
126
- shutil.copy(image_path, os.path.join(save_dir, 'images'))
127
-
128
- #copy labels
129
- image_name = os.path.basename(image_path)
130
- image_name = os.path.splitext(image_name)[0]
131
-
132
- with open(os.path.join(save_dir, 'labels', f'{image_name}.txt'), "w") as f:
133
- for label in dt[1]:
134
- label_str = " ".join(map(str, label))
135
- f.write(f'{label_str}\n')
136
-
137
-
138
-
139
- seed = 0
140
- val_size = 0.2
141
- test_size = 0.125
142
- dataset = [[gt_imagepath, gt_location_yolo] for gt_imagepath, gt_location_yolo in zip(gt_imagepaths, gt_locations_yolo)]
143
- train_data, val_data = train_test_split(dataset, test_size=val_size, random_state=42, shuffle=True)
144
- train_data, test_data = train_test_split(train_data, test_size=test_size, random_state=42, shuffle=True)
145
-
146
- save_yolo_data_dir = 'yolo_data'
147
- os.makedirs(save_yolo_data_dir, exist_ok=True)
148
- save_data_into_yolo_folder(
149
- data=train_data,
150
- src_img_dir=save_yolo_data_dir,
151
- save_dir=os.path.join(save_yolo_data_dir, 'train')
152
- )
153
- save_data_into_yolo_folder(
154
- data=val_data,
155
- src_img_dir=save_yolo_data_dir,
156
- save_dir=os.path.join(save_yolo_data_dir, 'val')
157
- )
158
- save_data_into_yolo_folder(
159
- data=test_data,
160
- src_img_dir=save_yolo_data_dir,
161
- save_dir=os.path.join(save_yolo_data_dir, 'test')
162
- )
163
-
164
- class_label = ['text']
165
- # Create data.yaml file
166
- data_yaml = {
167
- "path": '../yolo_data',
168
- 'train': 'train/images',
169
- 'test': 'test/images',
170
- 'val': 'val/images',
171
- 'nc': 1,
172
- 'names': class_label
173
- }
174
-
175
- yolo_yaml_path = os.path.join(save_yolo_data_dir, 'data.yaml')
176
- with open(yolo_yaml_path, "w") as f:
177
- yaml.dump(data_yaml, f, default_flow_style=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/Text_Localization/text_localization.py DELETED
@@ -1,33 +0,0 @@
1
- import ultralytics
2
- import json
3
- from ultralytics import YOLO
4
-
5
- yolo_yaml_path = 'yolo_data/data.yaml'
6
- config_path = 'src/config.json'
7
-
8
- def load_json_config(config_path):
9
- with open(config_path, "r") as f:
10
- config = json.load(f)
11
-
12
- return config
13
-
14
- config = load_json_config(config_path)
15
-
16
- # Load model
17
- model = YOLO('yolo11m.pt')
18
-
19
- # Train model
20
- results = model.train(
21
- data=yolo_yaml_path,
22
- epochs=config['yolov11']['epochs'],
23
- imgsz=config['yolov11']['image_size'],
24
- cache=config['yolov11']['cache'],
25
- patience=config['yolov11']['patience'],
26
- plots=config['yolov11']['plots']
27
- )
28
-
29
- # Evaluate model
30
- model_path = 'checkpoints/yolov11m.pt'
31
- model = YOLO(model_path)
32
-
33
- metrics = model.val()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/Text_Recognization/__pycache__/dataloader.cpython-312.pyc DELETED
Binary file (4.38 kB)
 
src/Text_Recognization/__pycache__/prepare_dataset.cpython-312.pyc DELETED
Binary file (11.1 kB)
 
src/Text_Recognization/__pycache__/text_recognization.cpython-312.pyc DELETED
Binary file (3.54 kB)
 
src/Text_Recognization/dataloader.py DELETED
@@ -1,96 +0,0 @@
1
- import torch
2
- import torchvision
3
- import json
4
- import sys
5
-
6
- from torch.utils.data import Dataset, DataLoader
7
- from torchvision import transforms
8
- from sklearn.model_selection import train_test_split
9
- from src.Text_Recognization.prepare_dataset import *
10
-
11
- # data augmentation
12
- data_transforms = {
13
- "train": transforms.Compose(
14
- [
15
- transforms.ToTensor(),
16
- transforms.Resize((100, 400)),
17
- transforms.ColorJitter(
18
- brightness=0.5,
19
- contrast=0.5,
20
- saturation=0.5
21
- ),
22
- transforms.GaussianBlur(3),
23
- transforms.RandomAffine(
24
- degrees=1,
25
- shear=1
26
- ),
27
- transforms.RandomPerspective(
28
- distortion_scale=0.3,
29
- p=0.5
30
- ),
31
- transforms.RandomRotation(degrees=15),
32
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33
- ]
34
- ),
35
- "val": transforms.Compose(
36
- [
37
- transforms.ToTensor(),
38
- transforms.Resize((100, 400)),
39
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
40
- ]
41
- )
42
- }
43
-
44
- def load_json_config(config_path):
45
- with open(config_path, "r") as f:
46
- config = json.load(f)
47
-
48
- return config
49
-
50
- # Dataloader
51
- class STRDataset(Dataset):
52
- def __init__(self, image_paths, labels, char_to_idx, transforms=None):
53
- self.image_paths = image_paths
54
- self.labels = labels
55
- self.char_to_idx = char_to_idx
56
- self.transforms= transforms
57
-
58
- def __len__(self):
59
- return len(self.image_paths)
60
-
61
- def __getitem__(self, idx):
62
- image = cv2.imread(self.image_paths[idx])
63
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
64
-
65
- if self.transforms:
66
- image = self.transforms(image)
67
-
68
- label_encoded, length = encode(self.labels[idx], self.char_to_idx, self.labels)
69
-
70
- return image, label_encoded, length
71
-
72
- def get_dataloader():
73
- val_size = 0.1
74
- test_size = 0.1
75
- root_path = 'Dataset'
76
- config_path = 'src/config.json'
77
-
78
- # get image paths and labels
79
- image_paths, labels = get_imagepaths_and_labels(root_path)
80
- char_to_idx, idx_to_char = build_vocab(root_path)
81
-
82
-
83
- config = load_json_config(config_path)
84
-
85
- X_train, X_val, y_train, y_val = train_test_split(image_paths, labels, test_size=val_size, random_state=42, shuffle=True)
86
- X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=test_size, random_state=42, shuffle=True)
87
- train_dataset = STRDataset(X_train, y_train, char_to_idx, transforms=data_transforms['train'])
88
- train_loader = DataLoader(train_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
89
-
90
- val_dataset = STRDataset(X_val, y_val, char_to_idx, transforms=data_transforms['val'])
91
- val_loader = DataLoader(val_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
92
-
93
- test_dataset = STRDataset(X_test, y_test, char_to_idx, transforms=data_transforms['val'])
94
- test_loader = DataLoader(test_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
95
-
96
- return train_loader, val_loader, test_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/Text_Recognization/prepare_dataset.py DELETED
@@ -1,221 +0,0 @@
1
- import os
2
- import random
3
- import numpy as np
4
- import xml.etree.ElementTree as ET
5
- import cv2
6
- import json
7
- import matplotlib.pyplot as plt
8
- import argparse
9
-
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
-
14
-
15
- def extract_data_from_xml(root_path):
16
- words_path = os.path.join(root_path, 'words.xml')
17
- tree = ET.parse(words_path)
18
- root = tree.getroot()
19
-
20
- image_paths = []
21
- image_sizes = []
22
- image_labels = []
23
- bboxes = []
24
-
25
- for image in root:
26
- imagename = image[0].text
27
- image_path = os.path.join(root_path, imagename)
28
- image_paths.append(image_path)
29
-
30
- image_height = image[1].get('x')
31
- image_width = image[1].get('y')
32
- image_sizes.append([image_height, image_width])
33
-
34
- bboxes_in_image = []
35
- labels_in_bboxes = []
36
- for bbox in image[2]:
37
- x = float(bbox.get('x'))
38
- y = float(bbox.get('y'))
39
- width = float(bbox.get('width'))
40
- height = float(bbox.get('height'))
41
- bboxes_in_image.append([x, y, width, height])
42
-
43
- # get text in this bbox
44
- labels = bbox.find('tag').text
45
- labels_in_bboxes.append(labels)
46
-
47
- bboxes.append(bboxes_in_image)
48
- image_labels.append(labels_in_bboxes)
49
-
50
- return image_paths, image_sizes, bboxes, image_labels
51
-
52
- def visualize_gt_bboxes(image_path, gt_locations, gt_labels):
53
- image = cv2.imread(image_path)
54
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
55
-
56
- for gt_location, gt_label in zip(gt_locations, gt_labels):
57
- x, y, width, height = gt_location
58
- x, y, width, height = int(x), int(y), int(width), int(height)
59
-
60
- image = cv2.rectangle(image, (x, y), (x+width, y+height), color=(255, 0, 0), thickness=2)
61
- image = cv2.putText(image, gt_label, (x, y-10), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale = 3, color=(255, 0, 0), thickness=2)
62
-
63
- plt.imshow(image)
64
- plt.axis('off')
65
- plt.show()
66
-
67
-
68
- def split_bboxes_from_image(image_paths, image_labels, bboxes, save_dir):
69
- """create a new dataset contains bboxes and corresponding labels
70
-
71
- Args:
72
- image_paths
73
- image_labels
74
- bboxes
75
- save_dir
76
- Return:
77
- non-return
78
- """
79
- os.makedirs(save_dir, exist_ok=True)
80
- os.makedirs('unvalid_images', exist_ok=True)
81
-
82
- bboxes_idx = 0
83
- unvalid_bboxes = 0
84
- new_labels = [] # List to store labels
85
- for image_path, bbox, label in zip(image_paths, bboxes, image_labels):
86
- image = cv2.imread(image_path)
87
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
88
-
89
- if image is None:
90
- print(image_path)
91
- continue
92
-
93
- for bb, lb in zip(bbox, label):
94
- x, y, width, height = bb
95
- x, y, width, height = int(x), int(y), int(width), int(height)
96
-
97
- cropped_text = image[y:y+height, x:x+width]
98
-
99
- # Filter if x, y, width, height is invalid cordinates
100
- if x < 0 or y < 0 or width < 0 or height < 0:
101
- continue
102
-
103
- # Filter text contain special characters
104
- if 'é' in [lb[i].lower() for i in range(len(lb))] or 'ñ' in [lb[i].lower() for i in range(len(lb))] or '£' in [lb[i].lower() for i in range(len(lb))]:
105
- continue
106
-
107
- # Filter out if text is too light or too dark
108
- if np.mean(cropped_text) < 30 or np.mean(cropped_text) > 230:
109
- cv2.imwrite(f'unvalid_images\\unvalid_image{unvalid_bboxes}_{lb}.jpg', cropped_text)
110
- unvalid_bboxes += 1
111
- continue
112
-
113
- # Filter out if image is too small
114
- if width < 10 or height < 10:
115
- cv2.imwrite(f'unvalid_images\\unvalid_image{unvalid_bboxes}_{lb}.jpg', cropped_text)
116
- unvalid_bboxes += 1
117
- continue
118
-
119
- new_image_path = os.path.join(save_dir, f'cropped_image{bboxes_idx}.jpg')
120
- cv2.imwrite(new_image_path, cropped_text)
121
- new_label = new_image_path + '\t' + lb
122
- new_labels.append(new_label)
123
- bboxes_idx += 1
124
-
125
- # Write labels into a text file
126
- with open(os.path.join(save_dir, 'labels.txt'), "w") as f:
127
- for new_label in new_labels:
128
- f.write(f'{new_label}\n')
129
-
130
-
131
- def build_vocab(root_dir):
132
- img_paths = []
133
- labels = []
134
-
135
- # Read labels from text file
136
- with open(os.path.join(root_dir, 'ocr_dataset', 'labels.txt'), "r") as f:
137
- for label in f:
138
- labels.append(label.strip().split("\t")[1])
139
- img_paths.append(label.strip().split("\t")[0])
140
-
141
- # build the vocab
142
- vocab = set()
143
- for label in labels:
144
- for i in range(len(label)):
145
- vocab.add(label[i])
146
-
147
- # "blank" character
148
- vocab = list(sorted(vocab))
149
- vocab = "".join(vocab)
150
- blank_char = '@'
151
- vocab = vocab + 'z'
152
- vocab = vocab + blank_char
153
-
154
- # build a dictionary convert from vocab to idx and idx to vocab
155
- char_to_idx = {
156
- char: idx + 1 for idx, char in enumerate(vocab)
157
- }
158
- idx_to_char = {
159
- idx: char for char, idx in char_to_idx.items()
160
- }
161
-
162
- return char_to_idx, idx_to_char
163
-
164
- def get_imagepaths_and_labels(root_path):
165
- img_paths = []
166
- labels = []
167
-
168
- # Read labels from text file
169
- with open(os.path.join(root_path, 'ocr_dataset', 'labels.txt'), "r") as f:
170
- for label in f:
171
- labels.append(label.strip().split("\t")[1])
172
- img_paths.append(label.strip().split("\t")[0])
173
-
174
- return img_paths, labels
175
-
176
- def encode(label, char_to_idx, labels):
177
- max_length_label = np.max([len(lb) for lb in labels])
178
-
179
- # encode label
180
- encoded_label = torch.tensor(
181
- [char_to_idx[char.lower()] for char in label],
182
- dtype=torch.int32
183
- )
184
- label_len = len(encoded_label)
185
- length = torch.tensor(
186
- label_len,
187
- dtype=torch.int32
188
- )
189
- padded_label = F.pad(
190
- encoded_label,
191
- (0, max_length_label-label_len),
192
- value=0
193
- )
194
- return padded_label, length
195
-
196
- def decode(encoded_label, idx_to_char, char_to_idx, blank_char='@'):
197
- label = []
198
- encoded_label = encoded_label.detach().numpy()
199
- for i in range(len(encoded_label)):
200
- if encoded_label[i] == 0:
201
- break
202
- elif (i == 0 or encoded_label[i] != encoded_label[i-1]) and encoded_label[i] != char_to_idx[blank_char]:
203
- label.append(idx_to_char[encoded_label[i]])
204
-
205
- label = "".join(label)
206
- return label
207
-
208
- def main():
209
- parser = argparse.ArgumentParser()
210
- parser.add_argument("--path", type=str, default=os.getcwd(), help="Path to the root directory")
211
- args = parser.parse_args()
212
-
213
- root_path = os.path.join(args.path, 'Dataset')
214
-
215
- image_paths, image_sizes, bboxes, image_labels = extract_data_from_xml(root_path)
216
- save_dir = 'Dataset/ocr_dataset'
217
- split_bboxes_from_image(image_paths, image_labels, bboxes, save_dir)
218
- char_to_idx, idx_to_char = build_vocab(root_path)
219
-
220
- if __name__ == '__main__':
221
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/Text_Recognization/text_recognization.py DELETED
@@ -1,64 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchvision
4
-
5
- from torchvision.models import resnet101
6
-
7
- class BackBone(nn.Module):
8
- def __init__(self, num_unfreeze_layers=3):
9
- super(BackBone, self).__init__()
10
- model = resnet101(weights='IMAGENET1K_V2', progress=True)
11
- feature_maps = list(model.children())[:8]
12
-
13
- # Adding an AdaptiveAvgPooling (batch_size, 2048, 8, 8) -> (batch_size, 2048, 1, 8)
14
- feature_maps.append(nn.AdaptiveAvgPool2d((1, None)))
15
- self.backbone = nn.Sequential(*feature_maps)
16
-
17
- for layer in list(self.backbone.parameters())[-(num_unfreeze_layers+1):]:
18
- layer.requires_grad = True
19
-
20
- def forward(self, image):
21
- return self.backbone(image)
22
-
23
- class CRNN(nn.Module):
24
- def __init__(self, vocab_size, hidden_size, n_layers, dropout=0.2, num_unfreeze_layers=3):
25
- super(CRNN, self).__init__()
26
- self.backbone = BackBone(num_unfreeze_layers=num_unfreeze_layers)
27
-
28
- self.mapSeq = nn.Sequential(
29
- nn.Linear(2048, 512),
30
- nn.ReLU(),
31
- nn.Dropout(p=dropout)
32
- )
33
-
34
- self.gru = nn.GRU(
35
- input_size=512,
36
- hidden_size=hidden_size,
37
- num_layers=n_layers,
38
- bidirectional=True,
39
- batch_first=True,
40
- dropout=dropout if n_layers > 1 else 0
41
- )
42
-
43
- self.layer_norm = nn.LayerNorm(hidden_size * 2)
44
-
45
- # Dense layers
46
- self.out = nn.Sequential(
47
- nn.Linear(hidden_size * 2, vocab_size),
48
- nn.LogSoftmax(dim=2)
49
- )
50
-
51
- def forward(self, x):
52
- x = self.backbone(x)
53
- # (batch_size, 2048, 1, 8) -> (batch_size, 8, 2048, 1)
54
- x = x.permute(0, 3, 1, 2)
55
- # flatten -> (batch_size, 8, 2048)
56
- x = x.view(x.size(0), x.size(1), -1)
57
- x = self.mapSeq(x)
58
- x, _ = self.gru(x)
59
- x = self.layer_norm(x)
60
- x = self.out(x)
61
- # (batch_size, 8, vocab_size) -> (8, batch_size, vocab_size)
62
- x = x.permute(1, 0, 2)
63
-
64
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/Text_Recognization/trainer.py DELETED
@@ -1,162 +0,0 @@
1
- import json
2
- import sys
3
- import os
4
- import argparse
5
- import torch
6
- import torch.nn as nn
7
- import torchvision
8
- from tqdm import tqdm
9
-
10
- sys.path.append(os.getcwd())
11
-
12
- from src.Text_Recognization.text_recognization import *
13
- from src.Text_Recognization.prepare_dataset import *
14
- from src.Text_Recognization.dataloader import *
15
-
16
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
-
18
- def load_json_config(config_path):
19
- with open(config_path, "r") as f:
20
- config = json.load(f)
21
-
22
- return config
23
-
24
- def evaluate(model, dataloader, criterion, device):
25
- model.eval()
26
- losses = []
27
-
28
- with torch.no_grad():
29
- for images, labels, labels_len in dataloader:
30
- images = images.to(device)
31
- labels = labels.to(device)
32
-
33
- outputs = model(images)
34
- logits_lens = torch.full(
35
- size=(outputs.size(1), ),
36
- fill_value=outputs.size(0),
37
- dtype=torch.long
38
- ).to(device)
39
-
40
- loss = criterion(outputs, labels, logits_lens, labels_len)
41
- losses.append(loss.item())
42
-
43
- eval_loss = sum(losses) / len(losses)
44
- return eval_loss
45
-
46
-
47
- def training_loop(model, train_loader, val_loader, learning_rate, epochs, optimizer, criterion, scheduler, device):
48
- model.to(device)
49
-
50
- train_losses = []
51
- val_losses = []
52
-
53
- for epoch in range(epochs):
54
- model.train()
55
-
56
- batch_losses = []
57
- for images, labels, labels_len in tqdm(train_loader):
58
- images = images.to(device)
59
- labels = labels.to(device)
60
-
61
- optimizer.zero_grad()
62
- outputs = model(images)
63
-
64
- logits_lens = torch.full(
65
- size=(outputs.size(1), ),
66
- fill_value=outputs.size(0),
67
- dtype=torch.long
68
- ).to(device)
69
-
70
- loss = criterion(outputs, labels, logits_lens, labels_len)
71
-
72
- loss.backward()
73
- torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
74
- optimizer.step()
75
-
76
- batch_losses.append(loss.item())
77
-
78
- train_loss = sum(batch_losses) / len(batch_losses)
79
- train_losses.append(train_loss)
80
-
81
- val_loss = evaluate(model, val_loader, criterion, device)
82
- val_losses.append(val_loss)
83
-
84
- print(f"epoch: {epoch+1}/{epochs}\ttrain_loss:{train_loss}\tval_loss:{val_loss}")
85
-
86
- scheduler.step()
87
-
88
- return train_losses, val_losses
89
-
90
- def main():
91
- parser = argparse.ArgumentParser()
92
- parser.add_argument('--root_path', type=str, default=os.getcwd(), help='Path to the root directory')
93
- parser.add_argument('--checkpoints_path', type=str, default=os.path.join(os.getcwd(), 'checkpoints'), help='Path to the checkpoint directory')
94
-
95
- args = parser.parse_args()
96
- config_path = 'src/config.json'
97
- dataset_path = os.path.join(args.root_path, 'Dataset')
98
- config = load_json_config(config_path)
99
-
100
- # dictionary char and idx
101
- char_to_idx, idx_to_char = build_vocab(dataset_path)
102
-
103
- # model
104
- model = CRNN(vocab_size=config['vocab_size'], hidden_size=config['CRNN']['hidden_size'], n_layers=config['CRNN']['n_layers'])
105
-
106
- # dataloader
107
- train_loader, val_loader, test_loader = get_dataloader()
108
-
109
- # define hyper parammeters
110
- criterion = nn.CTCLoss(
111
- blank=char_to_idx[config['blank_char']],
112
- zero_infinity=True,
113
- reduction='mean'
114
- )
115
- optimizer = torch.optim.Adam(
116
- model.parameters(),
117
- lr=config['CRNN']['learning_rate'],
118
- weight_decay=config['CRNN']['weight_decay']
119
- )
120
- scheduler = torch.optim.lr_scheduler.StepLR(
121
- optimizer=optimizer,
122
- step_size=config['CRNN']['scheduler_step_size'],
123
- gamma=0.1
124
- )
125
-
126
- # training loop
127
- train_losses, val_losses = training_loop(
128
- model=model,
129
- train_loader=train_loader,
130
- val_loader=val_loader,
131
- learning_rate=config['CRNN']['learning_rate'],
132
- epochs=config['CRNN']['epochs'],
133
- optimizer=optimizer,
134
- criterion=criterion,
135
- scheduler=scheduler,
136
- device=device
137
- )
138
-
139
- # save model
140
- if not os.path.exists(args.checkpoints_path):
141
- os.makedirs(args.checkpoints_path)
142
- os.makedirs(os.path.join(args.checkpoints_path, 'losses'))
143
- torch.save(model.state_dict(), os.path.join(args.checkpoints_path, 'crnn.pt'))
144
-
145
- # draw losses
146
- fig, axis = plt.subplots(1, 2, figsize=(8, 8))
147
- axis[0].plot(train_losses, label='train_loss')
148
- axis[0].set_xlabel('Epochs')
149
- axis[0].set_ylabel('Loss')
150
- axis[0].axis('off')
151
- axis[0].legend()
152
-
153
- axis[1].plot(val_losses, label='val_loss')
154
- axis[1].set_xlabel('Epochs')
155
- axis[1].set_ylabel('Loss')
156
- axis[1].axis('off')
157
- axis[1].legend()
158
-
159
- plt.savefig(os.path.join(args.checkpoints_path, 'losses', 'losses.png'))
160
-
161
- if __name__ == '__main__':
162
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/__pycache__/pipeline_end2end.cpython-312.pyc DELETED
Binary file (6.49 kB)
 
src/__pycache__/predict.cpython-312.pyc DELETED
Binary file (7.9 kB)
 
src/app.py DELETED
@@ -1,46 +0,0 @@
1
- import gradio as gr
2
- import numpy as np
3
- import os
4
- import json
5
- import cv2
6
- import sys
7
- import torch
8
- import torch.nn as nn
9
- import torchvision
10
-
11
- sys.path.append(os.getcwd())
12
- from predict import *
13
-
14
- def visualize_image(image, detections):
15
- for bbox, detected_class, conf, text, _ in detections:
16
- x1, y1, x2, y2 = bbox
17
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
18
-
19
- image = cv2.rectangle(image, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2)
20
- image = cv2.putText(image, f"{conf:.2f} {text}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
21
-
22
- return image
23
-
24
- def pipeline(image):
25
- image = np.array(image)
26
-
27
- predictions = prediction(image)
28
-
29
- # Filter low conf boxes
30
- filter_predictions = []
31
- for bbox, cls, conf, text, encoded_text in predictions:
32
- if conf > 0.7:
33
- filter_predictions.append([bbox, cls, conf, text, encoded_text])
34
-
35
- image = visualize_image(image, filter_predictions)
36
- return image
37
-
38
- demo = gr.Interface(
39
- fn=pipeline,
40
- inputs=gr.Image(type="pil", label="Input Image"),
41
- outputs="image",
42
- title="Scene Text Recognization",
43
- description="Recognize text in scene images"
44
- )
45
-
46
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "yolov11":
3
- {
4
- "epochs": 100,
5
- "image_size": 640,
6
- "cache": true,
7
- "patience": 20,
8
- "plots": true
9
- },
10
- "CRNN":
11
- {
12
- "batch_size": 64,
13
- "epochs": 100,
14
- "hidden_size": 256,
15
- "n_layers": 3,
16
- "dropout": 0.2,
17
- "unfreeze_layers": 3,
18
- "learning_rate": 5e-4,
19
- "weight_decay": 1e-5,
20
- "scheduler_step_size": 30
21
- },
22
- "blank_char": "@",
23
- "vocab_size": 73
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/predict.py DELETED
@@ -1,137 +0,0 @@
1
- import os
2
- import sys
3
- import json
4
- import cv2
5
- import argparse
6
- import matplotlib.pyplot as plt
7
-
8
- import ultralytics
9
- import torch
10
- import torch.nn as nn
11
- import torchvision
12
- from torchvision.models import resnet101
13
- from torchvision import transforms
14
-
15
- sys.path.append(os.getcwd())
16
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
-
18
- from ultralytics import YOLO
19
- from src.Text_Recognization.text_recognization import *
20
- from src.Text_Recognization.prepare_dataset import *
21
-
22
- # config
23
- def load_json_config(config_path):
24
- with open(config_path, "r") as f:
25
- config = json.load(f)
26
-
27
- return config
28
-
29
- config = load_json_config('src/config.json')
30
-
31
- # char to idx
32
- char_to_idx, idx_to_char = build_vocab('Dataset')
33
-
34
- # text detection model
35
- text_det_model_path = 'checkpoints/yolov11m.pt'
36
- yolo = YOLO(text_det_model_path)
37
-
38
- # text recognition model
39
- text_rec_model_path = 'checkpoints/crnn_extend_vocab.pt'
40
-
41
- # rcnn model
42
- rcnn_model = CRNN(vocab_size=74, hidden_size=config['CRNN']['hidden_size'], n_layers=config['CRNN']['n_layers'])
43
- rcnn_model.load_state_dict(torch.load(text_rec_model_path, weights_only=True, map_location=torch.device('cpu')))
44
-
45
- def text_detection(img_path, text_det_model):
46
- text_det_results = text_det_model(img_path, verbose=False)[0]
47
-
48
- bboxes = text_det_results.boxes.xyxy.tolist()
49
- classes = text_det_results.boxes.cls.tolist()
50
- names = text_det_results.names
51
- confs = text_det_results.boxes.conf.tolist()
52
-
53
- return bboxes, classes, names, confs
54
-
55
- def visualize_gt_bboxes_yolo(image_path, gt_location_yolo):
56
- image = cv2.imread(image_path)
57
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
58
-
59
- # Convert to original format
60
- for data in gt_location_yolo:
61
- xmin, ymin, xmax, ymax = data
62
- xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
63
-
64
- image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=(255, 0, 0), thickness=2)
65
-
66
- plt.imshow(image)
67
- plt.axis('off')
68
- plt.show()
69
-
70
- def text_recognization(image, data_transforms, text_reg_model, idx_to_char=idx_to_char, device=device):
71
- transformsed_image = data_transforms(image)
72
- transformsed_image = transformsed_image.unsqueeze(0).to(device)
73
- text_reg_model.to(device)
74
- text_reg_model.eval()
75
-
76
- with torch.no_grad():
77
- preds = text_reg_model(transformsed_image)
78
- _, idx = torch.max(preds, dim=2)
79
- idx = idx.view(-1)
80
- text = decode(idx, idx_to_char, char_to_idx)
81
-
82
- return text, idx
83
-
84
- def visualize_detection(image, detections):
85
- plt.figure(figsize=(10, 8))
86
-
87
- for bbox, detected_classes, conf, text, _ in detections:
88
- x1, y1, x2, y2 = bbox
89
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
90
-
91
- image = cv2.rectangle(image, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2)
92
- image = cv2.putText(image, f"{conf:.2f} {text}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
93
-
94
- plt.imshow(image)
95
- plt.axis('off')
96
- plt.show()
97
- return image
98
-
99
- data_transforms = transforms.Compose([
100
- transforms.ToTensor(),
101
- transforms.Resize((100, 400)),
102
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
103
- ])
104
-
105
- def prediction(image, text_det_model=yolo, text_reg_model=rcnn_model, idx_to_char=idx_to_char, char_to_idx=char_to_idx, data_transforms=data_transforms, device=device):
106
- # detection
107
- bboxes, classes, names, confs = text_detection(image, text_det_model)
108
-
109
- predictions = []
110
- for bbox, cls, conf in zip(bboxes, classes, confs):
111
- x1, y1, x2, y2 = bbox
112
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
113
- detected_text = image[y1:y2, x1:x2]
114
- text, encoded_text = text_recognization(detected_text, data_transforms, text_reg_model, idx_to_char, device)
115
- predictions.append((bbox, cls, conf, text, encoded_text))
116
- print(bbox, cls, conf, text)
117
-
118
- return predictions
119
-
120
- def main():
121
- parser = argparse.ArgumentParser()
122
- parser.add_argument('--image_path', type=str, help='Path to the image')
123
- parser.add_argument('--save_path', type=str, default=None, help='Path to save the image')
124
- args = parser.parse_args()
125
- image_path = args.image_path
126
- image = cv2.imread(image_path)
127
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
128
-
129
- detections = prediction(image)
130
- image = visualize_detection(image, detections)
131
-
132
- if args.save_path:
133
- print(f"Saving the image to {os.path.join(args.save_path, 'predicted_image.jpg')}")
134
- cv2.imwrite(os.path.join(args.save_path, 'predicted_image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
135
-
136
- if __name__ == '__main__':
137
- main()