akhaliq's picture
akhaliq HF Staff
Upload 157 files
939bf35 verified
"""Modified from https://github.com/JaidedAI/EasyOCR/blob/803b907/easyocr/detection.py.
1. Disable DataParallel.
"""
from collections import OrderedDict
import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image
from torch.autograd import Variable
from .craft import CRAFT
from .craft_utils import adjustResultCoordinates, getDetBoxes
from .imgproc import normalizeMeanVariance, resize_aspect_ratio
def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith("module"):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = ".".join(k.split(".")[start_idx:])
new_state_dict[name] = v
return new_state_dict
def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False):
if isinstance(image, np.ndarray) and len(image.shape) == 4: # image is batch of np arrays
image_arrs = image
else: # image is single numpy array
image_arrs = [image]
img_resized_list = []
# resize
for img in image_arrs:
img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size,
interpolation=cv2.INTER_LINEAR,
mag_ratio=mag_ratio)
img_resized_list.append(img_resized)
ratio_h = ratio_w = 1 / target_ratio
# preprocessing
x = [np.transpose(normalizeMeanVariance(n_img), (2, 0, 1))
for n_img in img_resized_list]
x = torch.from_numpy(np.array(x))
x = x.to(device)
# forward pass
with torch.no_grad():
y, feature = net(x)
boxes_list, polys_list = [], []
for out in y:
# make score and link map
score_text = out[:, :, 0].cpu().data.numpy()
score_link = out[:, :, 1].cpu().data.numpy()
# Post-processing
boxes, polys, mapper = getDetBoxes(
score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars)
# coordinate adjustment
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
if estimate_num_chars:
boxes = list(boxes)
polys = list(polys)
for k in range(len(polys)):
if estimate_num_chars:
boxes[k] = (boxes[k], mapper[k])
if polys[k] is None:
polys[k] = boxes[k]
boxes_list.append(boxes)
polys_list.append(polys)
return boxes_list, polys_list
def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False):
net = CRAFT()
if device == 'cpu':
net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
if quantize:
try:
torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True)
except:
pass
else:
net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
# net = torch.nn.DataParallel(net).to(device)
net = net.to(device)
cudnn.benchmark = cudnn_benchmark
net.eval()
return net
def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None, **kwargs):
result = []
estimate_num_chars = optimal_num_chars is not None
bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector,
image, text_threshold,
link_threshold, low_text, poly,
device, estimate_num_chars)
if estimate_num_chars:
polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))]
for polys in polys_list]
for polys in polys_list:
single_img_result = []
for i, box in enumerate(polys):
poly = np.array(box).astype(np.int32).reshape((-1))
single_img_result.append(poly)
result.append(single_img_result)
return result