import json from basemodel import TextDetBase, TextDetBaseDNN import os.path as osp from tqdm import tqdm import numpy as np import cv2 import torch from pathlib import Path import torch from utils.yolov5_utils import non_max_suppression from utils.db_utils import SegDetectorRepresenter from utils.io_utils import imread, imwrite, find_all_imgs, NumpyEncoder from utils.imgproc_utils import letterbox, xyxy2yolo, get_yololabel_strings from utils.textblock import TextBlock, group_output, visualize_textblocks from utils.textmask import refine_mask, refine_undetected_mask, REFINEMASK_INPAINT, REFINEMASK_ANNOTATION from pathlib import Path from typing import Union from manga_ocr import MangaOcr from PIL import Image def init_model(model_path, device): cuda = torch.cuda.is_available() device = 'cuda' if cuda else 'cpu' model = TextDetector(model_path=model_path, input_size=1024, device=device, act='leaky') return model def model2annotations(img_dir_list, save_dir, save_json=False, model=None): mocr = MangaOcr() if isinstance(img_dir_list, str): img_dir_list = [img_dir_list] # cuda = torch.cuda.is_available() # device = 'cuda' if cuda else 'cpu' # model = TextDetector(model_path=model_path, input_size=1024, device=device, act='leaky') imglist = [] result = [] for img_dir in img_dir_list: imglist += find_all_imgs(img_dir, abs_path=True) for img_path in tqdm(imglist): imgname = osp.basename(img_path) img = imread(img_path) im_h, im_w = img.shape[:2] imname = imgname.replace(Path(imgname).suffix, '') maskname = 'mask-'+imname+'.png' poly_save_path = osp.join(save_dir, 'line-' + imname + '.txt') mask, mask_refined, blk_list = model(img, refine_mode=REFINEMASK_ANNOTATION, keep_undetected_mask=True) polys = [] blk_xyxy = [] blk_dict_list = [] for blk in blk_list: polys += blk.lines blk_xyxy.append(blk.xyxy) temp_img = Image.open(img_path) cropped_img = temp_img.crop(blk.xyxy) ocr_text = mocr(cropped_img) blk_idct = blk.to_dict() blk_idct['text'] = ocr_text blk_dict_list.append(blk_idct) blk_xyxy = xyxy2yolo(blk_xyxy, im_w, im_h) if blk_xyxy is not None: cls_list = [1] * len(blk_xyxy) yolo_label = get_yololabel_strings(cls_list, blk_xyxy) else: yolo_label = '' with open(osp.join(save_dir, imname+'.txt'), 'w', encoding='utf8') as f: f.write(yolo_label) # num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask) # _, mask = cv2.threshold(mask, 50, 255, cv2.THRESH_BINARY) # draw_connected_labels(num_labels, labels, stats, centroids) # visualize_textblocks(img, blk_list) # cv2.imshow('rst', img) # cv2.imshow('mask', mask) # cv2.imshow('mask_refined', mask_refined) # cv2.waitKey(0) if len(polys) != 0: if isinstance(polys, list): polys = np.array(polys) polys = polys.reshape(-1, 8) np.savetxt(poly_save_path, polys, fmt='%d') if save_json: with open(osp.join(save_dir, imname+'.json'), 'w', encoding='utf8') as f: f.write(json.dumps(blk_dict_list, ensure_ascii=False, cls=NumpyEncoder)) imwrite(osp.join(save_dir, imgname), img) imwrite(osp.join(save_dir, maskname), mask_refined) result.append(blk_dict_list) return json.dumps(result, ensure_ascii=False, cls=NumpyEncoder) def preprocess_img(img, input_size=(1024, 1024), device='cpu', bgr2rgb=True, half=False, to_tensor=True): if bgr2rgb: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_in, ratio, (dw, dh) = letterbox(img, new_shape=input_size, auto=False, stride=64) if to_tensor: img_in = img_in.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB img_in = np.array([np.ascontiguousarray(img_in)]).astype(np.float32) / 255 if to_tensor: img_in = torch.from_numpy(img_in).to(device) if half: img_in = img_in.half() return img_in, ratio, int(dw), int(dh) def postprocess_mask(img: Union[torch.Tensor, np.ndarray], thresh=None): # img = img.permute(1, 2, 0) if isinstance(img, torch.Tensor): img = img.squeeze_() if img.device != 'cpu': img = img.detach_().cpu() img = img.numpy() else: img = img.squeeze() if thresh is not None: img = img > thresh img = img * 255 # if isinstance(img, torch.Tensor): return img.astype(np.uint8) def postprocess_yolo(det, conf_thresh, nms_thresh, resize_ratio, sort_func=None): det = non_max_suppression(det, conf_thresh, nms_thresh)[0] # bbox = det[..., 0:4] if det.device != 'cpu': det = det.detach_().cpu().numpy() det[..., [0, 2]] = det[..., [0, 2]] * resize_ratio[0] det[..., [1, 3]] = det[..., [1, 3]] * resize_ratio[1] if sort_func is not None: det = sort_func(det) blines = det[..., 0:4].astype(np.int32) confs = np.round(det[..., 4], 3) cls = det[..., 5].astype(np.int32) return blines, cls, confs class TextDetector: lang_list = ['eng', 'ja', 'unknown'] langcls2idx = {'eng': 0, 'ja': 1, 'unknown': 2} def __init__(self, model_path, input_size=1024, device='cpu', half=False, nms_thresh=0.35, conf_thresh=0.4, mask_thresh=0.3, act='leaky'): super(TextDetector, self).__init__() cuda = device == 'cuda' if Path(model_path).suffix == '.onnx': self.model = cv2.dnn.readNetFromONNX(model_path) self.net = TextDetBaseDNN(input_size, model_path) self.backend = 'opencv' else: self.net = TextDetBase(model_path, device=device, act=act) self.backend = 'torch' if isinstance(input_size, int): input_size = (input_size, input_size) self.input_size = input_size self.device = device self.half = half self.conf_thresh = conf_thresh self.nms_thresh = nms_thresh self.seg_rep = SegDetectorRepresenter(thresh=0.3) @torch.no_grad() def __call__(self, img, refine_mode=REFINEMASK_INPAINT, keep_undetected_mask=False): img_in, ratio, dw, dh = preprocess_img(img, input_size=self.input_size, device=self.device, half=self.half, to_tensor=self.backend=='torch') im_h, im_w = img.shape[:2] blks, mask, lines_map = self.net(img_in) resize_ratio = (im_w / (self.input_size[0] - dw), im_h / (self.input_size[1] - dh)) blks = postprocess_yolo(blks, self.conf_thresh, self.nms_thresh, resize_ratio) if self.backend == 'opencv': if mask.shape[1] == 2: # some version of opencv spit out reversed result tmp = mask mask = lines_map lines_map = tmp mask = postprocess_mask(mask) lines, scores = self.seg_rep(self.input_size, lines_map) box_thresh = 0.6 idx = np.where(scores[0] > box_thresh) lines, scores = lines[0][idx], scores[0][idx] # map output to input img mask = mask[: mask.shape[0]-dh, : mask.shape[1]-dw] mask = cv2.resize(mask, (im_w, im_h), interpolation=cv2.INTER_LINEAR) if lines.size == 0 : lines = [] else : lines = lines.astype(np.float64) lines[..., 0] *= resize_ratio[0] lines[..., 1] *= resize_ratio[1] lines = lines.astype(np.int32) blk_list = group_output(blks, lines, im_w, im_h, mask) mask_refined = refine_mask(img, mask, blk_list, refine_mode=refine_mode) if keep_undetected_mask: mask_refined = refine_undetected_mask(img, mask, mask_refined, blk_list, refine_mode=refine_mode) return mask, mask_refined, blk_list def traverse_by_dict(img_dir_list, dict_dir): if isinstance(img_dir_list, str): img_dir_list = [img_dir_list] imglist = [] for img_dir in img_dir_list: imglist += find_all_imgs(img_dir, abs_path=True) for img_path in tqdm(imglist): imgname = osp.basename(img_path) imname = imgname.replace(Path(imgname).suffix, '') mask_path = osp.join(dict_dir, 'mask-'+imname+'.png') with open(osp.join(dict_dir, imname+'.json'), 'r', encoding='utf8') as f: blk_dict_list = json.loads(f.read()) blk_list = [TextBlock(**blk_dict) for blk_dict in blk_dict_list] img = cv2.imread(img_path) mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) mask = refine_mask(img, mask, blk_list) visualize_textblocks(img, blk_list, path=dict_dir) #cv2.imshow('im', img) #cv2.imshow('mask', mask) cv2.imwrite(f'{dict_dir}/labeled.png', img) #cv2.imwrite('mask.png', mask) #cv2.waitKey(0) return len(blk_list) if __name__ == '__main__': device = 'cpu' #model_path = 'data/comictextdetector.pt' model_path = 'data/comictextdetector.pt.onnx' img_dir = r'../input' save_dir = r'../output' model2annotations(model_path, img_dir, save_dir, save_json=True) traverse_by_dict(img_dir, save_dir)