Spaces:
Runtime error
Runtime error
Delete src
Browse files- src/Text_Localization/__pycache__/prepare_dataset.cpython-312.pyc +0 -0
- src/Text_Localization/prepare_dataset.py +0 -177
- src/Text_Localization/text_localization.py +0 -33
- src/Text_Recognization/__pycache__/dataloader.cpython-312.pyc +0 -0
- src/Text_Recognization/__pycache__/prepare_dataset.cpython-312.pyc +0 -0
- src/Text_Recognization/__pycache__/text_recognization.cpython-312.pyc +0 -0
- src/Text_Recognization/dataloader.py +0 -96
- src/Text_Recognization/prepare_dataset.py +0 -221
- src/Text_Recognization/text_recognization.py +0 -64
- src/Text_Recognization/trainer.py +0 -162
- src/__pycache__/pipeline_end2end.cpython-312.pyc +0 -0
- src/__pycache__/predict.cpython-312.pyc +0 -0
- src/app.py +0 -46
- src/config.json +0 -24
- src/predict.py +0 -137
src/Text_Localization/__pycache__/prepare_dataset.cpython-312.pyc
DELETED
|
Binary file (7.86 kB)
|
|
|
src/Text_Localization/prepare_dataset.py
DELETED
|
@@ -1,177 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import os
|
| 3 |
-
import matplotlib.pyplot as plt
|
| 4 |
-
import cv2
|
| 5 |
-
import xml.etree.ElementTree as ET
|
| 6 |
-
import shutil
|
| 7 |
-
import yaml
|
| 8 |
-
|
| 9 |
-
from sklearn.model_selection import train_test_split
|
| 10 |
-
|
| 11 |
-
location_path = r'Dataset/locations.xml'
|
| 12 |
-
tree = ET.parse(location_path)
|
| 13 |
-
root = tree.getroot()
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def get_gt_bboxes(location_path):
|
| 17 |
-
"""get all the gt bbox of text in dataset
|
| 18 |
-
|
| 19 |
-
Args:
|
| 20 |
-
location_path: (path)
|
| 21 |
-
Return:
|
| 22 |
-
gt_imagepaths[1] (list): image's name
|
| 23 |
-
gt_locations (list): bboxes of each image
|
| 24 |
-
"""
|
| 25 |
-
gt_imagepaths = []
|
| 26 |
-
gt_imagesizes = []
|
| 27 |
-
gt_locations = []
|
| 28 |
-
|
| 29 |
-
for image in root:
|
| 30 |
-
# get path to image
|
| 31 |
-
image_name = image[0].text
|
| 32 |
-
image_path = os.path.join('Dataset', image_name)
|
| 33 |
-
gt_imagepaths.append(image_path)
|
| 34 |
-
|
| 35 |
-
# get the image size
|
| 36 |
-
w = image[1].get('x')
|
| 37 |
-
h = image[1].get('y')
|
| 38 |
-
gt_imagesizes.append([h, w])
|
| 39 |
-
|
| 40 |
-
# bboxes in the image
|
| 41 |
-
bbs = []
|
| 42 |
-
for bbox in image[2]:
|
| 43 |
-
x = np.int64(float(bbox.get('x')))
|
| 44 |
-
y = np.int64(float(bbox.get('y')))
|
| 45 |
-
width = np.int64(float(bbox.get('width')))
|
| 46 |
-
height = np.int64(float(bbox.get('height')))
|
| 47 |
-
bbs.append([x, y, width, height])
|
| 48 |
-
|
| 49 |
-
gt_locations.append(bbs)
|
| 50 |
-
|
| 51 |
-
return gt_imagepaths, gt_imagesizes, gt_locations
|
| 52 |
-
|
| 53 |
-
gt_imagepaths, gt_imagesizes, gt_locations = get_gt_bboxes(location_path)
|
| 54 |
-
|
| 55 |
-
def visualize_gt_bboxes(image_path, gt_locations):
|
| 56 |
-
image = cv2.imread(image_path)
|
| 57 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 58 |
-
|
| 59 |
-
for gt_location in gt_locations:
|
| 60 |
-
x, y, width, height = gt_location
|
| 61 |
-
x, y, width, height = int(x), int(y), int(width), int(height)
|
| 62 |
-
|
| 63 |
-
image = cv2.rectangle(image, (x, y), (x+width, y+height), color=(255, 0, 0), thickness=2)
|
| 64 |
-
|
| 65 |
-
plt.imshow(image)
|
| 66 |
-
plt.axis('off')
|
| 67 |
-
plt.show()
|
| 68 |
-
|
| 69 |
-
def visualize_gt_bboxes_yolo(image_path, gt_location_yolo):
|
| 70 |
-
image = cv2.imread(image_path)
|
| 71 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 72 |
-
|
| 73 |
-
image_height, image_width = image.shape[:2]
|
| 74 |
-
|
| 75 |
-
# Convert to original format
|
| 76 |
-
for data in gt_location_yolo:
|
| 77 |
-
xc, yc, w, h = data[1:]
|
| 78 |
-
xmin = int((xc - w/2) * image_width)
|
| 79 |
-
ymin = int((yc - h/2) * image_height)
|
| 80 |
-
xmax = int((xc + w/2) * image_width)
|
| 81 |
-
ymax = int((yc + h/2) * image_height)
|
| 82 |
-
|
| 83 |
-
image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=(255, 0, 0), thickness=2)
|
| 84 |
-
|
| 85 |
-
plt.imshow(image)
|
| 86 |
-
plt.axis('off')
|
| 87 |
-
plt.show()
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def convert_yolo_format(gt_locations, gt_imagesizes):
|
| 91 |
-
gt_locations_yolo = []
|
| 92 |
-
|
| 93 |
-
for image, image_size in zip(gt_locations, gt_imagesizes):
|
| 94 |
-
gt_location_yolo = []
|
| 95 |
-
for gt_location in image:
|
| 96 |
-
x, y, w, h = gt_location
|
| 97 |
-
image_height, image_width = image_size
|
| 98 |
-
|
| 99 |
-
xc = (x + w/2) / float(image_width)
|
| 100 |
-
yc = (y + h/2) / float(image_height)
|
| 101 |
-
width = w / float(image_width)
|
| 102 |
-
height = h / float(image_height)
|
| 103 |
-
|
| 104 |
-
# class = 0 -> meaning contains text
|
| 105 |
-
class_id = 0
|
| 106 |
-
gt_location_yolo.append([class_id, xc, yc, width, height])
|
| 107 |
-
|
| 108 |
-
gt_locations_yolo.append(gt_location_yolo)
|
| 109 |
-
|
| 110 |
-
return gt_locations_yolo
|
| 111 |
-
|
| 112 |
-
gt_locations_yolo = convert_yolo_format(gt_locations, gt_imagesizes)
|
| 113 |
-
|
| 114 |
-
def save_data_into_yolo_folder(data, src_img_dir, save_dir):
|
| 115 |
-
# Create folder if not exist
|
| 116 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 117 |
-
|
| 118 |
-
# Make images and labels folder
|
| 119 |
-
os.makedirs(os.path.join(save_dir, 'images'), exist_ok=True)
|
| 120 |
-
os.makedirs(os.path.join(save_dir, 'labels'), exist_ok=True)
|
| 121 |
-
|
| 122 |
-
# write data into yolo folder
|
| 123 |
-
for dt in data:
|
| 124 |
-
# copy data
|
| 125 |
-
image_path = dt[0]
|
| 126 |
-
shutil.copy(image_path, os.path.join(save_dir, 'images'))
|
| 127 |
-
|
| 128 |
-
#copy labels
|
| 129 |
-
image_name = os.path.basename(image_path)
|
| 130 |
-
image_name = os.path.splitext(image_name)[0]
|
| 131 |
-
|
| 132 |
-
with open(os.path.join(save_dir, 'labels', f'{image_name}.txt'), "w") as f:
|
| 133 |
-
for label in dt[1]:
|
| 134 |
-
label_str = " ".join(map(str, label))
|
| 135 |
-
f.write(f'{label_str}\n')
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
seed = 0
|
| 140 |
-
val_size = 0.2
|
| 141 |
-
test_size = 0.125
|
| 142 |
-
dataset = [[gt_imagepath, gt_location_yolo] for gt_imagepath, gt_location_yolo in zip(gt_imagepaths, gt_locations_yolo)]
|
| 143 |
-
train_data, val_data = train_test_split(dataset, test_size=val_size, random_state=42, shuffle=True)
|
| 144 |
-
train_data, test_data = train_test_split(train_data, test_size=test_size, random_state=42, shuffle=True)
|
| 145 |
-
|
| 146 |
-
save_yolo_data_dir = 'yolo_data'
|
| 147 |
-
os.makedirs(save_yolo_data_dir, exist_ok=True)
|
| 148 |
-
save_data_into_yolo_folder(
|
| 149 |
-
data=train_data,
|
| 150 |
-
src_img_dir=save_yolo_data_dir,
|
| 151 |
-
save_dir=os.path.join(save_yolo_data_dir, 'train')
|
| 152 |
-
)
|
| 153 |
-
save_data_into_yolo_folder(
|
| 154 |
-
data=val_data,
|
| 155 |
-
src_img_dir=save_yolo_data_dir,
|
| 156 |
-
save_dir=os.path.join(save_yolo_data_dir, 'val')
|
| 157 |
-
)
|
| 158 |
-
save_data_into_yolo_folder(
|
| 159 |
-
data=test_data,
|
| 160 |
-
src_img_dir=save_yolo_data_dir,
|
| 161 |
-
save_dir=os.path.join(save_yolo_data_dir, 'test')
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
class_label = ['text']
|
| 165 |
-
# Create data.yaml file
|
| 166 |
-
data_yaml = {
|
| 167 |
-
"path": '../yolo_data',
|
| 168 |
-
'train': 'train/images',
|
| 169 |
-
'test': 'test/images',
|
| 170 |
-
'val': 'val/images',
|
| 171 |
-
'nc': 1,
|
| 172 |
-
'names': class_label
|
| 173 |
-
}
|
| 174 |
-
|
| 175 |
-
yolo_yaml_path = os.path.join(save_yolo_data_dir, 'data.yaml')
|
| 176 |
-
with open(yolo_yaml_path, "w") as f:
|
| 177 |
-
yaml.dump(data_yaml, f, default_flow_style=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/Text_Localization/text_localization.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
import ultralytics
|
| 2 |
-
import json
|
| 3 |
-
from ultralytics import YOLO
|
| 4 |
-
|
| 5 |
-
yolo_yaml_path = 'yolo_data/data.yaml'
|
| 6 |
-
config_path = 'src/config.json'
|
| 7 |
-
|
| 8 |
-
def load_json_config(config_path):
|
| 9 |
-
with open(config_path, "r") as f:
|
| 10 |
-
config = json.load(f)
|
| 11 |
-
|
| 12 |
-
return config
|
| 13 |
-
|
| 14 |
-
config = load_json_config(config_path)
|
| 15 |
-
|
| 16 |
-
# Load model
|
| 17 |
-
model = YOLO('yolo11m.pt')
|
| 18 |
-
|
| 19 |
-
# Train model
|
| 20 |
-
results = model.train(
|
| 21 |
-
data=yolo_yaml_path,
|
| 22 |
-
epochs=config['yolov11']['epochs'],
|
| 23 |
-
imgsz=config['yolov11']['image_size'],
|
| 24 |
-
cache=config['yolov11']['cache'],
|
| 25 |
-
patience=config['yolov11']['patience'],
|
| 26 |
-
plots=config['yolov11']['plots']
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
# Evaluate model
|
| 30 |
-
model_path = 'checkpoints/yolov11m.pt'
|
| 31 |
-
model = YOLO(model_path)
|
| 32 |
-
|
| 33 |
-
metrics = model.val()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/Text_Recognization/__pycache__/dataloader.cpython-312.pyc
DELETED
|
Binary file (4.38 kB)
|
|
|
src/Text_Recognization/__pycache__/prepare_dataset.cpython-312.pyc
DELETED
|
Binary file (11.1 kB)
|
|
|
src/Text_Recognization/__pycache__/text_recognization.cpython-312.pyc
DELETED
|
Binary file (3.54 kB)
|
|
|
src/Text_Recognization/dataloader.py
DELETED
|
@@ -1,96 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torchvision
|
| 3 |
-
import json
|
| 4 |
-
import sys
|
| 5 |
-
|
| 6 |
-
from torch.utils.data import Dataset, DataLoader
|
| 7 |
-
from torchvision import transforms
|
| 8 |
-
from sklearn.model_selection import train_test_split
|
| 9 |
-
from src.Text_Recognization.prepare_dataset import *
|
| 10 |
-
|
| 11 |
-
# data augmentation
|
| 12 |
-
data_transforms = {
|
| 13 |
-
"train": transforms.Compose(
|
| 14 |
-
[
|
| 15 |
-
transforms.ToTensor(),
|
| 16 |
-
transforms.Resize((100, 400)),
|
| 17 |
-
transforms.ColorJitter(
|
| 18 |
-
brightness=0.5,
|
| 19 |
-
contrast=0.5,
|
| 20 |
-
saturation=0.5
|
| 21 |
-
),
|
| 22 |
-
transforms.GaussianBlur(3),
|
| 23 |
-
transforms.RandomAffine(
|
| 24 |
-
degrees=1,
|
| 25 |
-
shear=1
|
| 26 |
-
),
|
| 27 |
-
transforms.RandomPerspective(
|
| 28 |
-
distortion_scale=0.3,
|
| 29 |
-
p=0.5
|
| 30 |
-
),
|
| 31 |
-
transforms.RandomRotation(degrees=15),
|
| 32 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 33 |
-
]
|
| 34 |
-
),
|
| 35 |
-
"val": transforms.Compose(
|
| 36 |
-
[
|
| 37 |
-
transforms.ToTensor(),
|
| 38 |
-
transforms.Resize((100, 400)),
|
| 39 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 40 |
-
]
|
| 41 |
-
)
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
def load_json_config(config_path):
|
| 45 |
-
with open(config_path, "r") as f:
|
| 46 |
-
config = json.load(f)
|
| 47 |
-
|
| 48 |
-
return config
|
| 49 |
-
|
| 50 |
-
# Dataloader
|
| 51 |
-
class STRDataset(Dataset):
|
| 52 |
-
def __init__(self, image_paths, labels, char_to_idx, transforms=None):
|
| 53 |
-
self.image_paths = image_paths
|
| 54 |
-
self.labels = labels
|
| 55 |
-
self.char_to_idx = char_to_idx
|
| 56 |
-
self.transforms= transforms
|
| 57 |
-
|
| 58 |
-
def __len__(self):
|
| 59 |
-
return len(self.image_paths)
|
| 60 |
-
|
| 61 |
-
def __getitem__(self, idx):
|
| 62 |
-
image = cv2.imread(self.image_paths[idx])
|
| 63 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 64 |
-
|
| 65 |
-
if self.transforms:
|
| 66 |
-
image = self.transforms(image)
|
| 67 |
-
|
| 68 |
-
label_encoded, length = encode(self.labels[idx], self.char_to_idx, self.labels)
|
| 69 |
-
|
| 70 |
-
return image, label_encoded, length
|
| 71 |
-
|
| 72 |
-
def get_dataloader():
|
| 73 |
-
val_size = 0.1
|
| 74 |
-
test_size = 0.1
|
| 75 |
-
root_path = 'Dataset'
|
| 76 |
-
config_path = 'src/config.json'
|
| 77 |
-
|
| 78 |
-
# get image paths and labels
|
| 79 |
-
image_paths, labels = get_imagepaths_and_labels(root_path)
|
| 80 |
-
char_to_idx, idx_to_char = build_vocab(root_path)
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
config = load_json_config(config_path)
|
| 84 |
-
|
| 85 |
-
X_train, X_val, y_train, y_val = train_test_split(image_paths, labels, test_size=val_size, random_state=42, shuffle=True)
|
| 86 |
-
X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=test_size, random_state=42, shuffle=True)
|
| 87 |
-
train_dataset = STRDataset(X_train, y_train, char_to_idx, transforms=data_transforms['train'])
|
| 88 |
-
train_loader = DataLoader(train_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
|
| 89 |
-
|
| 90 |
-
val_dataset = STRDataset(X_val, y_val, char_to_idx, transforms=data_transforms['val'])
|
| 91 |
-
val_loader = DataLoader(val_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
|
| 92 |
-
|
| 93 |
-
test_dataset = STRDataset(X_test, y_test, char_to_idx, transforms=data_transforms['val'])
|
| 94 |
-
test_loader = DataLoader(test_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
|
| 95 |
-
|
| 96 |
-
return train_loader, val_loader, test_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/Text_Recognization/prepare_dataset.py
DELETED
|
@@ -1,221 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import random
|
| 3 |
-
import numpy as np
|
| 4 |
-
import xml.etree.ElementTree as ET
|
| 5 |
-
import cv2
|
| 6 |
-
import json
|
| 7 |
-
import matplotlib.pyplot as plt
|
| 8 |
-
import argparse
|
| 9 |
-
|
| 10 |
-
import torch
|
| 11 |
-
import torch.nn as nn
|
| 12 |
-
import torch.nn.functional as F
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def extract_data_from_xml(root_path):
|
| 16 |
-
words_path = os.path.join(root_path, 'words.xml')
|
| 17 |
-
tree = ET.parse(words_path)
|
| 18 |
-
root = tree.getroot()
|
| 19 |
-
|
| 20 |
-
image_paths = []
|
| 21 |
-
image_sizes = []
|
| 22 |
-
image_labels = []
|
| 23 |
-
bboxes = []
|
| 24 |
-
|
| 25 |
-
for image in root:
|
| 26 |
-
imagename = image[0].text
|
| 27 |
-
image_path = os.path.join(root_path, imagename)
|
| 28 |
-
image_paths.append(image_path)
|
| 29 |
-
|
| 30 |
-
image_height = image[1].get('x')
|
| 31 |
-
image_width = image[1].get('y')
|
| 32 |
-
image_sizes.append([image_height, image_width])
|
| 33 |
-
|
| 34 |
-
bboxes_in_image = []
|
| 35 |
-
labels_in_bboxes = []
|
| 36 |
-
for bbox in image[2]:
|
| 37 |
-
x = float(bbox.get('x'))
|
| 38 |
-
y = float(bbox.get('y'))
|
| 39 |
-
width = float(bbox.get('width'))
|
| 40 |
-
height = float(bbox.get('height'))
|
| 41 |
-
bboxes_in_image.append([x, y, width, height])
|
| 42 |
-
|
| 43 |
-
# get text in this bbox
|
| 44 |
-
labels = bbox.find('tag').text
|
| 45 |
-
labels_in_bboxes.append(labels)
|
| 46 |
-
|
| 47 |
-
bboxes.append(bboxes_in_image)
|
| 48 |
-
image_labels.append(labels_in_bboxes)
|
| 49 |
-
|
| 50 |
-
return image_paths, image_sizes, bboxes, image_labels
|
| 51 |
-
|
| 52 |
-
def visualize_gt_bboxes(image_path, gt_locations, gt_labels):
|
| 53 |
-
image = cv2.imread(image_path)
|
| 54 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 55 |
-
|
| 56 |
-
for gt_location, gt_label in zip(gt_locations, gt_labels):
|
| 57 |
-
x, y, width, height = gt_location
|
| 58 |
-
x, y, width, height = int(x), int(y), int(width), int(height)
|
| 59 |
-
|
| 60 |
-
image = cv2.rectangle(image, (x, y), (x+width, y+height), color=(255, 0, 0), thickness=2)
|
| 61 |
-
image = cv2.putText(image, gt_label, (x, y-10), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale = 3, color=(255, 0, 0), thickness=2)
|
| 62 |
-
|
| 63 |
-
plt.imshow(image)
|
| 64 |
-
plt.axis('off')
|
| 65 |
-
plt.show()
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def split_bboxes_from_image(image_paths, image_labels, bboxes, save_dir):
|
| 69 |
-
"""create a new dataset contains bboxes and corresponding labels
|
| 70 |
-
|
| 71 |
-
Args:
|
| 72 |
-
image_paths
|
| 73 |
-
image_labels
|
| 74 |
-
bboxes
|
| 75 |
-
save_dir
|
| 76 |
-
Return:
|
| 77 |
-
non-return
|
| 78 |
-
"""
|
| 79 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 80 |
-
os.makedirs('unvalid_images', exist_ok=True)
|
| 81 |
-
|
| 82 |
-
bboxes_idx = 0
|
| 83 |
-
unvalid_bboxes = 0
|
| 84 |
-
new_labels = [] # List to store labels
|
| 85 |
-
for image_path, bbox, label in zip(image_paths, bboxes, image_labels):
|
| 86 |
-
image = cv2.imread(image_path)
|
| 87 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 88 |
-
|
| 89 |
-
if image is None:
|
| 90 |
-
print(image_path)
|
| 91 |
-
continue
|
| 92 |
-
|
| 93 |
-
for bb, lb in zip(bbox, label):
|
| 94 |
-
x, y, width, height = bb
|
| 95 |
-
x, y, width, height = int(x), int(y), int(width), int(height)
|
| 96 |
-
|
| 97 |
-
cropped_text = image[y:y+height, x:x+width]
|
| 98 |
-
|
| 99 |
-
# Filter if x, y, width, height is invalid cordinates
|
| 100 |
-
if x < 0 or y < 0 or width < 0 or height < 0:
|
| 101 |
-
continue
|
| 102 |
-
|
| 103 |
-
# Filter text contain special characters
|
| 104 |
-
if 'é' in [lb[i].lower() for i in range(len(lb))] or 'ñ' in [lb[i].lower() for i in range(len(lb))] or '£' in [lb[i].lower() for i in range(len(lb))]:
|
| 105 |
-
continue
|
| 106 |
-
|
| 107 |
-
# Filter out if text is too light or too dark
|
| 108 |
-
if np.mean(cropped_text) < 30 or np.mean(cropped_text) > 230:
|
| 109 |
-
cv2.imwrite(f'unvalid_images\\unvalid_image{unvalid_bboxes}_{lb}.jpg', cropped_text)
|
| 110 |
-
unvalid_bboxes += 1
|
| 111 |
-
continue
|
| 112 |
-
|
| 113 |
-
# Filter out if image is too small
|
| 114 |
-
if width < 10 or height < 10:
|
| 115 |
-
cv2.imwrite(f'unvalid_images\\unvalid_image{unvalid_bboxes}_{lb}.jpg', cropped_text)
|
| 116 |
-
unvalid_bboxes += 1
|
| 117 |
-
continue
|
| 118 |
-
|
| 119 |
-
new_image_path = os.path.join(save_dir, f'cropped_image{bboxes_idx}.jpg')
|
| 120 |
-
cv2.imwrite(new_image_path, cropped_text)
|
| 121 |
-
new_label = new_image_path + '\t' + lb
|
| 122 |
-
new_labels.append(new_label)
|
| 123 |
-
bboxes_idx += 1
|
| 124 |
-
|
| 125 |
-
# Write labels into a text file
|
| 126 |
-
with open(os.path.join(save_dir, 'labels.txt'), "w") as f:
|
| 127 |
-
for new_label in new_labels:
|
| 128 |
-
f.write(f'{new_label}\n')
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def build_vocab(root_dir):
|
| 132 |
-
img_paths = []
|
| 133 |
-
labels = []
|
| 134 |
-
|
| 135 |
-
# Read labels from text file
|
| 136 |
-
with open(os.path.join(root_dir, 'ocr_dataset', 'labels.txt'), "r") as f:
|
| 137 |
-
for label in f:
|
| 138 |
-
labels.append(label.strip().split("\t")[1])
|
| 139 |
-
img_paths.append(label.strip().split("\t")[0])
|
| 140 |
-
|
| 141 |
-
# build the vocab
|
| 142 |
-
vocab = set()
|
| 143 |
-
for label in labels:
|
| 144 |
-
for i in range(len(label)):
|
| 145 |
-
vocab.add(label[i])
|
| 146 |
-
|
| 147 |
-
# "blank" character
|
| 148 |
-
vocab = list(sorted(vocab))
|
| 149 |
-
vocab = "".join(vocab)
|
| 150 |
-
blank_char = '@'
|
| 151 |
-
vocab = vocab + 'z'
|
| 152 |
-
vocab = vocab + blank_char
|
| 153 |
-
|
| 154 |
-
# build a dictionary convert from vocab to idx and idx to vocab
|
| 155 |
-
char_to_idx = {
|
| 156 |
-
char: idx + 1 for idx, char in enumerate(vocab)
|
| 157 |
-
}
|
| 158 |
-
idx_to_char = {
|
| 159 |
-
idx: char for char, idx in char_to_idx.items()
|
| 160 |
-
}
|
| 161 |
-
|
| 162 |
-
return char_to_idx, idx_to_char
|
| 163 |
-
|
| 164 |
-
def get_imagepaths_and_labels(root_path):
|
| 165 |
-
img_paths = []
|
| 166 |
-
labels = []
|
| 167 |
-
|
| 168 |
-
# Read labels from text file
|
| 169 |
-
with open(os.path.join(root_path, 'ocr_dataset', 'labels.txt'), "r") as f:
|
| 170 |
-
for label in f:
|
| 171 |
-
labels.append(label.strip().split("\t")[1])
|
| 172 |
-
img_paths.append(label.strip().split("\t")[0])
|
| 173 |
-
|
| 174 |
-
return img_paths, labels
|
| 175 |
-
|
| 176 |
-
def encode(label, char_to_idx, labels):
|
| 177 |
-
max_length_label = np.max([len(lb) for lb in labels])
|
| 178 |
-
|
| 179 |
-
# encode label
|
| 180 |
-
encoded_label = torch.tensor(
|
| 181 |
-
[char_to_idx[char.lower()] for char in label],
|
| 182 |
-
dtype=torch.int32
|
| 183 |
-
)
|
| 184 |
-
label_len = len(encoded_label)
|
| 185 |
-
length = torch.tensor(
|
| 186 |
-
label_len,
|
| 187 |
-
dtype=torch.int32
|
| 188 |
-
)
|
| 189 |
-
padded_label = F.pad(
|
| 190 |
-
encoded_label,
|
| 191 |
-
(0, max_length_label-label_len),
|
| 192 |
-
value=0
|
| 193 |
-
)
|
| 194 |
-
return padded_label, length
|
| 195 |
-
|
| 196 |
-
def decode(encoded_label, idx_to_char, char_to_idx, blank_char='@'):
|
| 197 |
-
label = []
|
| 198 |
-
encoded_label = encoded_label.detach().numpy()
|
| 199 |
-
for i in range(len(encoded_label)):
|
| 200 |
-
if encoded_label[i] == 0:
|
| 201 |
-
break
|
| 202 |
-
elif (i == 0 or encoded_label[i] != encoded_label[i-1]) and encoded_label[i] != char_to_idx[blank_char]:
|
| 203 |
-
label.append(idx_to_char[encoded_label[i]])
|
| 204 |
-
|
| 205 |
-
label = "".join(label)
|
| 206 |
-
return label
|
| 207 |
-
|
| 208 |
-
def main():
|
| 209 |
-
parser = argparse.ArgumentParser()
|
| 210 |
-
parser.add_argument("--path", type=str, default=os.getcwd(), help="Path to the root directory")
|
| 211 |
-
args = parser.parse_args()
|
| 212 |
-
|
| 213 |
-
root_path = os.path.join(args.path, 'Dataset')
|
| 214 |
-
|
| 215 |
-
image_paths, image_sizes, bboxes, image_labels = extract_data_from_xml(root_path)
|
| 216 |
-
save_dir = 'Dataset/ocr_dataset'
|
| 217 |
-
split_bboxes_from_image(image_paths, image_labels, bboxes, save_dir)
|
| 218 |
-
char_to_idx, idx_to_char = build_vocab(root_path)
|
| 219 |
-
|
| 220 |
-
if __name__ == '__main__':
|
| 221 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/Text_Recognization/text_recognization.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torchvision
|
| 4 |
-
|
| 5 |
-
from torchvision.models import resnet101
|
| 6 |
-
|
| 7 |
-
class BackBone(nn.Module):
|
| 8 |
-
def __init__(self, num_unfreeze_layers=3):
|
| 9 |
-
super(BackBone, self).__init__()
|
| 10 |
-
model = resnet101(weights='IMAGENET1K_V2', progress=True)
|
| 11 |
-
feature_maps = list(model.children())[:8]
|
| 12 |
-
|
| 13 |
-
# Adding an AdaptiveAvgPooling (batch_size, 2048, 8, 8) -> (batch_size, 2048, 1, 8)
|
| 14 |
-
feature_maps.append(nn.AdaptiveAvgPool2d((1, None)))
|
| 15 |
-
self.backbone = nn.Sequential(*feature_maps)
|
| 16 |
-
|
| 17 |
-
for layer in list(self.backbone.parameters())[-(num_unfreeze_layers+1):]:
|
| 18 |
-
layer.requires_grad = True
|
| 19 |
-
|
| 20 |
-
def forward(self, image):
|
| 21 |
-
return self.backbone(image)
|
| 22 |
-
|
| 23 |
-
class CRNN(nn.Module):
|
| 24 |
-
def __init__(self, vocab_size, hidden_size, n_layers, dropout=0.2, num_unfreeze_layers=3):
|
| 25 |
-
super(CRNN, self).__init__()
|
| 26 |
-
self.backbone = BackBone(num_unfreeze_layers=num_unfreeze_layers)
|
| 27 |
-
|
| 28 |
-
self.mapSeq = nn.Sequential(
|
| 29 |
-
nn.Linear(2048, 512),
|
| 30 |
-
nn.ReLU(),
|
| 31 |
-
nn.Dropout(p=dropout)
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
self.gru = nn.GRU(
|
| 35 |
-
input_size=512,
|
| 36 |
-
hidden_size=hidden_size,
|
| 37 |
-
num_layers=n_layers,
|
| 38 |
-
bidirectional=True,
|
| 39 |
-
batch_first=True,
|
| 40 |
-
dropout=dropout if n_layers > 1 else 0
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
self.layer_norm = nn.LayerNorm(hidden_size * 2)
|
| 44 |
-
|
| 45 |
-
# Dense layers
|
| 46 |
-
self.out = nn.Sequential(
|
| 47 |
-
nn.Linear(hidden_size * 2, vocab_size),
|
| 48 |
-
nn.LogSoftmax(dim=2)
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
def forward(self, x):
|
| 52 |
-
x = self.backbone(x)
|
| 53 |
-
# (batch_size, 2048, 1, 8) -> (batch_size, 8, 2048, 1)
|
| 54 |
-
x = x.permute(0, 3, 1, 2)
|
| 55 |
-
# flatten -> (batch_size, 8, 2048)
|
| 56 |
-
x = x.view(x.size(0), x.size(1), -1)
|
| 57 |
-
x = self.mapSeq(x)
|
| 58 |
-
x, _ = self.gru(x)
|
| 59 |
-
x = self.layer_norm(x)
|
| 60 |
-
x = self.out(x)
|
| 61 |
-
# (batch_size, 8, vocab_size) -> (8, batch_size, vocab_size)
|
| 62 |
-
x = x.permute(1, 0, 2)
|
| 63 |
-
|
| 64 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/Text_Recognization/trainer.py
DELETED
|
@@ -1,162 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import sys
|
| 3 |
-
import os
|
| 4 |
-
import argparse
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torchvision
|
| 8 |
-
from tqdm import tqdm
|
| 9 |
-
|
| 10 |
-
sys.path.append(os.getcwd())
|
| 11 |
-
|
| 12 |
-
from src.Text_Recognization.text_recognization import *
|
| 13 |
-
from src.Text_Recognization.prepare_dataset import *
|
| 14 |
-
from src.Text_Recognization.dataloader import *
|
| 15 |
-
|
| 16 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 17 |
-
|
| 18 |
-
def load_json_config(config_path):
|
| 19 |
-
with open(config_path, "r") as f:
|
| 20 |
-
config = json.load(f)
|
| 21 |
-
|
| 22 |
-
return config
|
| 23 |
-
|
| 24 |
-
def evaluate(model, dataloader, criterion, device):
|
| 25 |
-
model.eval()
|
| 26 |
-
losses = []
|
| 27 |
-
|
| 28 |
-
with torch.no_grad():
|
| 29 |
-
for images, labels, labels_len in dataloader:
|
| 30 |
-
images = images.to(device)
|
| 31 |
-
labels = labels.to(device)
|
| 32 |
-
|
| 33 |
-
outputs = model(images)
|
| 34 |
-
logits_lens = torch.full(
|
| 35 |
-
size=(outputs.size(1), ),
|
| 36 |
-
fill_value=outputs.size(0),
|
| 37 |
-
dtype=torch.long
|
| 38 |
-
).to(device)
|
| 39 |
-
|
| 40 |
-
loss = criterion(outputs, labels, logits_lens, labels_len)
|
| 41 |
-
losses.append(loss.item())
|
| 42 |
-
|
| 43 |
-
eval_loss = sum(losses) / len(losses)
|
| 44 |
-
return eval_loss
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def training_loop(model, train_loader, val_loader, learning_rate, epochs, optimizer, criterion, scheduler, device):
|
| 48 |
-
model.to(device)
|
| 49 |
-
|
| 50 |
-
train_losses = []
|
| 51 |
-
val_losses = []
|
| 52 |
-
|
| 53 |
-
for epoch in range(epochs):
|
| 54 |
-
model.train()
|
| 55 |
-
|
| 56 |
-
batch_losses = []
|
| 57 |
-
for images, labels, labels_len in tqdm(train_loader):
|
| 58 |
-
images = images.to(device)
|
| 59 |
-
labels = labels.to(device)
|
| 60 |
-
|
| 61 |
-
optimizer.zero_grad()
|
| 62 |
-
outputs = model(images)
|
| 63 |
-
|
| 64 |
-
logits_lens = torch.full(
|
| 65 |
-
size=(outputs.size(1), ),
|
| 66 |
-
fill_value=outputs.size(0),
|
| 67 |
-
dtype=torch.long
|
| 68 |
-
).to(device)
|
| 69 |
-
|
| 70 |
-
loss = criterion(outputs, labels, logits_lens, labels_len)
|
| 71 |
-
|
| 72 |
-
loss.backward()
|
| 73 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
|
| 74 |
-
optimizer.step()
|
| 75 |
-
|
| 76 |
-
batch_losses.append(loss.item())
|
| 77 |
-
|
| 78 |
-
train_loss = sum(batch_losses) / len(batch_losses)
|
| 79 |
-
train_losses.append(train_loss)
|
| 80 |
-
|
| 81 |
-
val_loss = evaluate(model, val_loader, criterion, device)
|
| 82 |
-
val_losses.append(val_loss)
|
| 83 |
-
|
| 84 |
-
print(f"epoch: {epoch+1}/{epochs}\ttrain_loss:{train_loss}\tval_loss:{val_loss}")
|
| 85 |
-
|
| 86 |
-
scheduler.step()
|
| 87 |
-
|
| 88 |
-
return train_losses, val_losses
|
| 89 |
-
|
| 90 |
-
def main():
|
| 91 |
-
parser = argparse.ArgumentParser()
|
| 92 |
-
parser.add_argument('--root_path', type=str, default=os.getcwd(), help='Path to the root directory')
|
| 93 |
-
parser.add_argument('--checkpoints_path', type=str, default=os.path.join(os.getcwd(), 'checkpoints'), help='Path to the checkpoint directory')
|
| 94 |
-
|
| 95 |
-
args = parser.parse_args()
|
| 96 |
-
config_path = 'src/config.json'
|
| 97 |
-
dataset_path = os.path.join(args.root_path, 'Dataset')
|
| 98 |
-
config = load_json_config(config_path)
|
| 99 |
-
|
| 100 |
-
# dictionary char and idx
|
| 101 |
-
char_to_idx, idx_to_char = build_vocab(dataset_path)
|
| 102 |
-
|
| 103 |
-
# model
|
| 104 |
-
model = CRNN(vocab_size=config['vocab_size'], hidden_size=config['CRNN']['hidden_size'], n_layers=config['CRNN']['n_layers'])
|
| 105 |
-
|
| 106 |
-
# dataloader
|
| 107 |
-
train_loader, val_loader, test_loader = get_dataloader()
|
| 108 |
-
|
| 109 |
-
# define hyper parammeters
|
| 110 |
-
criterion = nn.CTCLoss(
|
| 111 |
-
blank=char_to_idx[config['blank_char']],
|
| 112 |
-
zero_infinity=True,
|
| 113 |
-
reduction='mean'
|
| 114 |
-
)
|
| 115 |
-
optimizer = torch.optim.Adam(
|
| 116 |
-
model.parameters(),
|
| 117 |
-
lr=config['CRNN']['learning_rate'],
|
| 118 |
-
weight_decay=config['CRNN']['weight_decay']
|
| 119 |
-
)
|
| 120 |
-
scheduler = torch.optim.lr_scheduler.StepLR(
|
| 121 |
-
optimizer=optimizer,
|
| 122 |
-
step_size=config['CRNN']['scheduler_step_size'],
|
| 123 |
-
gamma=0.1
|
| 124 |
-
)
|
| 125 |
-
|
| 126 |
-
# training loop
|
| 127 |
-
train_losses, val_losses = training_loop(
|
| 128 |
-
model=model,
|
| 129 |
-
train_loader=train_loader,
|
| 130 |
-
val_loader=val_loader,
|
| 131 |
-
learning_rate=config['CRNN']['learning_rate'],
|
| 132 |
-
epochs=config['CRNN']['epochs'],
|
| 133 |
-
optimizer=optimizer,
|
| 134 |
-
criterion=criterion,
|
| 135 |
-
scheduler=scheduler,
|
| 136 |
-
device=device
|
| 137 |
-
)
|
| 138 |
-
|
| 139 |
-
# save model
|
| 140 |
-
if not os.path.exists(args.checkpoints_path):
|
| 141 |
-
os.makedirs(args.checkpoints_path)
|
| 142 |
-
os.makedirs(os.path.join(args.checkpoints_path, 'losses'))
|
| 143 |
-
torch.save(model.state_dict(), os.path.join(args.checkpoints_path, 'crnn.pt'))
|
| 144 |
-
|
| 145 |
-
# draw losses
|
| 146 |
-
fig, axis = plt.subplots(1, 2, figsize=(8, 8))
|
| 147 |
-
axis[0].plot(train_losses, label='train_loss')
|
| 148 |
-
axis[0].set_xlabel('Epochs')
|
| 149 |
-
axis[0].set_ylabel('Loss')
|
| 150 |
-
axis[0].axis('off')
|
| 151 |
-
axis[0].legend()
|
| 152 |
-
|
| 153 |
-
axis[1].plot(val_losses, label='val_loss')
|
| 154 |
-
axis[1].set_xlabel('Epochs')
|
| 155 |
-
axis[1].set_ylabel('Loss')
|
| 156 |
-
axis[1].axis('off')
|
| 157 |
-
axis[1].legend()
|
| 158 |
-
|
| 159 |
-
plt.savefig(os.path.join(args.checkpoints_path, 'losses', 'losses.png'))
|
| 160 |
-
|
| 161 |
-
if __name__ == '__main__':
|
| 162 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/__pycache__/pipeline_end2end.cpython-312.pyc
DELETED
|
Binary file (6.49 kB)
|
|
|
src/__pycache__/predict.cpython-312.pyc
DELETED
|
Binary file (7.9 kB)
|
|
|
src/app.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
import numpy as np
|
| 3 |
-
import os
|
| 4 |
-
import json
|
| 5 |
-
import cv2
|
| 6 |
-
import sys
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
import torchvision
|
| 10 |
-
|
| 11 |
-
sys.path.append(os.getcwd())
|
| 12 |
-
from predict import *
|
| 13 |
-
|
| 14 |
-
def visualize_image(image, detections):
|
| 15 |
-
for bbox, detected_class, conf, text, _ in detections:
|
| 16 |
-
x1, y1, x2, y2 = bbox
|
| 17 |
-
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
| 18 |
-
|
| 19 |
-
image = cv2.rectangle(image, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2)
|
| 20 |
-
image = cv2.putText(image, f"{conf:.2f} {text}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
|
| 21 |
-
|
| 22 |
-
return image
|
| 23 |
-
|
| 24 |
-
def pipeline(image):
|
| 25 |
-
image = np.array(image)
|
| 26 |
-
|
| 27 |
-
predictions = prediction(image)
|
| 28 |
-
|
| 29 |
-
# Filter low conf boxes
|
| 30 |
-
filter_predictions = []
|
| 31 |
-
for bbox, cls, conf, text, encoded_text in predictions:
|
| 32 |
-
if conf > 0.7:
|
| 33 |
-
filter_predictions.append([bbox, cls, conf, text, encoded_text])
|
| 34 |
-
|
| 35 |
-
image = visualize_image(image, filter_predictions)
|
| 36 |
-
return image
|
| 37 |
-
|
| 38 |
-
demo = gr.Interface(
|
| 39 |
-
fn=pipeline,
|
| 40 |
-
inputs=gr.Image(type="pil", label="Input Image"),
|
| 41 |
-
outputs="image",
|
| 42 |
-
title="Scene Text Recognization",
|
| 43 |
-
description="Recognize text in scene images"
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/config.json
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"yolov11":
|
| 3 |
-
{
|
| 4 |
-
"epochs": 100,
|
| 5 |
-
"image_size": 640,
|
| 6 |
-
"cache": true,
|
| 7 |
-
"patience": 20,
|
| 8 |
-
"plots": true
|
| 9 |
-
},
|
| 10 |
-
"CRNN":
|
| 11 |
-
{
|
| 12 |
-
"batch_size": 64,
|
| 13 |
-
"epochs": 100,
|
| 14 |
-
"hidden_size": 256,
|
| 15 |
-
"n_layers": 3,
|
| 16 |
-
"dropout": 0.2,
|
| 17 |
-
"unfreeze_layers": 3,
|
| 18 |
-
"learning_rate": 5e-4,
|
| 19 |
-
"weight_decay": 1e-5,
|
| 20 |
-
"scheduler_step_size": 30
|
| 21 |
-
},
|
| 22 |
-
"blank_char": "@",
|
| 23 |
-
"vocab_size": 73
|
| 24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/predict.py
DELETED
|
@@ -1,137 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import json
|
| 4 |
-
import cv2
|
| 5 |
-
import argparse
|
| 6 |
-
import matplotlib.pyplot as plt
|
| 7 |
-
|
| 8 |
-
import ultralytics
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
import torchvision
|
| 12 |
-
from torchvision.models import resnet101
|
| 13 |
-
from torchvision import transforms
|
| 14 |
-
|
| 15 |
-
sys.path.append(os.getcwd())
|
| 16 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 17 |
-
|
| 18 |
-
from ultralytics import YOLO
|
| 19 |
-
from src.Text_Recognization.text_recognization import *
|
| 20 |
-
from src.Text_Recognization.prepare_dataset import *
|
| 21 |
-
|
| 22 |
-
# config
|
| 23 |
-
def load_json_config(config_path):
|
| 24 |
-
with open(config_path, "r") as f:
|
| 25 |
-
config = json.load(f)
|
| 26 |
-
|
| 27 |
-
return config
|
| 28 |
-
|
| 29 |
-
config = load_json_config('src/config.json')
|
| 30 |
-
|
| 31 |
-
# char to idx
|
| 32 |
-
char_to_idx, idx_to_char = build_vocab('Dataset')
|
| 33 |
-
|
| 34 |
-
# text detection model
|
| 35 |
-
text_det_model_path = 'checkpoints/yolov11m.pt'
|
| 36 |
-
yolo = YOLO(text_det_model_path)
|
| 37 |
-
|
| 38 |
-
# text recognition model
|
| 39 |
-
text_rec_model_path = 'checkpoints/crnn_extend_vocab.pt'
|
| 40 |
-
|
| 41 |
-
# rcnn model
|
| 42 |
-
rcnn_model = CRNN(vocab_size=74, hidden_size=config['CRNN']['hidden_size'], n_layers=config['CRNN']['n_layers'])
|
| 43 |
-
rcnn_model.load_state_dict(torch.load(text_rec_model_path, weights_only=True, map_location=torch.device('cpu')))
|
| 44 |
-
|
| 45 |
-
def text_detection(img_path, text_det_model):
|
| 46 |
-
text_det_results = text_det_model(img_path, verbose=False)[0]
|
| 47 |
-
|
| 48 |
-
bboxes = text_det_results.boxes.xyxy.tolist()
|
| 49 |
-
classes = text_det_results.boxes.cls.tolist()
|
| 50 |
-
names = text_det_results.names
|
| 51 |
-
confs = text_det_results.boxes.conf.tolist()
|
| 52 |
-
|
| 53 |
-
return bboxes, classes, names, confs
|
| 54 |
-
|
| 55 |
-
def visualize_gt_bboxes_yolo(image_path, gt_location_yolo):
|
| 56 |
-
image = cv2.imread(image_path)
|
| 57 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 58 |
-
|
| 59 |
-
# Convert to original format
|
| 60 |
-
for data in gt_location_yolo:
|
| 61 |
-
xmin, ymin, xmax, ymax = data
|
| 62 |
-
xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
|
| 63 |
-
|
| 64 |
-
image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=(255, 0, 0), thickness=2)
|
| 65 |
-
|
| 66 |
-
plt.imshow(image)
|
| 67 |
-
plt.axis('off')
|
| 68 |
-
plt.show()
|
| 69 |
-
|
| 70 |
-
def text_recognization(image, data_transforms, text_reg_model, idx_to_char=idx_to_char, device=device):
|
| 71 |
-
transformsed_image = data_transforms(image)
|
| 72 |
-
transformsed_image = transformsed_image.unsqueeze(0).to(device)
|
| 73 |
-
text_reg_model.to(device)
|
| 74 |
-
text_reg_model.eval()
|
| 75 |
-
|
| 76 |
-
with torch.no_grad():
|
| 77 |
-
preds = text_reg_model(transformsed_image)
|
| 78 |
-
_, idx = torch.max(preds, dim=2)
|
| 79 |
-
idx = idx.view(-1)
|
| 80 |
-
text = decode(idx, idx_to_char, char_to_idx)
|
| 81 |
-
|
| 82 |
-
return text, idx
|
| 83 |
-
|
| 84 |
-
def visualize_detection(image, detections):
|
| 85 |
-
plt.figure(figsize=(10, 8))
|
| 86 |
-
|
| 87 |
-
for bbox, detected_classes, conf, text, _ in detections:
|
| 88 |
-
x1, y1, x2, y2 = bbox
|
| 89 |
-
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
| 90 |
-
|
| 91 |
-
image = cv2.rectangle(image, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2)
|
| 92 |
-
image = cv2.putText(image, f"{conf:.2f} {text}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
|
| 93 |
-
|
| 94 |
-
plt.imshow(image)
|
| 95 |
-
plt.axis('off')
|
| 96 |
-
plt.show()
|
| 97 |
-
return image
|
| 98 |
-
|
| 99 |
-
data_transforms = transforms.Compose([
|
| 100 |
-
transforms.ToTensor(),
|
| 101 |
-
transforms.Resize((100, 400)),
|
| 102 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 103 |
-
])
|
| 104 |
-
|
| 105 |
-
def prediction(image, text_det_model=yolo, text_reg_model=rcnn_model, idx_to_char=idx_to_char, char_to_idx=char_to_idx, data_transforms=data_transforms, device=device):
|
| 106 |
-
# detection
|
| 107 |
-
bboxes, classes, names, confs = text_detection(image, text_det_model)
|
| 108 |
-
|
| 109 |
-
predictions = []
|
| 110 |
-
for bbox, cls, conf in zip(bboxes, classes, confs):
|
| 111 |
-
x1, y1, x2, y2 = bbox
|
| 112 |
-
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
| 113 |
-
detected_text = image[y1:y2, x1:x2]
|
| 114 |
-
text, encoded_text = text_recognization(detected_text, data_transforms, text_reg_model, idx_to_char, device)
|
| 115 |
-
predictions.append((bbox, cls, conf, text, encoded_text))
|
| 116 |
-
print(bbox, cls, conf, text)
|
| 117 |
-
|
| 118 |
-
return predictions
|
| 119 |
-
|
| 120 |
-
def main():
|
| 121 |
-
parser = argparse.ArgumentParser()
|
| 122 |
-
parser.add_argument('--image_path', type=str, help='Path to the image')
|
| 123 |
-
parser.add_argument('--save_path', type=str, default=None, help='Path to save the image')
|
| 124 |
-
args = parser.parse_args()
|
| 125 |
-
image_path = args.image_path
|
| 126 |
-
image = cv2.imread(image_path)
|
| 127 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 128 |
-
|
| 129 |
-
detections = prediction(image)
|
| 130 |
-
image = visualize_detection(image, detections)
|
| 131 |
-
|
| 132 |
-
if args.save_path:
|
| 133 |
-
print(f"Saving the image to {os.path.join(args.save_path, 'predicted_image.jpg')}")
|
| 134 |
-
cv2.imwrite(os.path.join(args.save_path, 'predicted_image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
|
| 135 |
-
|
| 136 |
-
if __name__ == '__main__':
|
| 137 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|