Spaces:
Sleeping
Sleeping
File size: 5,963 Bytes
43bca44 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | from utils import string_utils, error_rates
import numpy as np
from . import nms
import copy
def get_trimmed_polygons(out):
all_polygons = []
for j in range(out['lf'][0].shape[0]):
begin = out['beginning'][j]
end = out['ending'][j]
last_xy = None
begin_f = int(np.floor(begin))
end_f = int(np.ceil(end))
points = []
for i in range(begin_f, end_f+1):
if i == begin_f:
p0 = out['lf'][i][j]
p1 = out['lf'][i+1][j]
t = begin - np.floor(begin)
p = p0 * (1 - t) + p1 * t
elif i == end_f:
p0 = out['lf'][i-1][j]
if i != len(out['lf']):
p1 = out['lf'][i][j]
t = end - np.floor(end)
p = p0 * (1 - t) + p1 * t
else:
p = p0
else:
p = out['lf'][i][j]
points.append(p)
points = np.array(points)
all_polygons.append(points)
return all_polygons
def trim_ends(out):
lf_length = len(out['lf'])
hw = out['hw']
# Mehreen: hw is (14, 361, 197) a 14x361 matrix for each character. selected is (14, 361)
selected = hw.argmax(axis=-1)
beginning = np.argmax(selected != 0, axis=1)
ending = selected.shape[1] - 1 - np.argmax(selected[:,::-1] != 0, axis=1)
beginning_percent = (beginning+0.5) / float(selected.shape[1])
ending_percent = (ending+0.5) / float(selected.shape[1])
lf_beginning = lf_length * beginning_percent
lf_ending = lf_length * ending_percent
out['beginning'] = lf_beginning
out['ending'] = lf_ending
return out
def filter_on_pick(out, pick):
out['sol'] = out['sol'][pick]
out['lf'] = [l[pick] for l in out['lf']]
out['hw'] = out['hw'][pick]
if 'idx' in out:
out['idx'] = out['idx'][pick]
if 'beginning' in out:
out['beginning'] = out['beginning'][pick]
if 'ending' in out:
out['ending'] = out['ending'][pick]
def filter_on_pick_no_copy(out, pick):
output = {}
output['sol'] = out['sol'][pick]
output['lf'] = [l[pick] for l in out['lf']]
output['hw'] = out['hw'][pick]
##Mehreen
#print(pick)
#out['line_imgs'] = out['line_imgs'][pick]
## End mehreen
if 'idx' in out:
output['idx'] = out['idx'][pick]
if 'beginning' in out:
output['beginning'] = out['beginning'][pick]
if 'ending' in out:
output['ending'] = out['ending'][pick]
return output
def select_non_empty_string(out):
selected = out['hw'].argmax(axis=-1)
return np.where(selected.sum(axis=1) != 0)
def postprocess(out, **kwargs):
out = copy.copy(out)
# postprocessing should be done with numpy data
sol_threshold = kwargs.get("sol_threshold", None)
sol_nms_threshold = kwargs.get("sol_nms_threshold", None)
lf_nms_params = kwargs.get('lf_nms_params', None)
lf_nms_2_params = kwargs.get('lf_nms_2_params', None)
if sol_threshold is not None:
pick = np.where(out['sol'][:,-1] > sol_threshold)
filter_on_pick(out, pick)
#Mehreen: this is passed as None from run_hwr from decode_one_img_with_info
if sol_nms_threshold is not None:
raise Exception("This is not correct")
pick = nms.sol_nms_single(out['sol'], sol_nms_threshold)
out['sol'] = out['sol'][pick]
#Mehreen: When post-processing this part is done. sample_config lf_nms_range: [0,6] lf_nms_threshold: 0.5
if lf_nms_params is not None:
confidences = out['sol'][:,-1]
overlap_range = lf_nms_params['overlap_range']
overlap_thresh = lf_nms_params['overlap_threshold']
lf_setup = np.concatenate([l[None,...] for l in out['lf']])
lf_setup = [lf_setup[:,i] for i in range(lf_setup.shape[1])]
pick = nms.lf_non_max_suppression_area(lf_setup, confidences, overlap_range, overlap_thresh)
filter_on_pick(out, pick)
#Mehreen: When post-processing this part is None from decode_one_img_with_info
if lf_nms_2_params is not None:
confidences = out['sol'][:,-1]
overlap_thresh = lf_nms_2_params['overlap_threshold']
refined_lf = get_trimmed_polygons(out)
pick = nms.lf_non_max_suppression_area(refined_lf, confidences, None, overlap_thresh)
filter_on_pick(out, pick)
return out
def read_order(out):
first_pt = out['lf'][0][:,:2,0]
first_pt = first_pt[:,::-1]
first_pt = np.concatenate([first_pt, np.arange(first_pt.shape[0])[:,None]], axis=1)
first_pt = first_pt.tolist()
first_pt.sort()
return [int(p[2]) for p in first_pt]
def decode_handwriting(out, idx_to_char):
hw_out = out['hw']
list_of_pred = []
list_of_raw_pred = []
for i in range(hw_out.shape[0]):
logits = hw_out[i,...]
pred, raw_pred = string_utils.naive_decode(logits)
pred_str = string_utils.label2str_single(pred, idx_to_char, False)
raw_pred_str = string_utils.label2str_single(raw_pred, idx_to_char, True)
list_of_pred.append(pred_str)
list_of_raw_pred.append(raw_pred_str)
return list_of_pred, list_of_raw_pred
def results_to_numpy(out):
return {
"sol": out['sol'].data.cpu().numpy()[:,0,:],
"lf": [l.data.cpu().numpy() for l in out['lf']] if out['lf'] is not None else None,
"hw": out['hw'].data.cpu().numpy(),
"results_scale": out['results_scale'],
"line_imgs": out['line_imgs'],
}
def align_to_gt_lines(decoded_hw, gt_lines):
costs = []
for i in range(len(decoded_hw)):
costs.append([])
for j in range(len(gt_lines)):
pred = decoded_hw[i]
gt = gt_lines[j]
cer = error_rates.cer(gt, pred)
costs[i].append(cer)
costs = np.array(costs)
min_idx = costs.argmin(axis=0)
min_val = costs.min(axis=0)
return min_idx, min_val
|