Spaces:
Runtime error
Runtime error
| # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import logging | |
| import os | |
| import imghdr | |
| import cv2 | |
| import random | |
| import numpy as np | |
| import paddle | |
| import importlib.util | |
| import sys | |
| import subprocess | |
| def print_dict(d, logger, delimiter=0): | |
| """ | |
| Recursively visualize a dict and | |
| indenting acrrording by the relationship of keys. | |
| """ | |
| for k, v in sorted(d.items()): | |
| if isinstance(v, dict): | |
| logger.info("{}{} : ".format(delimiter * " ", str(k))) | |
| print_dict(v, logger, delimiter + 4) | |
| elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): | |
| logger.info("{}{} : ".format(delimiter * " ", str(k))) | |
| for value in v: | |
| print_dict(value, logger, delimiter + 4) | |
| else: | |
| logger.info("{}{} : {}".format(delimiter * " ", k, v)) | |
| def get_check_global_params(mode): | |
| check_params = ['use_gpu', 'max_text_length', 'image_shape', \ | |
| 'image_shape', 'character_type', 'loss_type'] | |
| if mode == "train_eval": | |
| check_params = check_params + [ \ | |
| 'train_batch_size_per_card', 'test_batch_size_per_card'] | |
| elif mode == "test": | |
| check_params = check_params + ['test_batch_size_per_card'] | |
| return check_params | |
| def _check_image_file(path): | |
| img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'} | |
| return any([path.lower().endswith(e) for e in img_end]) | |
| def get_image_file_list(img_file): | |
| imgs_lists = [] | |
| if img_file is None or not os.path.exists(img_file): | |
| raise Exception("not found any img file in {}".format(img_file)) | |
| if os.path.isfile(img_file) and _check_image_file(img_file): | |
| imgs_lists.append(img_file) | |
| elif os.path.isdir(img_file): | |
| for single_file in os.listdir(img_file): | |
| file_path = os.path.join(img_file, single_file) | |
| if os.path.isfile(file_path) and _check_image_file(file_path): | |
| imgs_lists.append(file_path) | |
| if len(imgs_lists) == 0: | |
| raise Exception("not found any img file in {}".format(img_file)) | |
| imgs_lists = sorted(imgs_lists) | |
| return imgs_lists | |
| def binarize_img(img): | |
| if len(img.shape) == 3 and img.shape[2] == 3: | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # conversion to grayscale image | |
| # use cv2 threshold binarization | |
| _, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR) | |
| return img | |
| def alpha_to_color(img, alpha_color=(255, 255, 255)): | |
| if len(img.shape) == 3 and img.shape[2] == 4: | |
| B, G, R, A = cv2.split(img) | |
| alpha = A / 255 | |
| R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8) | |
| G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8) | |
| B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8) | |
| img = cv2.merge((B, G, R)) | |
| return img | |
| def check_and_read(img_path): | |
| if os.path.basename(img_path)[-3:].lower() == 'gif': | |
| gif = cv2.VideoCapture(img_path) | |
| ret, frame = gif.read() | |
| if not ret: | |
| logger = logging.getLogger('ppocr') | |
| logger.info("Cannot read {}. This gif image maybe corrupted.") | |
| return None, False | |
| if len(frame.shape) == 2 or frame.shape[-1] == 1: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) | |
| imgvalue = frame[:, :, ::-1] | |
| return imgvalue, True, False | |
| elif os.path.basename(img_path)[-3:].lower() == 'pdf': | |
| import fitz | |
| from PIL import Image | |
| imgs = [] | |
| with fitz.open(img_path) as pdf: | |
| for pg in range(0, pdf.page_count): | |
| page = pdf[pg] | |
| mat = fitz.Matrix(2, 2) | |
| pm = page.get_pixmap(matrix=mat, alpha=False) | |
| # if width or height > 2000 pixels, don't enlarge the image | |
| if pm.width > 2000 or pm.height > 2000: | |
| pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) | |
| img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples) | |
| img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
| imgs.append(img) | |
| return imgs, False, True | |
| return None, False, False | |
| def load_vqa_bio_label_maps(label_map_path): | |
| with open(label_map_path, "r", encoding='utf-8') as fin: | |
| lines = fin.readlines() | |
| old_lines = [line.strip() for line in lines] | |
| lines = ["O"] | |
| for line in old_lines: | |
| # "O" has already been in lines | |
| if line.upper() in ["OTHER", "OTHERS", "IGNORE"]: | |
| continue | |
| lines.append(line) | |
| labels = ["O"] | |
| for line in lines[1:]: | |
| labels.append("B-" + line) | |
| labels.append("I-" + line) | |
| label2id_map = {label.upper(): idx for idx, label in enumerate(labels)} | |
| id2label_map = {idx: label.upper() for idx, label in enumerate(labels)} | |
| return label2id_map, id2label_map | |
| def set_seed(seed=1024): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| paddle.seed(seed) | |
| def check_install(module_name, install_name): | |
| spec = importlib.util.find_spec(module_name) | |
| if spec is None: | |
| print(f'Warnning! The {module_name} module is NOT installed') | |
| print( | |
| f'Try install {module_name} module automatically. You can also try to install manually by pip install {install_name}.' | |
| ) | |
| python = sys.executable | |
| try: | |
| subprocess.check_call( | |
| [python, '-m', 'pip', 'install', install_name], | |
| stdout=subprocess.DEVNULL) | |
| print(f'The {module_name} module is now installed') | |
| except subprocess.CalledProcessError as exc: | |
| raise Exception( | |
| f"Install {module_name} failed, please install manually") | |
| else: | |
| print(f"{module_name} has been installed.") | |
| class AverageMeter: | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| """reset""" | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| """update""" | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |