Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import os | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import xml.etree.ElementTree as ET | |
| import shutil | |
| import yaml | |
| from sklearn.model_selection import train_test_split | |
| location_path = r'Dataset/locations.xml' | |
| tree = ET.parse(location_path) | |
| root = tree.getroot() | |
| def get_gt_bboxes(location_path): | |
| """get all the gt bbox of text in dataset | |
| Args: | |
| location_path: (path) | |
| Return: | |
| gt_imagepaths[1] (list): image's name | |
| gt_locations (list): bboxes of each image | |
| """ | |
| gt_imagepaths = [] | |
| gt_imagesizes = [] | |
| gt_locations = [] | |
| for image in root: | |
| # get path to image | |
| image_name = image[0].text | |
| image_path = os.path.join('Dataset', image_name) | |
| gt_imagepaths.append(image_path) | |
| # get the image size | |
| w = image[1].get('x') | |
| h = image[1].get('y') | |
| gt_imagesizes.append([h, w]) | |
| # bboxes in the image | |
| bbs = [] | |
| for bbox in image[2]: | |
| x = np.int64(float(bbox.get('x'))) | |
| y = np.int64(float(bbox.get('y'))) | |
| width = np.int64(float(bbox.get('width'))) | |
| height = np.int64(float(bbox.get('height'))) | |
| bbs.append([x, y, width, height]) | |
| gt_locations.append(bbs) | |
| return gt_imagepaths, gt_imagesizes, gt_locations | |
| gt_imagepaths, gt_imagesizes, gt_locations = get_gt_bboxes(location_path) | |
| def visualize_gt_bboxes(image_path, gt_locations): | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| for gt_location in gt_locations: | |
| x, y, width, height = gt_location | |
| x, y, width, height = int(x), int(y), int(width), int(height) | |
| image = cv2.rectangle(image, (x, y), (x+width, y+height), color=(255, 0, 0), thickness=2) | |
| plt.imshow(image) | |
| plt.axis('off') | |
| plt.show() | |
| def visualize_gt_bboxes_yolo(image_path, gt_location_yolo): | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image_height, image_width = image.shape[:2] | |
| # Convert to original format | |
| for data in gt_location_yolo: | |
| xc, yc, w, h = data[1:] | |
| xmin = int((xc - w/2) * image_width) | |
| ymin = int((yc - h/2) * image_height) | |
| xmax = int((xc + w/2) * image_width) | |
| ymax = int((yc + h/2) * image_height) | |
| image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=(255, 0, 0), thickness=2) | |
| plt.imshow(image) | |
| plt.axis('off') | |
| plt.show() | |
| def convert_yolo_format(gt_locations, gt_imagesizes): | |
| gt_locations_yolo = [] | |
| for image, image_size in zip(gt_locations, gt_imagesizes): | |
| gt_location_yolo = [] | |
| for gt_location in image: | |
| x, y, w, h = gt_location | |
| image_height, image_width = image_size | |
| xc = (x + w/2) / float(image_width) | |
| yc = (y + h/2) / float(image_height) | |
| width = w / float(image_width) | |
| height = h / float(image_height) | |
| # class = 0 -> meaning contains text | |
| class_id = 0 | |
| gt_location_yolo.append([class_id, xc, yc, width, height]) | |
| gt_locations_yolo.append(gt_location_yolo) | |
| return gt_locations_yolo | |
| gt_locations_yolo = convert_yolo_format(gt_locations, gt_imagesizes) | |
| def save_data_into_yolo_folder(data, src_img_dir, save_dir): | |
| # Create folder if not exist | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Make images and labels folder | |
| os.makedirs(os.path.join(save_dir, 'images'), exist_ok=True) | |
| os.makedirs(os.path.join(save_dir, 'labels'), exist_ok=True) | |
| # write data into yolo folder | |
| for dt in data: | |
| # copy data | |
| image_path = dt[0] | |
| shutil.copy(image_path, os.path.join(save_dir, 'images')) | |
| #copy labels | |
| image_name = os.path.basename(image_path) | |
| image_name = os.path.splitext(image_name)[0] | |
| with open(os.path.join(save_dir, 'labels', f'{image_name}.txt'), "w") as f: | |
| for label in dt[1]: | |
| label_str = " ".join(map(str, label)) | |
| f.write(f'{label_str}\n') | |
| seed = 0 | |
| val_size = 0.2 | |
| test_size = 0.125 | |
| dataset = [[gt_imagepath, gt_location_yolo] for gt_imagepath, gt_location_yolo in zip(gt_imagepaths, gt_locations_yolo)] | |
| train_data, val_data = train_test_split(dataset, test_size=val_size, random_state=42, shuffle=True) | |
| train_data, test_data = train_test_split(train_data, test_size=test_size, random_state=42, shuffle=True) | |
| save_yolo_data_dir = 'yolo_data' | |
| os.makedirs(save_yolo_data_dir, exist_ok=True) | |
| save_data_into_yolo_folder( | |
| data=train_data, | |
| src_img_dir=save_yolo_data_dir, | |
| save_dir=os.path.join(save_yolo_data_dir, 'train') | |
| ) | |
| save_data_into_yolo_folder( | |
| data=val_data, | |
| src_img_dir=save_yolo_data_dir, | |
| save_dir=os.path.join(save_yolo_data_dir, 'val') | |
| ) | |
| save_data_into_yolo_folder( | |
| data=test_data, | |
| src_img_dir=save_yolo_data_dir, | |
| save_dir=os.path.join(save_yolo_data_dir, 'test') | |
| ) | |
| class_label = ['text'] | |
| # Create data.yaml file | |
| data_yaml = { | |
| "path": '../yolo_data', | |
| 'train': 'train/images', | |
| 'test': 'test/images', | |
| 'val': 'val/images', | |
| 'nc': 1, | |
| 'names': class_label | |
| } | |
| yolo_yaml_path = os.path.join(save_yolo_data_dir, 'data.yaml') | |
| with open(yolo_yaml_path, "w") as f: | |
| yaml.dump(data_yaml, f, default_flow_style=False) |