Spaces:
Build error
Build error
| import os | |
| import json | |
| import random | |
| from PIL import Image, ImageDraw | |
| from pdf_extract_kit.registry.registry import TASK_REGISTRY | |
| from pdf_extract_kit.utils.data_preprocess import load_pdf | |
| from pdf_extract_kit.tasks.base_task import BaseTask | |
| class OCRTask(BaseTask): | |
| def __init__(self, model): | |
| """init the task based on the given model. | |
| Args: | |
| model: task model, must contains predict function. | |
| """ | |
| super().__init__(model) | |
| def predict_image(self, image): | |
| """predict on one image, reture text detection and recognition results. | |
| Args: | |
| image: PIL.Image.Image, (if the model.predict function support other types, remenber add change-format-function in model.predict) | |
| Returns: | |
| List[dict]: list of text bbox with it's content | |
| Return example: | |
| [ | |
| { | |
| "category_type": "text", | |
| "poly": [ | |
| 380.6792698635707, | |
| 159.85058512958923, | |
| 765.1419999999998, | |
| 159.85058512958923, | |
| 765.1419999999998, | |
| 192.51073013642917, | |
| 380.6792698635707, | |
| 192.51073013642917 | |
| ], | |
| "text": "this is an example text", | |
| "score": 0.97 | |
| }, | |
| ... | |
| ] | |
| """ | |
| return self.model.predict(image) | |
| def prepare_input_files(self, input_path): | |
| if os.path.isdir(input_path): | |
| file_list = [os.path.join(input_path, fname) for fname in os.listdir(input_path)] | |
| else: | |
| file_list = [input_path] | |
| return file_list | |
| def process(self, input_path, save_dir=None, visualize=False): | |
| file_list = self.prepare_input_files(input_path) | |
| res_list = [] | |
| for fpath in file_list: | |
| basename = os.path.basename(fpath)[:-4] | |
| if fpath.endswith(".pdf") or fpath.endswith(".PDF"): | |
| images = load_pdf(fpath) | |
| pdf_res = [] | |
| for page, img in enumerate(images): | |
| page_res = self.predict_image(img) | |
| pdf_res.append(page_res) | |
| if save_dir: | |
| os.makedirs(os.path.join(save_dir, basename), exist_ok=True) | |
| self.save_json_result(page_res, os.path.join(save_dir, basename, f"page_{page+1}.json")) | |
| if visualize: | |
| self.visualize_image(img, page_res, os.path.join(save_dir, basename, f"page_{page+1}.jpg")) | |
| res_list.append(pdf_res) | |
| else: | |
| image = Image.open(fpath) | |
| img_res = self.predict_image(image) | |
| res_list.append(img_res) | |
| if save_dir: | |
| os.makedirs(save_dir, exist_ok=True) | |
| self.save_json_result(img_res, os.path.join(save_dir, f"{basename}.json")) | |
| if visualize: | |
| self.visualize_image(image, img_res, os.path.join(save_dir, f"{basename}.png")) | |
| return res_list | |
| def visualize_image(self, image, ocr_res, save_path="", cate2color={}): | |
| """plot each result's bbox and category on image. | |
| Args: | |
| image: PIL.Image.Image | |
| ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function | |
| save_path: path to save visualized image | |
| """ | |
| draw = ImageDraw.Draw(image) | |
| for res in ocr_res: | |
| box_color = cate2color.get(res['category_type'], (0, 255, 0)) | |
| x_min, y_min = int(res['poly'][0]), int(res['poly'][1]) | |
| x_max, y_max = int(res['poly'][4]), int(res['poly'][5]) | |
| draw.rectangle([x_min, y_min, x_max, y_max], fill=None, outline=box_color, width=1) | |
| draw.text((x_min, y_min), res['category_type'], (255, 0, 0)) | |
| if save_path: | |
| image.save(save_path) | |
| def save_json_result(self, ocr_res, save_path): | |
| """save results to a json file. | |
| Args: | |
| ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function | |
| save_path: path to save visualized image | |
| """ | |
| with open(save_path, "w", encoding="utf-8") as f: | |
| f.write(json.dumps(ocr_res, indent=2, ensure_ascii=False)) | |