Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| import torch | |
| from torchvision import transforms | |
| import os | |
| import cv2 | |
| import time | |
| import numpy as np | |
| import pyclipper | |
| import gradio as gr | |
| from models import get_model | |
| from utils.util import show_img, draw_bbox | |
| class Pytorch_model: | |
| def __init__(self, model_path, gpu_id=None): | |
| ''' | |
| 初始化pytorch模型 | |
| :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件) | |
| :param gpu_id: 在哪一块gpu上运行 | |
| ''' | |
| self.gpu_id = gpu_id | |
| if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available(): | |
| self.device = torch.device("cuda:%s" % self.gpu_id) | |
| checkpoint = torch.load(model_path) | |
| else: | |
| self.device = torch.device("cpu") | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| print('device:', self.device) | |
| config = checkpoint['config'] | |
| config['arch']['args']['pretrained'] = False | |
| self.net = get_model(config) | |
| self.img_channel = config['data_loader']['args']['dataset']['img_channel'] | |
| self.net.load_state_dict(checkpoint['state_dict']) ## load weights | |
| self.net.to(self.device) | |
| self.net.eval() | |
| def predict(self, img, short_size: int = 736, min_area: int = 100): | |
| ''' | |
| 对传入的图像进行预测,支持图像地址, opencv读取图片,偏慢 | |
| :param img: the image as an np array | |
| :param short_size: | |
| :param min_area: 小于该尺度的bbox忽略 | |
| :return: | |
| ''' | |
| if self.img_channel == 3: | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| h, w = img.shape[:2] | |
| scale = short_size / min(h, w) | |
| img = cv2.resize(img, None, fx=scale, fy=scale) | |
| tensor = transforms.ToTensor()(img) | |
| tensor = tensor.unsqueeze_(0) | |
| tensor = tensor.to(self.device) | |
| with torch.no_grad(): | |
| # torch.cuda.synchronize(self.device) | |
| start = time.time() | |
| preds = self.net(tensor)[0] | |
| # torch.cuda.synchronize(self.device) | |
| scale = (preds.shape[2] / w, preds.shape[1] / h) | |
| t = time.time() - start | |
| '''inference''' | |
| start = time.time() | |
| prob_map, thres_map = preds[0], preds[1] | |
| ## Step 1: Use threshold to get the binary map | |
| thr = 0.2 | |
| out = (prob_map > thr).float() * 255 | |
| out = out.data.cpu().numpy().astype(np.uint8) | |
| # cv2.imwrite('c_bin_map.png', out) | |
| ## Step 2: Connected components findContours | |
| contours, hierarchy = cv2.findContours(out, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| contours = [(i / scale).astype(int) for i in contours if len(i)>=4] | |
| # Step 3: Dilate the shrunk region (not necessary) | |
| ratio_prime = 1.5 | |
| dilated_polys = [] | |
| for poly in contours: | |
| poly = poly[:,0,:] | |
| D_prime = cv2.contourArea(poly) * ratio_prime / cv2.arcLength(poly, True) # formula(10) in the thesis | |
| pco = pyclipper.PyclipperOffset() | |
| pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) | |
| dilated_poly = np.array(pco.Execute(D_prime)) | |
| if dilated_poly.size == 0 or dilated_poly.dtype != int or len(dilated_poly) != 1: | |
| continue | |
| dilated_polys.append(dilated_poly) | |
| boxes_list = [] | |
| for cnt in dilated_polys: | |
| # print('=============') | |
| # print(cnt) | |
| # print(len(cnt)) | |
| if cv2.contourArea(cnt) < min_area: | |
| continue | |
| rect = cv2.minAreaRect(cnt) | |
| box = (cv2.boxPoints(rect)).astype(int) | |
| boxes_list.append(box) | |
| t = time.time() - start + t | |
| boxes_list = np.array(boxes_list) | |
| return dilated_polys, boxes_list, t | |
| def run(img): | |
| # os.environ["CUDA_VISIBLE_DEVICES"] = "5" | |
| model_path = 'MobileNetv2_best_loss.pth' | |
| # init model | |
| model = Pytorch_model(model_path, gpu_id = None) ## set GPU id or None if you only have cpu | |
| contours, boxes_list, t = model.predict(img) | |
| print('Time: %.4f' %t) | |
| img = img[:, :, ::-1] | |
| imgc = img.copy() | |
| cv2.drawContours(imgc, contours, -1, (22,222,22), 2, cv2.LINE_AA) | |
| return imgc | |
| # cv2.imwrite('contour.png', imgc) | |
| # img = draw_bbox(img, boxes_list) | |
| # cv2.imwrite('predict.jpg', img) | |
| if __name__ == '__main__': | |
| iface = gr.Interface(fn=run, | |
| title="CJK Text Detection Using DBNet", | |
| description="Gives an image containing CJK texts, we will mark the texts out!", | |
| inputs="image", | |
| outputs="image") | |
| iface.launch() | |