import os import random import numpy as np import xml.etree.ElementTree as ET import cv2 import json import pickle import matplotlib.pyplot as plt import argparse import torch import torch.nn as nn import torch.nn.functional as F def extract_data_from_xml(root_path): words_path = os.path.join(root_path, 'words.xml') tree = ET.parse(words_path) root = tree.getroot() image_paths = [] image_sizes = [] image_labels = [] bboxes = [] for image in root: imagename = image[0].text image_path = os.path.join(root_path, imagename) image_paths.append(image_path) image_height = image[1].get('x') image_width = image[1].get('y') image_sizes.append([image_height, image_width]) bboxes_in_image = [] labels_in_bboxes = [] for bbox in image[2]: x = float(bbox.get('x')) y = float(bbox.get('y')) width = float(bbox.get('width')) height = float(bbox.get('height')) bboxes_in_image.append([x, y, width, height]) # get text in this bbox labels = bbox.find('tag').text labels_in_bboxes.append(labels) bboxes.append(bboxes_in_image) image_labels.append(labels_in_bboxes) return image_paths, image_sizes, bboxes, image_labels def visualize_gt_bboxes(image_path, gt_locations, gt_labels): image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for gt_location, gt_label in zip(gt_locations, gt_labels): x, y, width, height = gt_location x, y, width, height = int(x), int(y), int(width), int(height) image = cv2.rectangle(image, (x, y), (x+width, y+height), color=(255, 0, 0), thickness=2) image = cv2.putText(image, gt_label, (x, y-10), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale = 3, color=(255, 0, 0), thickness=2) plt.imshow(image) plt.axis('off') plt.show() def split_bboxes_from_image(image_paths, image_labels, bboxes, save_dir): """create a new dataset contains bboxes and corresponding labels Args: image_paths image_labels bboxes save_dir Return: non-return """ os.makedirs(save_dir, exist_ok=True) os.makedirs('unvalid_images', exist_ok=True) bboxes_idx = 0 unvalid_bboxes = 0 new_labels = [] # List to store labels for image_path, bbox, label in zip(image_paths, bboxes, image_labels): image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if image is None: print(image_path) continue for bb, lb in zip(bbox, label): x, y, width, height = bb x, y, width, height = int(x), int(y), int(width), int(height) cropped_text = image[y:y+height, x:x+width] # Filter if x, y, width, height is invalid cordinates if x < 0 or y < 0 or width < 0 or height < 0: continue # Filter text contain special characters if 'é' in [lb[i].lower() for i in range(len(lb))] or 'ñ' in [lb[i].lower() for i in range(len(lb))] or '£' in [lb[i].lower() for i in range(len(lb))]: continue # Filter out if text is too light or too dark if np.mean(cropped_text) < 30 or np.mean(cropped_text) > 230: cv2.imwrite(f'unvalid_images\\unvalid_image{unvalid_bboxes}_{lb}.jpg', cropped_text) unvalid_bboxes += 1 continue # Filter out if image is too small if width < 10 or height < 10: cv2.imwrite(f'unvalid_images\\unvalid_image{unvalid_bboxes}_{lb}.jpg', cropped_text) unvalid_bboxes += 1 continue new_image_path = os.path.join(save_dir, f'cropped_image{bboxes_idx}.jpg') cv2.imwrite(new_image_path, cropped_text) new_label = new_image_path + '\t' + lb new_labels.append(new_label) bboxes_idx += 1 # Write labels into a text file with open(os.path.join(save_dir, 'labels.txt'), "w") as f: for new_label in new_labels: f.write(f'{new_label}\n') def build_vocab(root_dir): img_paths = [] labels = [] # Read labels from text file with open(os.path.join(root_dir, 'ocr_dataset', 'labels.txt'), "r") as f: for label in f: labels.append(label.strip().split("\t")[1]) img_paths.append(label.strip().split("\t")[0]) # build the vocab vocab = set() for label in labels: for i in range(len(label)): vocab.add(label[i]) # "blank" character vocab = list(sorted(vocab)) vocab = "".join(vocab) blank_char = '@' vocab = vocab + 'z' vocab = vocab + blank_char # build a dictionary convert from vocab to idx and idx to vocab char_to_idx = { char: idx + 1 for idx, char in enumerate(vocab) } idx_to_char = { idx: char for char, idx in char_to_idx.items() } # save with open('src/encode.pkl', "wb") as file: pickle.dump(char_to_idx, file) with open('src/decode.pkl', "wb") as file: pickle.dump(idx_to_char, file) return char_to_idx, idx_to_char def get_imagepaths_and_labels(root_path): img_paths = [] labels = [] # Read labels from text file with open(os.path.join(root_path, 'ocr_dataset', 'labels.txt'), "r") as f: for label in f: labels.append(label.strip().split("\t")[1]) img_paths.append(label.strip().split("\t")[0]) return img_paths, labels def encode(label, char_to_idx, labels): max_length_label = np.max([len(lb) for lb in labels]) # encode label encoded_label = torch.tensor( [char_to_idx[char] for char in label], dtype=torch.int32 ) label_len = len(encoded_label) length = torch.tensor( label_len, dtype=torch.int32 ) padded_label = F.pad( encoded_label, (0, max_length_label-label_len), value=0 ) return padded_label, length def decode(encoded_label, idx_to_char, char_to_idx, blank_char='@'): label = [] encoded_label = encoded_label.detach().numpy() for i in range(len(encoded_label)): if encoded_label[i] == 0: break elif (i == 0 or encoded_label[i] != encoded_label[i-1]) and encoded_label[i] != char_to_idx[blank_char]: label.append(idx_to_char[encoded_label[i]]) label = "".join(label) return label def main(): parser = argparse.ArgumentParser() parser.add_argument("--path", type=str, default=os.getcwd(), help="Path to the root directory") args = parser.parse_args() root_path = os.path.join(args.path, 'Dataset') image_paths, image_sizes, bboxes, image_labels = extract_data_from_xml(root_path) save_dir = 'Dataset/ocr_dataset' split_bboxes_from_image(image_paths, image_labels, bboxes, save_dir) char_to_idx, idx_to_char = build_vocab(root_path) if __name__ == '__main__': main()