Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import numpy as np | |
| import xml.etree.ElementTree as ET | |
| import cv2 | |
| import json | |
| import pickle | |
| import matplotlib.pyplot as plt | |
| import argparse | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def extract_data_from_xml(root_path): | |
| words_path = os.path.join(root_path, 'words.xml') | |
| tree = ET.parse(words_path) | |
| root = tree.getroot() | |
| image_paths = [] | |
| image_sizes = [] | |
| image_labels = [] | |
| bboxes = [] | |
| for image in root: | |
| imagename = image[0].text | |
| image_path = os.path.join(root_path, imagename) | |
| image_paths.append(image_path) | |
| image_height = image[1].get('x') | |
| image_width = image[1].get('y') | |
| image_sizes.append([image_height, image_width]) | |
| bboxes_in_image = [] | |
| labels_in_bboxes = [] | |
| for bbox in image[2]: | |
| x = float(bbox.get('x')) | |
| y = float(bbox.get('y')) | |
| width = float(bbox.get('width')) | |
| height = float(bbox.get('height')) | |
| bboxes_in_image.append([x, y, width, height]) | |
| # get text in this bbox | |
| labels = bbox.find('tag').text | |
| labels_in_bboxes.append(labels) | |
| bboxes.append(bboxes_in_image) | |
| image_labels.append(labels_in_bboxes) | |
| return image_paths, image_sizes, bboxes, image_labels | |
| def visualize_gt_bboxes(image_path, gt_locations, gt_labels): | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| for gt_location, gt_label in zip(gt_locations, gt_labels): | |
| x, y, width, height = gt_location | |
| x, y, width, height = int(x), int(y), int(width), int(height) | |
| image = cv2.rectangle(image, (x, y), (x+width, y+height), color=(255, 0, 0), thickness=2) | |
| image = cv2.putText(image, gt_label, (x, y-10), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale = 3, color=(255, 0, 0), thickness=2) | |
| plt.imshow(image) | |
| plt.axis('off') | |
| plt.show() | |
| def split_bboxes_from_image(image_paths, image_labels, bboxes, save_dir): | |
| """create a new dataset contains bboxes and corresponding labels | |
| Args: | |
| image_paths | |
| image_labels | |
| bboxes | |
| save_dir | |
| Return: | |
| non-return | |
| """ | |
| os.makedirs(save_dir, exist_ok=True) | |
| os.makedirs('unvalid_images', exist_ok=True) | |
| bboxes_idx = 0 | |
| unvalid_bboxes = 0 | |
| new_labels = [] # List to store labels | |
| for image_path, bbox, label in zip(image_paths, bboxes, image_labels): | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| if image is None: | |
| print(image_path) | |
| continue | |
| for bb, lb in zip(bbox, label): | |
| x, y, width, height = bb | |
| x, y, width, height = int(x), int(y), int(width), int(height) | |
| cropped_text = image[y:y+height, x:x+width] | |
| # Filter if x, y, width, height is invalid cordinates | |
| if x < 0 or y < 0 or width < 0 or height < 0: | |
| continue | |
| # Filter text contain special characters | |
| if 'é' in [lb[i].lower() for i in range(len(lb))] or 'ñ' in [lb[i].lower() for i in range(len(lb))] or '£' in [lb[i].lower() for i in range(len(lb))]: | |
| continue | |
| # Filter out if text is too light or too dark | |
| if np.mean(cropped_text) < 30 or np.mean(cropped_text) > 230: | |
| cv2.imwrite(f'unvalid_images\\unvalid_image{unvalid_bboxes}_{lb}.jpg', cropped_text) | |
| unvalid_bboxes += 1 | |
| continue | |
| # Filter out if image is too small | |
| if width < 10 or height < 10: | |
| cv2.imwrite(f'unvalid_images\\unvalid_image{unvalid_bboxes}_{lb}.jpg', cropped_text) | |
| unvalid_bboxes += 1 | |
| continue | |
| new_image_path = os.path.join(save_dir, f'cropped_image{bboxes_idx}.jpg') | |
| cv2.imwrite(new_image_path, cropped_text) | |
| new_label = new_image_path + '\t' + lb | |
| new_labels.append(new_label) | |
| bboxes_idx += 1 | |
| # Write labels into a text file | |
| with open(os.path.join(save_dir, 'labels.txt'), "w") as f: | |
| for new_label in new_labels: | |
| f.write(f'{new_label}\n') | |
| def build_vocab(root_dir): | |
| img_paths = [] | |
| labels = [] | |
| # Read labels from text file | |
| with open(os.path.join(root_dir, 'ocr_dataset', 'labels.txt'), "r") as f: | |
| for label in f: | |
| labels.append(label.strip().split("\t")[1]) | |
| img_paths.append(label.strip().split("\t")[0]) | |
| # build the vocab | |
| vocab = set() | |
| for label in labels: | |
| for i in range(len(label)): | |
| vocab.add(label[i]) | |
| # "blank" character | |
| vocab = list(sorted(vocab)) | |
| vocab = "".join(vocab) | |
| blank_char = '@' | |
| vocab = vocab + 'z' | |
| vocab = vocab + blank_char | |
| # build a dictionary convert from vocab to idx and idx to vocab | |
| char_to_idx = { | |
| char: idx + 1 for idx, char in enumerate(vocab) | |
| } | |
| idx_to_char = { | |
| idx: char for char, idx in char_to_idx.items() | |
| } | |
| # save | |
| with open('src/encode.pkl', "wb") as file: | |
| pickle.dump(char_to_idx, file) | |
| with open('src/decode.pkl', "wb") as file: | |
| pickle.dump(idx_to_char, file) | |
| return char_to_idx, idx_to_char | |
| def get_imagepaths_and_labels(root_path): | |
| img_paths = [] | |
| labels = [] | |
| # Read labels from text file | |
| with open(os.path.join(root_path, 'ocr_dataset', 'labels.txt'), "r") as f: | |
| for label in f: | |
| labels.append(label.strip().split("\t")[1]) | |
| img_paths.append(label.strip().split("\t")[0]) | |
| return img_paths, labels | |
| def encode(label, char_to_idx, labels): | |
| max_length_label = np.max([len(lb) for lb in labels]) | |
| # encode label | |
| encoded_label = torch.tensor( | |
| [char_to_idx[char] for char in label], | |
| dtype=torch.int32 | |
| ) | |
| label_len = len(encoded_label) | |
| length = torch.tensor( | |
| label_len, | |
| dtype=torch.int32 | |
| ) | |
| padded_label = F.pad( | |
| encoded_label, | |
| (0, max_length_label-label_len), | |
| value=0 | |
| ) | |
| return padded_label, length | |
| def decode(encoded_label, idx_to_char, char_to_idx, blank_char='@'): | |
| label = [] | |
| encoded_label = encoded_label.detach().numpy() | |
| for i in range(len(encoded_label)): | |
| if encoded_label[i] == 0: | |
| break | |
| elif (i == 0 or encoded_label[i] != encoded_label[i-1]) and encoded_label[i] != char_to_idx[blank_char]: | |
| label.append(idx_to_char[encoded_label[i]]) | |
| label = "".join(label) | |
| return label | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--path", type=str, default=os.getcwd(), help="Path to the root directory") | |
| args = parser.parse_args() | |
| root_path = os.path.join(args.path, 'Dataset') | |
| image_paths, image_sizes, bboxes, image_labels = extract_data_from_xml(root_path) | |
| save_dir = 'Dataset/ocr_dataset' | |
| split_bboxes_from_image(image_paths, image_labels, bboxes, save_dir) | |
| char_to_idx, idx_to_char = build_vocab(root_path) | |
| if __name__ == '__main__': | |
| main() |