Spaces:
Sleeping
Sleeping
| from utils import string_utils, error_rates | |
| import numpy as np | |
| from . import nms | |
| import copy | |
| def get_trimmed_polygons(out): | |
| all_polygons = [] | |
| for j in range(out['lf'][0].shape[0]): | |
| begin = out['beginning'][j] | |
| end = out['ending'][j] | |
| last_xy = None | |
| begin_f = int(np.floor(begin)) | |
| end_f = int(np.ceil(end)) | |
| points = [] | |
| for i in range(begin_f, end_f+1): | |
| if i == begin_f: | |
| p0 = out['lf'][i][j] | |
| p1 = out['lf'][i+1][j] | |
| t = begin - np.floor(begin) | |
| p = p0 * (1 - t) + p1 * t | |
| elif i == end_f: | |
| p0 = out['lf'][i-1][j] | |
| if i != len(out['lf']): | |
| p1 = out['lf'][i][j] | |
| t = end - np.floor(end) | |
| p = p0 * (1 - t) + p1 * t | |
| else: | |
| p = p0 | |
| else: | |
| p = out['lf'][i][j] | |
| points.append(p) | |
| points = np.array(points) | |
| all_polygons.append(points) | |
| return all_polygons | |
| def trim_ends(out): | |
| lf_length = len(out['lf']) | |
| hw = out['hw'] | |
| # Mehreen: hw is (14, 361, 197) a 14x361 matrix for each character. selected is (14, 361) | |
| selected = hw.argmax(axis=-1) | |
| beginning = np.argmax(selected != 0, axis=1) | |
| ending = selected.shape[1] - 1 - np.argmax(selected[:,::-1] != 0, axis=1) | |
| beginning_percent = (beginning+0.5) / float(selected.shape[1]) | |
| ending_percent = (ending+0.5) / float(selected.shape[1]) | |
| lf_beginning = lf_length * beginning_percent | |
| lf_ending = lf_length * ending_percent | |
| out['beginning'] = lf_beginning | |
| out['ending'] = lf_ending | |
| return out | |
| def filter_on_pick(out, pick): | |
| out['sol'] = out['sol'][pick] | |
| out['lf'] = [l[pick] for l in out['lf']] | |
| out['hw'] = out['hw'][pick] | |
| if 'idx' in out: | |
| out['idx'] = out['idx'][pick] | |
| if 'beginning' in out: | |
| out['beginning'] = out['beginning'][pick] | |
| if 'ending' in out: | |
| out['ending'] = out['ending'][pick] | |
| def filter_on_pick_no_copy(out, pick): | |
| output = {} | |
| output['sol'] = out['sol'][pick] | |
| output['lf'] = [l[pick] for l in out['lf']] | |
| output['hw'] = out['hw'][pick] | |
| ##Mehreen | |
| #print(pick) | |
| #out['line_imgs'] = out['line_imgs'][pick] | |
| ## End mehreen | |
| if 'idx' in out: | |
| output['idx'] = out['idx'][pick] | |
| if 'beginning' in out: | |
| output['beginning'] = out['beginning'][pick] | |
| if 'ending' in out: | |
| output['ending'] = out['ending'][pick] | |
| return output | |
| def select_non_empty_string(out): | |
| selected = out['hw'].argmax(axis=-1) | |
| return np.where(selected.sum(axis=1) != 0) | |
| def postprocess(out, **kwargs): | |
| out = copy.copy(out) | |
| # postprocessing should be done with numpy data | |
| sol_threshold = kwargs.get("sol_threshold", None) | |
| sol_nms_threshold = kwargs.get("sol_nms_threshold", None) | |
| lf_nms_params = kwargs.get('lf_nms_params', None) | |
| lf_nms_2_params = kwargs.get('lf_nms_2_params', None) | |
| if sol_threshold is not None: | |
| pick = np.where(out['sol'][:,-1] > sol_threshold) | |
| filter_on_pick(out, pick) | |
| #Mehreen: this is passed as None from run_hwr from decode_one_img_with_info | |
| if sol_nms_threshold is not None: | |
| raise Exception("This is not correct") | |
| pick = nms.sol_nms_single(out['sol'], sol_nms_threshold) | |
| out['sol'] = out['sol'][pick] | |
| #Mehreen: When post-processing this part is done. sample_config lf_nms_range: [0,6] lf_nms_threshold: 0.5 | |
| if lf_nms_params is not None: | |
| confidences = out['sol'][:,-1] | |
| overlap_range = lf_nms_params['overlap_range'] | |
| overlap_thresh = lf_nms_params['overlap_threshold'] | |
| lf_setup = np.concatenate([l[None,...] for l in out['lf']]) | |
| lf_setup = [lf_setup[:,i] for i in range(lf_setup.shape[1])] | |
| pick = nms.lf_non_max_suppression_area(lf_setup, confidences, overlap_range, overlap_thresh) | |
| filter_on_pick(out, pick) | |
| #Mehreen: When post-processing this part is None from decode_one_img_with_info | |
| if lf_nms_2_params is not None: | |
| confidences = out['sol'][:,-1] | |
| overlap_thresh = lf_nms_2_params['overlap_threshold'] | |
| refined_lf = get_trimmed_polygons(out) | |
| pick = nms.lf_non_max_suppression_area(refined_lf, confidences, None, overlap_thresh) | |
| filter_on_pick(out, pick) | |
| return out | |
| def read_order(out): | |
| first_pt = out['lf'][0][:,:2,0] | |
| first_pt = first_pt[:,::-1] | |
| first_pt = np.concatenate([first_pt, np.arange(first_pt.shape[0])[:,None]], axis=1) | |
| first_pt = first_pt.tolist() | |
| first_pt.sort() | |
| return [int(p[2]) for p in first_pt] | |
| def decode_handwriting(out, idx_to_char): | |
| hw_out = out['hw'] | |
| list_of_pred = [] | |
| list_of_raw_pred = [] | |
| for i in range(hw_out.shape[0]): | |
| logits = hw_out[i,...] | |
| pred, raw_pred = string_utils.naive_decode(logits) | |
| pred_str = string_utils.label2str_single(pred, idx_to_char, False) | |
| raw_pred_str = string_utils.label2str_single(raw_pred, idx_to_char, True) | |
| list_of_pred.append(pred_str) | |
| list_of_raw_pred.append(raw_pred_str) | |
| return list_of_pred, list_of_raw_pred | |
| def results_to_numpy(out): | |
| return { | |
| "sol": out['sol'].data.cpu().numpy()[:,0,:], | |
| "lf": [l.data.cpu().numpy() for l in out['lf']] if out['lf'] is not None else None, | |
| "hw": out['hw'].data.cpu().numpy(), | |
| "results_scale": out['results_scale'], | |
| "line_imgs": out['line_imgs'], | |
| } | |
| def align_to_gt_lines(decoded_hw, gt_lines): | |
| costs = [] | |
| for i in range(len(decoded_hw)): | |
| costs.append([]) | |
| for j in range(len(gt_lines)): | |
| pred = decoded_hw[i] | |
| gt = gt_lines[j] | |
| cer = error_rates.cer(gt, pred) | |
| costs[i].append(cer) | |
| costs = np.array(costs) | |
| min_idx = costs.argmin(axis=0) | |
| min_val = costs.min(axis=0) | |
| return min_idx, min_val | |