Spaces:
Build error
Build error
msaeed3 commited on
Commit ·
e295beb
1
Parent(s): 6f01ce4
version 1.0
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- arabic/decode_one_image.py +225 -0
- arabic/page_htr.py +283 -0
- arabic/post_process_routines.py +301 -0
- arabic/test_hw_helper_routines.py +147 -0
- arabic/warp_routines.py +368 -0
- coords/__init__.py +0 -0
- coords/points.py +342 -0
- coords/poly_routines.py +195 -0
- coords/text_cleaning_routines.py +82 -0
- coords/text_gt.py +101 -0
- model/trial_26_A/muharaf_charset.json +320 -0
- model/trial_26_A/set0/config_2600.yaml +25 -0
- model/trial_26_A/set0/pretrain/hw.pt +3 -0
- model/trial_26_A/set0/pretrain/lf.pt +3 -0
- model/trial_26_A/set0/pretrain/sol.pt +3 -0
- py3/e2e/__init__.py +0 -0
- py3/e2e/alignment_dataset.py +69 -0
- py3/e2e/e2e_model.py +207 -0
- py3/e2e/e2e_postprocessing.py +182 -0
- py3/e2e/forward_pass.py +86 -0
- py3/e2e/handwriting_alignment_loss.py +125 -0
- py3/e2e/nms.py +162 -0
- py3/e2e/validation_utils.py +137 -0
- py3/e2e/visualization.py +176 -0
- py3/hw/__init__.py +0 -0
- py3/hw/cnn_lstm.py +117 -0
- py3/lf/__init__.py +0 -0
- py3/lf/fast_patch_view.py +96 -0
- py3/lf/lf_cnn.py +45 -0
- py3/lf/line_follower.py +181 -0
- py3/lf/models/__init__.py +36 -0
- py3/lf/models/res_unet.py +147 -0
- py3/lf/models/resnet.py +335 -0
- py3/lf/models/tools.py +144 -0
- py3/lf/stn/__init__.py +0 -0
- py3/lf/stn/gridgen.py +126 -0
- py3/sol/__init__.py +0 -0
- py3/sol/crop_transform.py +35 -0
- py3/sol/crop_utils.py +48 -0
- py3/sol/start_of_line_finder.py +42 -0
- py3/sol/vgg.py +157 -0
- py3/utils/__init__.py +0 -0
- py3/utils/character_set.ipynb +539 -0
- py3/utils/character_set.py +61 -0
- py3/utils/continuous_state.py +87 -0
- py3/utils/dataset_parse.py +17 -0
- py3/utils/dataset_wrapper.py +27 -0
- py3/utils/error_rates.py +21 -0
- py3/utils/fast_inverse.py +58 -0
- py3/utils/safe_load.py +30 -0
arabic/decode_one_image.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from utils.continuous_state import init_model
|
| 6 |
+
from e2e import e2e_model, e2e_postprocessing, visualization
|
| 7 |
+
from e2e.e2e_model import E2EModel
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.autograd import Variable
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import cv2
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
import codecs
|
| 18 |
+
import yaml
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
import operator
|
| 23 |
+
import pandas as pd
|
| 24 |
+
from utils import error_rates
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
import argparse
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Network output on one image
|
| 30 |
+
# Will read from file if org_img is none
|
| 31 |
+
def network_output(config_file, image_path, model_mode = "best_overall",
|
| 32 |
+
flip=False, use_unet=False, org_img=None, device="cuda"):
|
| 33 |
+
|
| 34 |
+
with open(config_file) as f:
|
| 35 |
+
config = yaml.load(f, Loader=yaml.Loader)
|
| 36 |
+
|
| 37 |
+
if use_unet:
|
| 38 |
+
config['network']['lf']['u_net'] = True
|
| 39 |
+
#print('config changed')
|
| 40 |
+
|
| 41 |
+
char_set_path = config['network']['hw']['char_set_path']
|
| 42 |
+
### Change hw's num_of_outputs in config
|
| 43 |
+
with open(char_set_path) as f:
|
| 44 |
+
char_set = json.load(f)
|
| 45 |
+
|
| 46 |
+
config["network"]["hw"]["num_of_outputs"] = len(char_set['idx_to_char']) + 1
|
| 47 |
+
|
| 48 |
+
dtype =torch.FloatTensor
|
| 49 |
+
if 'cuda' in device:
|
| 50 |
+
dtype =torch.cuda.FloatTensor
|
| 51 |
+
|
| 52 |
+
sol, lf, hw = init_model(config, sol_dir=model_mode, lf_dir=model_mode, hw_dir=model_mode,
|
| 53 |
+
device=device)
|
| 54 |
+
|
| 55 |
+
e2e = E2EModel(sol, lf, hw, dtype=dtype, device=device)
|
| 56 |
+
|
| 57 |
+
e2e.eval()
|
| 58 |
+
|
| 59 |
+
if org_img is None:
|
| 60 |
+
org_img = cv2.imread(image_path)
|
| 61 |
+
if flip:
|
| 62 |
+
org_img = cv2.flip(org_img, 1)
|
| 63 |
+
|
| 64 |
+
target_dim1 = 512
|
| 65 |
+
s = target_dim1 / float(org_img.shape[1])
|
| 66 |
+
|
| 67 |
+
pad_amount = 128
|
| 68 |
+
org_img = np.pad(org_img, ((pad_amount,pad_amount),(pad_amount,pad_amount), (0,0)), 'constant', constant_values=255)
|
| 69 |
+
before_padding = org_img
|
| 70 |
+
|
| 71 |
+
target_dim0 = int(org_img.shape[0] * s)
|
| 72 |
+
target_dim1 = int(org_img.shape[1] * s)
|
| 73 |
+
|
| 74 |
+
full_img = org_img.astype(np.float32)
|
| 75 |
+
full_img = full_img.transpose([2,1,0])[None,...]
|
| 76 |
+
full_img = torch.from_numpy(full_img)
|
| 77 |
+
full_img = full_img / 128 - 1
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
img = cv2.resize(org_img,(target_dim1, target_dim0), interpolation = cv2.INTER_CUBIC)
|
| 81 |
+
img = img.astype(np.float32)
|
| 82 |
+
img = img.transpose([2,1,0])[None,...]
|
| 83 |
+
img = torch.from_numpy(img)
|
| 84 |
+
img = img / 128 - 1
|
| 85 |
+
|
| 86 |
+
out = e2e.forward({
|
| 87 |
+
"resized_img": img,
|
| 88 |
+
"full_img": full_img,
|
| 89 |
+
"resize_scale": 1.0/s
|
| 90 |
+
}, use_full_img=True, device=device)
|
| 91 |
+
|
| 92 |
+
out = e2e_postprocessing.results_to_numpy(out)
|
| 93 |
+
|
| 94 |
+
if out is None:
|
| 95 |
+
print ("No Results")
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
# take into account the padding
|
| 99 |
+
out['sol'][:,:2] = out['sol'][:,:2] - pad_amount
|
| 100 |
+
for l in out['lf']:
|
| 101 |
+
l[:,:2,:2] = l[:,:2,:2] - pad_amount
|
| 102 |
+
|
| 103 |
+
out['image_path'] = image_path
|
| 104 |
+
|
| 105 |
+
return out
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def decode_one_img_with_info(config_path, out, visualize=False, flip=False, org_img=None, device="cuda"):
|
| 109 |
+
|
| 110 |
+
with open(config_path) as f:
|
| 111 |
+
config = yaml.load(f, Loader=yaml.Loader)
|
| 112 |
+
|
| 113 |
+
char_set_path = config['network']['hw']['char_set_path']
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
with open(char_set_path) as f:
|
| 118 |
+
char_set = json.load(f)
|
| 119 |
+
|
| 120 |
+
idx_to_char = {}
|
| 121 |
+
for k,v in char_set['idx_to_char'].items():
|
| 122 |
+
idx_to_char[int(k)] = v
|
| 123 |
+
|
| 124 |
+
out = dict(out)
|
| 125 |
+
|
| 126 |
+
image_path = str(out['image_path'])
|
| 127 |
+
#print(image_path)
|
| 128 |
+
if org_img is None:
|
| 129 |
+
org_img = cv2.imread(image_path)
|
| 130 |
+
if flip:
|
| 131 |
+
org_img = cv2.flip(org_img, 1)
|
| 132 |
+
|
| 133 |
+
# Postprocessing Steps
|
| 134 |
+
out['idx'] = np.arange(out['sol'].shape[0])
|
| 135 |
+
out = e2e_postprocessing.trim_ends(out)
|
| 136 |
+
e2e_postprocessing.filter_on_pick(out, e2e_postprocessing.select_non_empty_string(out))
|
| 137 |
+
out = e2e_postprocessing.postprocess(out,
|
| 138 |
+
sol_threshold=config['post_processing']['sol_threshold'],
|
| 139 |
+
lf_nms_params={
|
| 140 |
+
"overlap_range": config['post_processing']['lf_nms_range'],
|
| 141 |
+
"overlap_threshold": config['post_processing']['lf_nms_threshold']
|
| 142 |
+
}
|
| 143 |
+
)
|
| 144 |
+
order = e2e_postprocessing.read_order(out)
|
| 145 |
+
e2e_postprocessing.filter_on_pick(out, order)
|
| 146 |
+
|
| 147 |
+
# Get output strings and CER
|
| 148 |
+
output_strings = []
|
| 149 |
+
output_strings, decoded_raw_hw = e2e_postprocessing.decode_handwriting(out, idx_to_char)
|
| 150 |
+
return out, output_strings
|
| 151 |
+
|
| 152 |
+
def write_line_images(images, parent_img_fullpath, result_dir='Result', flip=True):
|
| 153 |
+
directory = os.path.dirname(parent_img_fullpath)
|
| 154 |
+
parent_basename = os.path.basename(parent_img_fullpath)
|
| 155 |
+
dir_basename = parent_basename[0:parent_basename.rfind('_')]
|
| 156 |
+
result_dir = os.path.join(directory, result_dir)
|
| 157 |
+
new_directory = os.path.join(directory, result_dir, dir_basename)
|
| 158 |
+
if not os.path.exists(result_dir):
|
| 159 |
+
os.mkdir(result_dir)
|
| 160 |
+
if not os.path.exists(new_directory):
|
| 161 |
+
os.mkdir(new_directory)
|
| 162 |
+
# Get rid of file extension
|
| 163 |
+
parent_basename = parent_basename[0:parent_basename.rfind('.')]
|
| 164 |
+
for ind, img in enumerate(images):
|
| 165 |
+
|
| 166 |
+
filename = os.path.join(new_directory,
|
| 167 |
+
parent_basename + '_' + str(ind) + '.png')
|
| 168 |
+
|
| 169 |
+
if flip:
|
| 170 |
+
img = cv2.flip(img, 1)
|
| 171 |
+
cv2.imwrite(filename, img)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# Write a one time header if needed
|
| 175 |
+
def write_empty(result_file):
|
| 176 |
+
result_df = pd.DataFrame(columns = ['image_file', 'ground_truth','prediction',
|
| 177 |
+
'region_type',
|
| 178 |
+
# 'gt_poly',
|
| 179 |
+
'lf_points', 'beginning', 'ending', 'SOL'])
|
| 180 |
+
result_df.to_csv(result_file, index=False)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# TODO: Verify that (x,y) are the first two of SOL output
|
| 184 |
+
def add_offset_to_sol(sol, offset):
|
| 185 |
+
for ind in range(len(sol)):
|
| 186 |
+
sol[ind][0] += offset[0]
|
| 187 |
+
sol[ind][1] += offset[1]
|
| 188 |
+
return sol
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def add_offset_to_lf(lf, offset):
|
| 192 |
+
for pt_ind in range(len(lf)):
|
| 193 |
+
for line_ind in range(len(lf[0])):
|
| 194 |
+
lf[pt_ind][line_ind][0][0] = lf[pt_ind][line_ind][0][0] + offset[0]
|
| 195 |
+
lf[pt_ind][line_ind][1][0] = lf[pt_ind][line_ind][1][0] + offset[1]
|
| 196 |
+
lf[pt_ind][line_ind][0][1] = lf[pt_ind][line_ind][0][1] + offset[0]
|
| 197 |
+
lf[pt_ind][line_ind][1][1] = lf[pt_ind][line_ind][1][1] + offset[1]
|
| 198 |
+
return lf
|
| 199 |
+
|
| 200 |
+
# The merged output will be in out1
|
| 201 |
+
def merge_out(out1, out2, offset):
|
| 202 |
+
lf1 = out1['lf']
|
| 203 |
+
lf2 = add_offset_to_lf(out2['lf'], offset)
|
| 204 |
+
out1['lf'].extend(lf2)
|
| 205 |
+
|
| 206 |
+
out1['beginning'] = np.concatenate((out1['beginning'], out2['beginning']))
|
| 207 |
+
out1['ending'] = np.concatenate((out1['ending'], out2['ending']))
|
| 208 |
+
|
| 209 |
+
sol2 = add_offset_to_sol(out2['sol'], offset)
|
| 210 |
+
out1['sol'] = np.vstack((out1['sol'], sol2))
|
| 211 |
+
return out1
|
| 212 |
+
|
| 213 |
+
def split_image_horizontal(img_file):
|
| 214 |
+
f1 = img_file[:-4] + '_1' + '.jpg'
|
| 215 |
+
f2 = img_file[:-4] + '_2' + '.jpg'
|
| 216 |
+
img = cv2.imread(img_file)
|
| 217 |
+
[ht, width, colors] = img.shape
|
| 218 |
+
ht1 = int(ht/2)
|
| 219 |
+
img1 = img[:ht1, :, :]
|
| 220 |
+
img2 = img[ht1:, :, :]
|
| 221 |
+
cv2.imwrite(f1, img1)
|
| 222 |
+
cv2.imwrite(f2, img2)
|
| 223 |
+
return f1, f2, ht1
|
| 224 |
+
|
| 225 |
+
|
arabic/page_htr.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('py3/')
|
| 3 |
+
sys.path.append('coords')
|
| 4 |
+
import test_hw_helper_routines as test_hw
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import cv2
|
| 10 |
+
import sys
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import numpy as np
|
| 13 |
+
import decode_one_image as decode
|
| 14 |
+
import post_process_routines as post
|
| 15 |
+
import points
|
| 16 |
+
import warp_routines as warp
|
| 17 |
+
from datetime import datetime, timezone
|
| 18 |
+
from utils import error_rates
|
| 19 |
+
import text_cleaning_routines as clean
|
| 20 |
+
import time
|
| 21 |
+
import argparse
|
| 22 |
+
|
| 23 |
+
def add_meta(json_obj):
|
| 24 |
+
|
| 25 |
+
# Get the current date and time in UTC
|
| 26 |
+
now_utc = datetime.now(timezone.utc)
|
| 27 |
+
# Format the date and time as a string
|
| 28 |
+
formatted_date_utc = now_utc.strftime('%Y-%m-%dT%H:%M:%S')
|
| 29 |
+
|
| 30 |
+
json_obj['timeStamps'] = {"created": formatted_date_utc,
|
| 31 |
+
"lastEdited": "",
|
| 32 |
+
"submitted": "",
|
| 33 |
+
"checked": ""
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
json_obj["annotators"] = {
|
| 37 |
+
"creator": "SFR",
|
| 38 |
+
"lastEditor": "",
|
| 39 |
+
"transcriber": "",
|
| 40 |
+
"transcription_QA": "",
|
| 41 |
+
"transcription_tagging": "",
|
| 42 |
+
"transcription_tagging_QA": "" }
|
| 43 |
+
|
| 44 |
+
return json_obj
|
| 45 |
+
|
| 46 |
+
def reset_time(json_obj):
|
| 47 |
+
json_obj["time"] = 0
|
| 48 |
+
for line in json_obj:
|
| 49 |
+
if line.startswith("line_"):
|
| 50 |
+
json_obj[line]["transcribeTime"] = 0
|
| 51 |
+
json_obj[line]["annotateTime"] = 0
|
| 52 |
+
json_obj[line]["edited"] = "0"
|
| 53 |
+
return json_obj
|
| 54 |
+
|
| 55 |
+
def get_hw(config_file, device="cuda"):
|
| 56 |
+
config = test_hw.get_config(config_file)
|
| 57 |
+
|
| 58 |
+
idx_to_char = test_hw.load_char_set(config['network']['hw']['char_set_path'])
|
| 59 |
+
if 'hw_to_save' in config['pretraining'].keys():
|
| 60 |
+
pt_file = config['pretraining']['hw_to_save']
|
| 61 |
+
else:
|
| 62 |
+
pt_file = 'hw.pt'
|
| 63 |
+
pt_filename = os.path.join(config['pretraining']['snapshot_path'], pt_file)
|
| 64 |
+
|
| 65 |
+
config["network"]["hw"]["num_of_outputs"] = len(idx_to_char) + 1
|
| 66 |
+
|
| 67 |
+
print('...Using snapshot', pt_filename)
|
| 68 |
+
HW = test_hw.load_HW(config['network']['hw'], pt_filename)
|
| 69 |
+
device = torch.device(device)
|
| 70 |
+
HW.to(device)
|
| 71 |
+
HW.eval()
|
| 72 |
+
return HW, idx_to_char
|
| 73 |
+
|
| 74 |
+
def sort_lines(json_obj, copy_lines_with_text_only=False):
|
| 75 |
+
top_left = []
|
| 76 |
+
keys = []
|
| 77 |
+
for k, v in json_obj.items():
|
| 78 |
+
if k.startswith('line_'):
|
| 79 |
+
keys.append(k)
|
| 80 |
+
poly = np.array(points.list_to_xy(v['coord']))
|
| 81 |
+
top_left.append([np.max(poly, 0)[0], np.min(poly, 0)[1]])
|
| 82 |
+
sorted_indices = sorted(range(len(top_left)),
|
| 83 |
+
key=lambda i: (top_left[i][1], -top_left[i][0]))
|
| 84 |
+
|
| 85 |
+
sorted_json = dict()
|
| 86 |
+
# Copy all non-line keys
|
| 87 |
+
for k, v in json_obj.items():
|
| 88 |
+
if not k.startswith('line_'):
|
| 89 |
+
sorted_json[k] = v
|
| 90 |
+
# Copy all lines
|
| 91 |
+
for i, ind in enumerate(sorted_indices):
|
| 92 |
+
if copy_lines_with_text_only and len(json_obj[keys[ind]]['text']) == 0:
|
| 93 |
+
continue
|
| 94 |
+
sorted_json[f'line_{i + 1}'] = json_obj[keys[ind]]
|
| 95 |
+
|
| 96 |
+
return sorted_json
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# This will not do line annotations...only transcriptions
|
| 100 |
+
def complete_annotations_for_directory(input_dir, config_file,
|
| 101 |
+
annotator, model_mode="pretrain",
|
| 102 |
+
do_all=False):
|
| 103 |
+
total_done = 0
|
| 104 |
+
files = os.listdir(input_dir)
|
| 105 |
+
files.sort()
|
| 106 |
+
HW, idx_to_char = get_hw(config_file)
|
| 107 |
+
|
| 108 |
+
for f in files:
|
| 109 |
+
if not f.lower().endswith('.jpg'):
|
| 110 |
+
continue
|
| 111 |
+
img_file = os.path.join(input_dir, f)
|
| 112 |
+
json_file = img_file[:-4] + '_annotate_' + annotator + '.json'
|
| 113 |
+
if not os.path.exists(json_file):
|
| 114 |
+
print('No Json for', img_file)
|
| 115 |
+
continue
|
| 116 |
+
print('doing', img_file)
|
| 117 |
+
|
| 118 |
+
with open(json_file) as fin:
|
| 119 |
+
json_obj = json.load(fin)
|
| 120 |
+
# Add meta information and reset the timings in json
|
| 121 |
+
#json_obj = add_meta(json_obj)
|
| 122 |
+
#json_obj = reset_time(json_obj)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
for line, values in json_obj.items():
|
| 126 |
+
if not line.startswith('line_'):
|
| 127 |
+
continue
|
| 128 |
+
if not do_all and len(values['text']) > 0:
|
| 129 |
+
continue
|
| 130 |
+
img = cv2.imread(img_file)
|
| 131 |
+
line_img = warp.get_line_image(values['coord'], img)
|
| 132 |
+
line_text = test_hw.get_predicted_str(HW, None, idx_to_char, flip=True,
|
| 133 |
+
img=line_img, read_image=False)
|
| 134 |
+
|
| 135 |
+
line_text_logical_order = clean.get_clean_visual_order(line_text)
|
| 136 |
+
json_obj[line]['text'] = line_text_logical_order
|
| 137 |
+
|
| 138 |
+
json_obj = sort_lines(json_obj)
|
| 139 |
+
with open(json_file, 'w') as fout:
|
| 140 |
+
json.dump(json_obj, fout, indent=2)
|
| 141 |
+
|
| 142 |
+
total_done += 1
|
| 143 |
+
return total_done
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# This will do bulk annotation in the whole dir
|
| 147 |
+
def predict_annotations_for_directory(input_dir, config_file, annotator, model_mode="pretrain",
|
| 148 |
+
skip_if_json_exists=False, device="cuda"):
|
| 149 |
+
print('skip_if_json_exists', skip_if_json_exists)
|
| 150 |
+
files = os.listdir(input_dir)
|
| 151 |
+
files.sort()
|
| 152 |
+
done = 0
|
| 153 |
+
for f in files:
|
| 154 |
+
if not f.lower().endswith('.jpg'):
|
| 155 |
+
continue
|
| 156 |
+
img_file = os.path.join(input_dir, f)
|
| 157 |
+
image_arr = cv2.imread(img_file)
|
| 158 |
+
#plt.imshow(image_arr)
|
| 159 |
+
|
| 160 |
+
json_file = img_file[:-4] + '_annotate_' + annotator + '.json'
|
| 161 |
+
if os.path.exists(json_file) and skip_if_json_exists:
|
| 162 |
+
print('already done', json_file)
|
| 163 |
+
continue
|
| 164 |
+
print('doing', img_file)
|
| 165 |
+
out = decode.network_output(config_file, img_file, flip=True, model_mode=model_mode, device=device)
|
| 166 |
+
out, predicted_text = decode.decode_one_img_with_info(config_file, out, flip=True, device=device)
|
| 167 |
+
|
| 168 |
+
poly_list = post.get_polygon_list_tuples(out)
|
| 169 |
+
|
| 170 |
+
# Get rid of degenerate points
|
| 171 |
+
to_del_ind = []
|
| 172 |
+
for ind, p in enumerate(poly_list):
|
| 173 |
+
if len(p) < 3:
|
| 174 |
+
to_del_ind.append(ind)
|
| 175 |
+
|
| 176 |
+
if len(to_del_ind) > 0:
|
| 177 |
+
print('Deleting poly at index', to_del_ind)
|
| 178 |
+
poly_list = [poly_list[i] for i in range(len(poly_list)) if i not in to_del_ind]
|
| 179 |
+
predicted_text = [predicted_text[i] for i in range(len(predicted_text)) if i not in to_del_ind]
|
| 180 |
+
|
| 181 |
+
del_list, poly_list = post.get_poly_no_overlap(img_file, poly_list, 0.7)
|
| 182 |
+
|
| 183 |
+
if len(del_list) > 0:
|
| 184 |
+
print('polygons deleted', len(del_list), del_list)
|
| 185 |
+
print(len(poly_list))
|
| 186 |
+
|
| 187 |
+
predicted_text = [predicted_text[i] for i in range(len(predicted_text)) if i not in del_list]
|
| 188 |
+
poly_list = post.flip_polygon(img_file, poly_list)
|
| 189 |
+
#post.draw_image_with_poly("", img_file, poly_list, convert=False)
|
| 190 |
+
page_json = post.create_annotations_json(predicted_text, poly_list)
|
| 191 |
+
|
| 192 |
+
with open(json_file, 'w') as fout:
|
| 193 |
+
json.dump(page_json, fout)
|
| 194 |
+
done += 1
|
| 195 |
+
|
| 196 |
+
return done
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def page_htr_one_file(img_file, config_file, model_mode="pretrain", device="cuda"):
|
| 200 |
+
|
| 201 |
+
image_arr = cv2.imread(img_file)
|
| 202 |
+
|
| 203 |
+
out = decode.network_output(config_file, img_file, flip=True, model_mode=model_mode, device=device)
|
| 204 |
+
out, predicted_text = decode.decode_one_img_with_info(config_file, out, flip=True, device=device)
|
| 205 |
+
|
| 206 |
+
poly_list = post.get_polygon_list_tuples(out)
|
| 207 |
+
|
| 208 |
+
# Get rid of degenerate points
|
| 209 |
+
to_del_ind = []
|
| 210 |
+
for ind, p in enumerate(poly_list):
|
| 211 |
+
if len(p) < 3:
|
| 212 |
+
to_del_ind.append(ind)
|
| 213 |
+
|
| 214 |
+
if len(to_del_ind) > 0:
|
| 215 |
+
#print('Deleting poly at index', to_del_ind)
|
| 216 |
+
poly_list = [poly_list[i] for i in range(len(poly_list)) if i not in to_del_ind]
|
| 217 |
+
predicted_text = [predicted_text[i] for i in range(len(predicted_text)) if i not in to_del_ind]
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
del_list, poly_list = post.get_poly_no_overlap(img_file, poly_list, 0.7)
|
| 221 |
+
predicted_text = [predicted_text[i] for i in range(len(predicted_text)) if i not in del_list]
|
| 222 |
+
predicted_text = [clean.get_clean_visual_order(txt) for txt in predicted_text]
|
| 223 |
+
poly_list = post.flip_polygon(img_file, poly_list)
|
| 224 |
+
#post.draw_image_with_poly("", img_file, poly_list, convert=False)
|
| 225 |
+
page_json = post.create_annotations_json(predicted_text, poly_list)
|
| 226 |
+
|
| 227 |
+
torch.cuda.empty_cache()
|
| 228 |
+
return page_json
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def hw_one_file(img_file, config_file, json_obj, model_mode="pretrain", line_key=None):
|
| 232 |
+
|
| 233 |
+
HW, idx_to_char = get_hw(config_file)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
for line, values in json_obj.items():
|
| 237 |
+
if not line.startswith('line_'):
|
| 238 |
+
continue
|
| 239 |
+
# IF line_key is specified then modify only that line
|
| 240 |
+
if line_key is not None and len(line_key) > 0. and line != line_key:
|
| 241 |
+
continue
|
| 242 |
+
|
| 243 |
+
img = cv2.imread(img_file)
|
| 244 |
+
line_img = warp.get_line_image(values['coord'], img)
|
| 245 |
+
line_text = test_hw.get_predicted_str(HW, None, idx_to_char, flip=True,
|
| 246 |
+
img=line_img, read_image=False)
|
| 247 |
+
line_text = make_manual_text_correction(line_text)
|
| 248 |
+
# New change
|
| 249 |
+
line_text_logical_order = clean.get_clean_visual_order(line_text)
|
| 250 |
+
json_obj[line]['text'] = line_text_logical_order
|
| 251 |
+
|
| 252 |
+
json_obj = sort_lines(json_obj)
|
| 253 |
+
torch.cuda.empty_cache()
|
| 254 |
+
return json_obj
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == "__main__":
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
parser = argparse.ArgumentParser(description="Run HTR module")
|
| 262 |
+
|
| 263 |
+
parser.add_argument("--line_htr", type=int, required=True, help="If 1, do line_htr else do page_htr")
|
| 264 |
+
parser.add_argument("--img_path", type=str, required=True, help="Image path")
|
| 265 |
+
parser.add_argument("--config_file", type=str, required=True, help="SFR_Arabic config file")
|
| 266 |
+
parser.add_argument("--original_json", type=str, required=True, help="Original JSON")
|
| 267 |
+
parser.add_argument("--line_key", type=str, required=True, help="line key")
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
args = parser.parse_args()
|
| 271 |
+
json_obj = {}
|
| 272 |
+
if args.line_htr == 1:
|
| 273 |
+
json_obj = json.loads(args.original_json)
|
| 274 |
+
json_obj = hw_one_file(args.img_path, args.config_file, json_obj,
|
| 275 |
+
model_mode="pretrain", line_key=args.line_key)
|
| 276 |
+
else:
|
| 277 |
+
json_obj = json.loads(args.original_json)
|
| 278 |
+
json_obj = page_htr_one_file(args.img_path, args.config_file, device="cuda")
|
| 279 |
+
|
| 280 |
+
print('BEGIN_OUT')
|
| 281 |
+
print(json.dumps(json_obj))
|
| 282 |
+
|
| 283 |
+
# python3 arabic/page_htr.py --line_htr 0 --img_path ../../datasets/kclds/KEllis/bk_on_server/bk/KEllis2018-150a.jpg --config_file model/trial_26_A/set0/config_2600.yaml --original_json {} --line_key 0
|
arabic/post_process_routines.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import json
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append('coords/')
|
| 6 |
+
import points
|
| 7 |
+
from PIL import Image, ImageDraw
|
| 8 |
+
import os
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
from utils import error_rates
|
| 12 |
+
from arabic_reshaper import ArabicReshaper
|
| 13 |
+
|
| 14 |
+
# This file has routines for working with CERs and polygons
|
| 15 |
+
|
| 16 |
+
def correct_pt(value, max_value):
|
| 17 |
+
boundary = False
|
| 18 |
+
if value < 0:
|
| 19 |
+
value = 0
|
| 20 |
+
boundary = True
|
| 21 |
+
if value >= max_value:
|
| 22 |
+
value = max_value - 1
|
| 23 |
+
boundary = True
|
| 24 |
+
return [value, boundary]
|
| 25 |
+
|
| 26 |
+
# Each polygon is a list of (x,y) tuples
|
| 27 |
+
def get_polygon_list_tuples(out):
|
| 28 |
+
img = cv2.imread(out["image_path"])
|
| 29 |
+
img_height, img_width = img.shape[:2]
|
| 30 |
+
polygon_list = []
|
| 31 |
+
prev = [-1, -1]
|
| 32 |
+
for line_ind in range(len(out['lf'][0])):
|
| 33 |
+
polygon = []
|
| 34 |
+
begin_ind = out['beginning'][line_ind]
|
| 35 |
+
end_ind = out['ending'][line_ind]
|
| 36 |
+
begin_ind = int(np.floor(begin_ind))
|
| 37 |
+
end_ind = int(np.ceil(end_ind))
|
| 38 |
+
end_ind = min(end_ind, len(out['lf'])-1)
|
| 39 |
+
for pt_ind in range(begin_ind, end_ind+1):
|
| 40 |
+
pt_x = float(out['lf'][pt_ind][line_ind][0][0])
|
| 41 |
+
pt_y = float(out['lf'][pt_ind][line_ind][1][0])
|
| 42 |
+
pt_x, boundary_x = correct_pt(pt_x, img_width)
|
| 43 |
+
pt_y, boundary_y = correct_pt(pt_y, img_height)
|
| 44 |
+
if prev != [pt_x, pt_y]:
|
| 45 |
+
polygon.append((pt_x, pt_y))
|
| 46 |
+
prev = [pt_x, pt_y]
|
| 47 |
+
for pt_ind in range(end_ind, begin_ind-1, -1):
|
| 48 |
+
pt_x = float(out['lf'][pt_ind][line_ind][0][1])
|
| 49 |
+
pt_y = float(out['lf'][pt_ind][line_ind][1][1])
|
| 50 |
+
pt_x, boundary_x = correct_pt(pt_x, img_width)
|
| 51 |
+
pt_y, boundary_y = correct_pt(pt_y, img_height)
|
| 52 |
+
if prev != [pt_x, pt_y]:
|
| 53 |
+
polygon.append((pt_x, pt_y))
|
| 54 |
+
prev = [pt_x, pt_y]
|
| 55 |
+
|
| 56 |
+
polygon_list.append(polygon)
|
| 57 |
+
if len(polygon) < 3:
|
| 58 |
+
print('WARNING: DEGENERATE POLYGON AT INDEX', len(polygon_list))
|
| 59 |
+
return polygon_list
|
| 60 |
+
|
| 61 |
+
# Each polygon is a list of (x,y) tuples
|
| 62 |
+
def get_polygon_list_without_trim(out):
|
| 63 |
+
img = cv2.imread(out["image_path"])
|
| 64 |
+
img_height, img_width = img.shape[:2]
|
| 65 |
+
polygon_list = []
|
| 66 |
+
for line_ind in range(len(out['lf'][0])):
|
| 67 |
+
polygon = []
|
| 68 |
+
begin_ind = 0
|
| 69 |
+
end_ind = len(out['lf'])-1
|
| 70 |
+
prev = [-1, -1]
|
| 71 |
+
|
| 72 |
+
for pt_ind in range(begin_ind, end_ind+1):
|
| 73 |
+
pt_x = float(out['lf'][pt_ind][line_ind][0][0])
|
| 74 |
+
pt_y = float(out['lf'][pt_ind][line_ind][1][0])
|
| 75 |
+
pt_x, boundary_x = correct_pt(pt_x, img_width)
|
| 76 |
+
pt_y, boundary_y = correct_pt(pt_y, img_height)
|
| 77 |
+
if prev != [pt_x, pt_y]:
|
| 78 |
+
polygon.append((pt_x, pt_y))
|
| 79 |
+
prev = [pt_x, pt_y]
|
| 80 |
+
for pt_ind in range(end_ind, begin_ind-1, -1):
|
| 81 |
+
pt_x = float(out['lf'][pt_ind][line_ind][0][1])
|
| 82 |
+
pt_y = float(out['lf'][pt_ind][line_ind][1][1])
|
| 83 |
+
pt_x, boundary_x = correct_pt(pt_x, img_width)
|
| 84 |
+
pt_y, boundary_y = correct_pt(pt_y, img_height)
|
| 85 |
+
if prev != [pt_x, pt_y]:
|
| 86 |
+
polygon.append((pt_x, pt_y))
|
| 87 |
+
prev = [pt_x, pt_y]
|
| 88 |
+
|
| 89 |
+
if len(polygon) >= 3:
|
| 90 |
+
polygon_list.append(polygon)
|
| 91 |
+
return polygon_list
|
| 92 |
+
|
| 93 |
+
# Each polygon passed as input is a list of (x,y) tuples
|
| 94 |
+
# Same for output
|
| 95 |
+
def percent_intersection(size, poly1, poly2):
|
| 96 |
+
im1 = Image.new(mode="1", size=size)
|
| 97 |
+
draw1 = ImageDraw.Draw(im1)
|
| 98 |
+
draw1.polygon(poly1, fill=1)
|
| 99 |
+
im2 = Image.new(mode="1", size=size)
|
| 100 |
+
draw2 = ImageDraw.Draw(im2)
|
| 101 |
+
draw2.polygon(poly2, fill=1)
|
| 102 |
+
mask1 = np.asarray(im1, dtype=bool)
|
| 103 |
+
mask2 = np.asarray(im2, dtype=bool)
|
| 104 |
+
intersection_mask = mask1 & mask2
|
| 105 |
+
#plt.imshow(intersection)
|
| 106 |
+
intersection_area = intersection_mask.sum()
|
| 107 |
+
percent1 = intersection_area / mask1.sum()
|
| 108 |
+
percent2 = intersection_area / mask2.sum()
|
| 109 |
+
return intersection_area, percent1, percent2
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def get_poly_no_overlap(img_name, poly_list, threshold=0.6):
|
| 114 |
+
|
| 115 |
+
img=Image.open(img_name)
|
| 116 |
+
size=img.size
|
| 117 |
+
#polygons = [points.list_to_xy(p) for p in poly_list]
|
| 118 |
+
polygons = poly_list
|
| 119 |
+
del_list = []
|
| 120 |
+
current = 0
|
| 121 |
+
next_ind = current+1
|
| 122 |
+
last_deleted = -1
|
| 123 |
+
while next_ind<len(polygons):
|
| 124 |
+
# Check these are not degernate polygons
|
| 125 |
+
if len(polygons[current]) < 3:
|
| 126 |
+
del_list.append(current)
|
| 127 |
+
current, next_ind = (current+1, next_ind+1)
|
| 128 |
+
continue
|
| 129 |
+
if len(polygons[next_ind]) < 3:
|
| 130 |
+
del_list.append(next_ind)
|
| 131 |
+
next_ind += 1
|
| 132 |
+
continue
|
| 133 |
+
# End check
|
| 134 |
+
overlap_area, percent1, percent2 = percent_intersection(size,
|
| 135 |
+
polygons[current],
|
| 136 |
+
polygons[next_ind])
|
| 137 |
+
|
| 138 |
+
if percent1 > threshold or percent2 > threshold:
|
| 139 |
+
to_del = current if percent1 > percent2 else next_ind
|
| 140 |
+
current, next_ind = (current, next_ind+1) if percent1<percent2\
|
| 141 |
+
else (next_ind, next_ind+1)
|
| 142 |
+
del_list.append(to_del)
|
| 143 |
+
last_deleted = to_del
|
| 144 |
+
#print('last deleted', to_del)
|
| 145 |
+
else: # when no overlap is found
|
| 146 |
+
current, next_ind = (current+1, next_ind+1)
|
| 147 |
+
if current <= last_deleted:
|
| 148 |
+
current = last_deleted + 1
|
| 149 |
+
next_ind = current + 1
|
| 150 |
+
all_ind = set(range(len(poly_list)))
|
| 151 |
+
good_ind = all_ind.difference(set(del_list))
|
| 152 |
+
poly_non_overlapping = [poly_list[i] for i in good_ind]
|
| 153 |
+
return del_list, poly_non_overlapping
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def dump_polygons_json(out, polygons = None, filename=None):
|
| 158 |
+
if filename is None:
|
| 159 |
+
filename = out["image_path"][:-3] + "json"
|
| 160 |
+
if polygons is None:
|
| 161 |
+
polygons = get_polygon_list(out)
|
| 162 |
+
lf_dict = {}
|
| 163 |
+
for ind, poly in enumerate(polygons):
|
| 164 |
+
lf_dict['line_' + str(ind+1)] = points.xy_to_list(poly)
|
| 165 |
+
|
| 166 |
+
with open(filename, 'w') as fout:
|
| 167 |
+
json_dumps_str = json.dumps(lf_dict, indent=2)
|
| 168 |
+
#print('....json_dumps_str', json_dumps_str)
|
| 169 |
+
print(json_dumps_str, file=fout)
|
| 170 |
+
|
| 171 |
+
def write_json_file(out, poly_list = None, json_file=None):
|
| 172 |
+
if poly_list is None:
|
| 173 |
+
poly_list = get_polygon_list(out)
|
| 174 |
+
dump_polygons_json(out, poly_list, json_file)
|
| 175 |
+
|
| 176 |
+
def write_text_file(out, predicted_text, filename=None):
|
| 177 |
+
if filename is None:
|
| 178 |
+
filename = out["image_path"][:-3] + "txt"
|
| 179 |
+
prediction_para = '\n'.join(predicted_text)
|
| 180 |
+
with open(filename, 'w') as f:
|
| 181 |
+
f.write(prediction_para)
|
| 182 |
+
|
| 183 |
+
# won't flip the polygons...only the image
|
| 184 |
+
def draw_image_with_poly(directory, image, poly, convert=True, flip=False):
|
| 185 |
+
img = cv2.imread(os.path.join(directory, image))
|
| 186 |
+
if flip:
|
| 187 |
+
img = cv2.flip(img, 1)
|
| 188 |
+
plt.imshow(img)
|
| 189 |
+
colors = ['red', 'green', 'blue']
|
| 190 |
+
|
| 191 |
+
for ind, p in enumerate(poly):
|
| 192 |
+
if convert:
|
| 193 |
+
p = points.list_to_xy(p)
|
| 194 |
+
points.draw_poly(plt, p, colors[ind%3])
|
| 195 |
+
plt.text(p[-1][0], p[-1][1], str(ind))
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
arabic_reshaper_configuration = {
|
| 199 |
+
'delete_harakat': True,
|
| 200 |
+
'support_ligatures': True
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def remove_diacritics(txt):
|
| 205 |
+
reshaper = ArabicReshaper(configuration=arabic_reshaper_configuration)
|
| 206 |
+
txt_without_diacritics = reshaper.reshape(txt)
|
| 207 |
+
return(txt_without_diacritics)
|
| 208 |
+
|
| 209 |
+
def get_cer_matrix(gt_list, prediction_list):
|
| 210 |
+
total_gt = len(gt_list)
|
| 211 |
+
total_predictions = len(prediction_list)
|
| 212 |
+
cer_matrix = np.zeros((total_gt, total_predictions))
|
| 213 |
+
|
| 214 |
+
for ind_g, g in enumerate(gt_list):
|
| 215 |
+
for ind_p, p in enumerate(prediction_list):
|
| 216 |
+
cer_matrix[ind_g, ind_p] = error_rates.cer(g, p)
|
| 217 |
+
if cer_matrix[ind_g, ind_p] > 1:
|
| 218 |
+
cer_matrix[ind_g, ind_p] = 1
|
| 219 |
+
|
| 220 |
+
return cer_matrix
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# https://en.wikipedia.org/wiki/Dynamic_time_warping
|
| 225 |
+
def get_dtw(dist_matrix):
|
| 226 |
+
r,c = dist_matrix.shape
|
| 227 |
+
#dtw = np.zeros((r+1, c+1))
|
| 228 |
+
#for i in range(dtw.shape[0]):
|
| 229 |
+
#for j in range(dtw.shape[1]):
|
| 230 |
+
# dtw[i, j] = np.inf
|
| 231 |
+
|
| 232 |
+
dtw = np.full((r+1, c+1), np.inf)
|
| 233 |
+
dtw[0, 0] = 0
|
| 234 |
+
|
| 235 |
+
for i in range(1, dtw.shape[0]):
|
| 236 |
+
for j in range(1, dtw.shape[1]):
|
| 237 |
+
dtw[i, j] = dist_matrix[i-1, j-1] + np.min([dtw[i-1, j], dtw[i, j-1],
|
| 238 |
+
dtw[i-1, j-1]])
|
| 239 |
+
return dtw[-1, -1]
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def get_cers(gt_list, prediction_list, no_diacritics=False):
|
| 246 |
+
|
| 247 |
+
if no_diacritics:
|
| 248 |
+
gt_list = [remove_diacritics(g) for g in gt_list]
|
| 249 |
+
prediction_list = [remove_diacritics(p) for p in prediction_list]
|
| 250 |
+
|
| 251 |
+
cer_matrix = get_cer_matrix(gt_list, prediction_list)
|
| 252 |
+
cer_gt = cer_matrix.min(axis=1).flatten()
|
| 253 |
+
cer_p = cer_matrix.min(axis=0).flatten()
|
| 254 |
+
cer_dtw = get_dtw(cer_matrix)
|
| 255 |
+
|
| 256 |
+
gt = '\n'.join(gt_list)
|
| 257 |
+
pred = '\n'.join(prediction_list)
|
| 258 |
+
cer_para = error_rates.cer(gt, pred)
|
| 259 |
+
|
| 260 |
+
return cer_dtw, np.sum(cer_gt), np.sum(cer_p), cer_para
|
| 261 |
+
|
| 262 |
+
def get_cers_wer(gt_list, prediction_list, no_diacritics=False):
|
| 263 |
+
|
| 264 |
+
if no_diacritics:
|
| 265 |
+
gt_list = [remove_diacritics(g) for g in gt_list]
|
| 266 |
+
prediction_list = [remove_diacritics(p) for p in prediction_list]
|
| 267 |
+
|
| 268 |
+
cer_matrix = get_cer_matrix(gt_list, prediction_list)
|
| 269 |
+
cer_gt = cer_matrix.min(axis=1).flatten()
|
| 270 |
+
cer_p = cer_matrix.min(axis=0).flatten()
|
| 271 |
+
cer_dtw = get_dtw(cer_matrix)
|
| 272 |
+
|
| 273 |
+
gt = '\n'.join(gt_list)
|
| 274 |
+
pred = '\n'.join(prediction_list)
|
| 275 |
+
cer_para = error_rates.cer(gt, pred)
|
| 276 |
+
wer_para = error_rates.wer(gt, pred)
|
| 277 |
+
|
| 278 |
+
return cer_dtw, np.sum(cer_gt), np.sum(cer_p), cer_para, wer_para
|
| 279 |
+
|
| 280 |
+
def flip_polygon(img_file, poly_list):
|
| 281 |
+
img = cv2.imread(img_file)
|
| 282 |
+
h, w = img.shape[:2]
|
| 283 |
+
flipped_poly_list = []
|
| 284 |
+
for p in poly_list:
|
| 285 |
+
flipped = [(w-x, y) for (x, y) in p]
|
| 286 |
+
flipped_poly_list.append(flipped)
|
| 287 |
+
return flipped_poly_list
|
| 288 |
+
|
| 289 |
+
# This will create annotations that can be used in scribeArabic annotation tool
|
| 290 |
+
def create_annotations_json(predicted_text, poly_list):
|
| 291 |
+
if len(predicted_text) != len(poly_list):
|
| 292 |
+
print("POLYGONG LIST LEN Not same as PREDICTED TEXT LEN")
|
| 293 |
+
slkfj
|
| 294 |
+
page_json = dict()
|
| 295 |
+
for ind, (ocr, poly) in enumerate(zip(predicted_text, poly_list)):
|
| 296 |
+
pts = points.xy_to_list(poly)
|
| 297 |
+
line_key = f'line_{ind+1}'
|
| 298 |
+
page_json[line_key] = dict()
|
| 299 |
+
page_json[line_key]['coord'] = pts
|
| 300 |
+
page_json[line_key]['text'] = ocr
|
| 301 |
+
return page_json
|
arabic/test_hw_helper_routines.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('py3/')
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import yaml
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import hw
|
| 10 |
+
from hw import cnn_lstm
|
| 11 |
+
from utils import string_utils, error_rates
|
| 12 |
+
import cv2
|
| 13 |
+
|
| 14 |
+
HT = 60
|
| 15 |
+
|
| 16 |
+
def load_char_set(char_set_path):
|
| 17 |
+
with open(char_set_path) as f:
|
| 18 |
+
char_set = json.load(f)
|
| 19 |
+
|
| 20 |
+
idx_to_char = {}
|
| 21 |
+
for k,v in char_set['idx_to_char'].items():
|
| 22 |
+
idx_to_char[int(k)] = v
|
| 23 |
+
return idx_to_char
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_config(config_file):
|
| 28 |
+
with open(config_file) as f:
|
| 29 |
+
config = yaml.load(f, Loader=yaml.loader.SafeLoader)
|
| 30 |
+
return config
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load_HW(hw_network_config, pt_filename):
|
| 35 |
+
HW = cnn_lstm.create_model(hw_network_config)
|
| 36 |
+
hw_state = torch.load(pt_filename)
|
| 37 |
+
HW.load_state_dict(hw_state)
|
| 38 |
+
|
| 39 |
+
device = torch.device("cuda")
|
| 40 |
+
HW.to(device)
|
| 41 |
+
return HW
|
| 42 |
+
|
| 43 |
+
def get_predicted_str(HW, img_file, idx_to_char, device="cuda", flip=False,
|
| 44 |
+
show=False, read_image=True, img=None, tokenizer=None):
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
device = torch.device(device)
|
| 48 |
+
if read_image:
|
| 49 |
+
img = cv2.imread(img_file)
|
| 50 |
+
ht, width = img.shape[:2]
|
| 51 |
+
|
| 52 |
+
if ht != HT:
|
| 53 |
+
new_width = int(width/ht*HT)
|
| 54 |
+
img = cv2.resize(img, (new_width, HT))
|
| 55 |
+
if show:
|
| 56 |
+
|
| 57 |
+
plt.imshow(img)
|
| 58 |
+
plt.show()
|
| 59 |
+
if flip:
|
| 60 |
+
img = np.flip(img, axis=1)
|
| 61 |
+
img = img.astype(np.float32)
|
| 62 |
+
img = img / 128.0 - 1.0
|
| 63 |
+
img = np.expand_dims(img, 0)
|
| 64 |
+
img = img.transpose([0,3,1,2])
|
| 65 |
+
img = torch.from_numpy(img)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
IMG = img.to(device)
|
| 69 |
+
#print('img size is', img.size())
|
| 70 |
+
|
| 71 |
+
preds = HW(IMG).cpu()
|
| 72 |
+
|
| 73 |
+
out = preds.permute(1,0,2)
|
| 74 |
+
out = out.data.numpy()
|
| 75 |
+
logits = out[0,...]
|
| 76 |
+
pred, raw_pred = string_utils.naive_decode(logits)
|
| 77 |
+
if tokenizer is None:
|
| 78 |
+
pred_str = string_utils.label2str_single(pred, idx_to_char, False)
|
| 79 |
+
else:
|
| 80 |
+
pred_str = tokenizer.decode(pred)
|
| 81 |
+
|
| 82 |
+
del IMG
|
| 83 |
+
return pred_str
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def write_csv_all_predictions(config_file, suffix="", device="cuda", flip=False, pt_file='hw.pt',
|
| 87 |
+
test_file_to_use="", result_file="", tokenizer=None):
|
| 88 |
+
|
| 89 |
+
result_df = pd.DataFrame(columns=["image", "ground_truth", "prediction", "CER", "WER"])
|
| 90 |
+
config = get_config(config_file)
|
| 91 |
+
idx_to_char = load_char_set(config['network']['hw']['char_set_path'])
|
| 92 |
+
if 'hw_to_save' in config['pretraining'].keys():
|
| 93 |
+
pt_file = config['pretraining']['hw_to_save']
|
| 94 |
+
else:
|
| 95 |
+
pt_file = pt_file
|
| 96 |
+
pt_filename = os.path.join(config['pretraining']['snapshot_path'], pt_file)
|
| 97 |
+
|
| 98 |
+
config["network"]["hw"]["num_of_outputs"] = len(idx_to_char) + 1
|
| 99 |
+
if tokenizer is not None:
|
| 100 |
+
config["network"]["hw"]["num_of_outputs"] = tokenizer.get_vocab_size()
|
| 101 |
+
pt_filename = os.path.join(config['pretraining']['snapshot_path'], f"hw_tokenizer_{tokenizer.get_vocab_size()}.pt")
|
| 102 |
+
|
| 103 |
+
if len(suffix) > 0:
|
| 104 |
+
pt_filename = pt_filename[:-3] + suffix + '.pt'
|
| 105 |
+
print('...Using snapshot', pt_filename)
|
| 106 |
+
HW = load_HW(config['network']['hw'], pt_filename)
|
| 107 |
+
device = torch.device(device)
|
| 108 |
+
HW.to(device)
|
| 109 |
+
HW.eval()
|
| 110 |
+
|
| 111 |
+
if test_file_to_use == "":
|
| 112 |
+
test_json_file = config['testing']['test_file']
|
| 113 |
+
else:
|
| 114 |
+
test_json_file = test_file_to_use
|
| 115 |
+
print('Using test file', test_json_file)
|
| 116 |
+
|
| 117 |
+
with open(test_json_file) as f:
|
| 118 |
+
json_obj = json.load(f)
|
| 119 |
+
for ind, obj in enumerate(json_obj):
|
| 120 |
+
#obj is a list of: [jsonfile imgfile]
|
| 121 |
+
#open the json file and get a list of predictions
|
| 122 |
+
with open(obj[0]) as f:
|
| 123 |
+
image_list = json.load(f)
|
| 124 |
+
|
| 125 |
+
for record in image_list:
|
| 126 |
+
# If not gt in file
|
| 127 |
+
if record['gt'] == 'None' or isinstance(record['gt'], float) or record['gt'] == 'nan' or len(record['gt']) == 0:
|
| 128 |
+
print(type(record['gt']), record['gt'], record['hw_path'])
|
| 129 |
+
continue
|
| 130 |
+
#print(record['hw_path'])
|
| 131 |
+
predicted_str = get_predicted_str(HW, record['hw_path'], idx_to_char, device=device,
|
| 132 |
+
flip=flip, tokenizer=tokenizer)
|
| 133 |
+
|
| 134 |
+
cer = error_rates.cer(record['gt'], predicted_str)
|
| 135 |
+
wer = error_rates.wer(record['gt'], predicted_str)
|
| 136 |
+
result_df.loc[len(result_df)] = [record['hw_path'], record['gt'],
|
| 137 |
+
predicted_str, cer, wer]
|
| 138 |
+
|
| 139 |
+
if len(result_file) == 0:
|
| 140 |
+
result_file = config_file.replace("config", "result")
|
| 141 |
+
result_file = result_file.replace("yaml", "csv")
|
| 142 |
+
if len(suffix)>0:
|
| 143 |
+
result_file = result_file[:-4] + suffix + '.csv'
|
| 144 |
+
result_df.to_csv(result_file, index=False)
|
| 145 |
+
return result_df
|
| 146 |
+
|
| 147 |
+
|
arabic/warp_routines.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from svgpathtools import Path, Line
|
| 2 |
+
from scipy.interpolate import griddata
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
import sys
|
| 6 |
+
sys.path.append('../../coords/')
|
| 7 |
+
import points
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
def generate_offset_mapping(img, ts, path, offset_1, offset_2, max_min = None, cube_size = None):
|
| 11 |
+
# cube_size = 80
|
| 12 |
+
|
| 13 |
+
offset_1_pts = []
|
| 14 |
+
offset_2_pts = []
|
| 15 |
+
# for t in ts:
|
| 16 |
+
for i in range(len(ts)):
|
| 17 |
+
t = ts[i]
|
| 18 |
+
pt = path.point(t)
|
| 19 |
+
|
| 20 |
+
norm = None
|
| 21 |
+
if i == 0:
|
| 22 |
+
norm = normal(pt, path.point(ts[i+1]))
|
| 23 |
+
norm = norm / dis(complex(0,0), norm)
|
| 24 |
+
elif i == len(ts)-1:
|
| 25 |
+
norm = normal(path.point(ts[i-1]), pt)
|
| 26 |
+
norm = norm / dis(complex(0,0), norm)
|
| 27 |
+
else:
|
| 28 |
+
norm1 = normal(path.point(ts[i-1]), pt)
|
| 29 |
+
norm1 = norm1 / dis(complex(0,0), norm1)
|
| 30 |
+
norm2 = normal(pt, path.point(ts[i+1]))
|
| 31 |
+
norm2 = norm2 / dis(complex(0,0), norm2)
|
| 32 |
+
|
| 33 |
+
norm = (norm1 + norm2)/2
|
| 34 |
+
norm = norm / dis(complex(0,0), norm)
|
| 35 |
+
|
| 36 |
+
offset_vector1 = offset_1 * norm
|
| 37 |
+
offset_vector2 = offset_2 * norm
|
| 38 |
+
|
| 39 |
+
pt1 = pt + offset_vector1
|
| 40 |
+
pt2 = pt + offset_vector2
|
| 41 |
+
|
| 42 |
+
offset_1_pts.append(complexToNpPt(pt1))
|
| 43 |
+
offset_2_pts.append(complexToNpPt(pt2))
|
| 44 |
+
|
| 45 |
+
offset_1_pts = np.array(offset_1_pts)
|
| 46 |
+
offset_2_pts = np.array(offset_2_pts)
|
| 47 |
+
|
| 48 |
+
h,w = img.shape[:2]
|
| 49 |
+
|
| 50 |
+
offset_source2 = np.array([(cube_size*i, 0) for i in range(len(offset_1_pts))], dtype=np.float32)
|
| 51 |
+
offset_source1 = np.array([(cube_size*i, cube_size) for i in range(len(offset_2_pts))], dtype=np.float32)
|
| 52 |
+
|
| 53 |
+
offset_source1 = offset_source1[::-1]
|
| 54 |
+
offset_source2 = offset_source2[::-1]
|
| 55 |
+
|
| 56 |
+
source = np.concatenate([offset_source1, offset_source2])
|
| 57 |
+
destination = np.concatenate([offset_1_pts, offset_2_pts])
|
| 58 |
+
|
| 59 |
+
source = source[:,::-1]
|
| 60 |
+
destination = destination[:,::-1]
|
| 61 |
+
|
| 62 |
+
n_w = int(offset_source2[:,0].max())
|
| 63 |
+
n_h = int(cube_size)
|
| 64 |
+
|
| 65 |
+
grid_x, grid_y = np.mgrid[0:n_h, 0:n_w]
|
| 66 |
+
|
| 67 |
+
grid_z = griddata(source, destination, (grid_x, grid_y), method='cubic')
|
| 68 |
+
map_x = np.append([], [ar[:,1] for ar in grid_z]).reshape(n_h,n_w)
|
| 69 |
+
map_y = np.append([], [ar[:,0] for ar in grid_z]).reshape(n_h,n_w)
|
| 70 |
+
map_x_32 = map_x.astype('float32')
|
| 71 |
+
map_y_32 = map_y.astype('float32')
|
| 72 |
+
|
| 73 |
+
rectified_to_warped_x = map_x_32
|
| 74 |
+
rectified_to_warped_y = map_y_32
|
| 75 |
+
|
| 76 |
+
grid_x, grid_y = np.mgrid[0:h, 0:w]
|
| 77 |
+
grid_z = griddata(source, destination, (grid_x, grid_y), method='cubic')
|
| 78 |
+
map_x = np.append([], [ar[:,1] for ar in grid_z]).reshape(h,w)
|
| 79 |
+
map_y = np.append([], [ar[:,0] for ar in grid_z]).reshape(h,w)
|
| 80 |
+
map_x_32 = map_x.astype('float32')
|
| 81 |
+
map_y_32 = map_y.astype('float32')
|
| 82 |
+
|
| 83 |
+
warped_to_rectified_x = map_x_32
|
| 84 |
+
warped_to_rectified_y = map_y_32
|
| 85 |
+
|
| 86 |
+
return rectified_to_warped_x, rectified_to_warped_y, warped_to_rectified_x, warped_to_rectified_y, max_min
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def dis(pt1, pt2):
|
| 90 |
+
a = (pt1.real - pt2.real)**2
|
| 91 |
+
b = (pt1.imag - pt2.imag)**2
|
| 92 |
+
return np.sqrt(a+b)
|
| 93 |
+
|
| 94 |
+
def complexToNpPt(pt):
|
| 95 |
+
return np.array([pt.real, pt.imag], dtype=np.float32)
|
| 96 |
+
|
| 97 |
+
def normal(pt1, pt2):
|
| 98 |
+
dif = pt1 - pt2
|
| 99 |
+
return complex(-dif.imag, dif.real)
|
| 100 |
+
|
| 101 |
+
def find_t_spacing(path, cube_size):
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
l = path.length()
|
| 105 |
+
error = 0.01
|
| 106 |
+
init_step_size = cube_size / l
|
| 107 |
+
|
| 108 |
+
last_t = 0
|
| 109 |
+
cur_t = 0
|
| 110 |
+
pts = []
|
| 111 |
+
ts = [0]
|
| 112 |
+
pts.append(complexToNpPt(path.point(cur_t)))
|
| 113 |
+
path_lookup = {}
|
| 114 |
+
for target in np.arange(cube_size, int(l), cube_size):
|
| 115 |
+
step_size = init_step_size
|
| 116 |
+
for i in range(1000):
|
| 117 |
+
cur_length = dis(path.point(last_t), path.point(cur_t))
|
| 118 |
+
if np.abs(cur_length - cube_size) < error:
|
| 119 |
+
break
|
| 120 |
+
|
| 121 |
+
step_t = min(cur_t + step_size, 1.0)
|
| 122 |
+
step_l = dis(path.point(last_t), path.point(step_t))
|
| 123 |
+
|
| 124 |
+
if np.abs(step_l - cube_size) < np.abs(cur_length - cube_size):
|
| 125 |
+
cur_t = step_t
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
step_t = max(cur_t - step_size, 0.0)
|
| 129 |
+
step_t = max(step_t, last_t)
|
| 130 |
+
step_t = max(step_t, 1.0)
|
| 131 |
+
|
| 132 |
+
step_l = dis(path.point(last_t), path.point(step_t))
|
| 133 |
+
|
| 134 |
+
if np.abs(step_l - cube_size) < np.abs(cur_length - cube_size):
|
| 135 |
+
cur_t = step_t
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
step_size = step_size / 2.0
|
| 139 |
+
|
| 140 |
+
last_t = cur_t
|
| 141 |
+
|
| 142 |
+
ts.append(cur_t)
|
| 143 |
+
pts.append(complexToNpPt(path.point(cur_t)))
|
| 144 |
+
|
| 145 |
+
pts = np.array(pts)
|
| 146 |
+
|
| 147 |
+
return ts
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def remap_with_grid_sample(input_image, map_x, map_y, padding, img_tensor=None, device="cuda"):
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
reshaped = False
|
| 155 |
+
H, W = input_image.shape[:2]
|
| 156 |
+
|
| 157 |
+
if len(input_image.shape) == 2:
|
| 158 |
+
#print('reshaping', input_image.shape)
|
| 159 |
+
input_image = input_image.reshape(H, W, 1)
|
| 160 |
+
reshaped = True
|
| 161 |
+
|
| 162 |
+
if img_tensor is None:
|
| 163 |
+
# Convert input image to PyTorch tensor in NCHW format and normalize to [0, 1]
|
| 164 |
+
img_tensor = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
|
| 165 |
+
|
| 166 |
+
img_tensor = img_tensor.to(device)
|
| 167 |
+
|
| 168 |
+
#print(input_image.shape, map_x.shape)
|
| 169 |
+
|
| 170 |
+
# Convert map_x and map_y to normalized coordinates in the range [-1, 1]
|
| 171 |
+
norm_map_x = (torch.from_numpy(map_x.copy()).float() / (W - 1)) * 2 - 1
|
| 172 |
+
norm_map_y = (torch.from_numpy(map_y.copy()).float() / (H - 1)) * 2 - 1
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# Stack normalized coordinates to create a grid of shape (1, H, W, 2)
|
| 176 |
+
grid = torch.stack((norm_map_x, norm_map_y), dim=-1).unsqueeze(0)
|
| 177 |
+
|
| 178 |
+
# Ensure grid is on the same device as the input tensor (e.g., GPU if available)
|
| 179 |
+
grid = grid.to(img_tensor.device)
|
| 180 |
+
|
| 181 |
+
# Apply grid_sample to perform the remap operation
|
| 182 |
+
output_tensor = torch.nn.functional.grid_sample(img_tensor, grid, mode='bilinear',
|
| 183 |
+
padding_mode=padding, align_corners=True)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# Convert back to NumPy and scale back to [0, 255]
|
| 187 |
+
output_image = (output_tensor.squeeze(dim=0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 188 |
+
if reshaped:
|
| 189 |
+
output_image = output_image[:, :, 0]
|
| 190 |
+
return output_image
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def generate_offset_mapping_1(img, ts, path, offset_1, offset_2,
|
| 194 |
+
cube_size = None):
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
offset_1_pts = []
|
| 198 |
+
offset_2_pts = []
|
| 199 |
+
# for t in ts:
|
| 200 |
+
for i in range(len(ts)):
|
| 201 |
+
t = ts[i]
|
| 202 |
+
pt = path.point(t)
|
| 203 |
+
|
| 204 |
+
norm = None
|
| 205 |
+
if i == 0:
|
| 206 |
+
norm = normal(pt, path.point(ts[i+1]))
|
| 207 |
+
norm = norm / dis(complex(0,0), norm)
|
| 208 |
+
elif i == len(ts)-1:
|
| 209 |
+
norm = normal(path.point(ts[i-1]), pt)
|
| 210 |
+
norm = norm / dis(complex(0,0), norm)
|
| 211 |
+
else:
|
| 212 |
+
norm1 = normal(path.point(ts[i-1]), pt)
|
| 213 |
+
norm1 = norm1 / dis(complex(0,0), norm1)
|
| 214 |
+
norm2 = normal(pt, path.point(ts[i+1]))
|
| 215 |
+
norm2 = norm2 / dis(complex(0,0), norm2)
|
| 216 |
+
|
| 217 |
+
norm = (norm1 + norm2)/2
|
| 218 |
+
norm = norm / dis(complex(0,0), norm)
|
| 219 |
+
|
| 220 |
+
offset_vector1 = offset_1 * norm
|
| 221 |
+
offset_vector2 = offset_2 * norm
|
| 222 |
+
|
| 223 |
+
pt1 = pt + offset_vector1
|
| 224 |
+
pt2 = pt + offset_vector2
|
| 225 |
+
|
| 226 |
+
offset_1_pts.append(complexToNpPt(pt1))
|
| 227 |
+
offset_2_pts.append(complexToNpPt(pt2))
|
| 228 |
+
|
| 229 |
+
offset_1_pts = np.array(offset_1_pts)
|
| 230 |
+
offset_2_pts = np.array(offset_2_pts)
|
| 231 |
+
|
| 232 |
+
h,w = img.shape[:2]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
offset_source2 = np.array([(cube_size*i, 0) for i in range(len(offset_1_pts))], dtype=np.float32)
|
| 236 |
+
offset_source1 = np.array([(cube_size*i, cube_size) for i in range(len(offset_2_pts))], dtype=np.float32)
|
| 237 |
+
|
| 238 |
+
offset_source1 = offset_source1[::-1]
|
| 239 |
+
offset_source2 = offset_source2[::-1]
|
| 240 |
+
|
| 241 |
+
source = np.concatenate([offset_source1, offset_source2])
|
| 242 |
+
destination = np.concatenate([offset_1_pts, offset_2_pts])
|
| 243 |
+
|
| 244 |
+
source = source[:,::-1]
|
| 245 |
+
destination = destination[:,::-1]
|
| 246 |
+
|
| 247 |
+
n_w = int(offset_source2[:,0].max())
|
| 248 |
+
n_h = int(cube_size)
|
| 249 |
+
|
| 250 |
+
grid_x, grid_y = np.mgrid[0:n_h, 0:n_w]
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
grid_z = griddata(source, destination, (grid_x, grid_y), method='cubic')
|
| 254 |
+
map_x = np.append([], [ar[:,1] for ar in grid_z]).reshape(n_h,n_w)
|
| 255 |
+
map_y = np.append([], [ar[:,0] for ar in grid_z]).reshape(n_h,n_w)
|
| 256 |
+
map_x_32 = map_x.astype('float32')
|
| 257 |
+
map_y_32 = map_y.astype('float32')
|
| 258 |
+
|
| 259 |
+
rectified_to_warped_x = map_x_32
|
| 260 |
+
rectified_to_warped_y = map_y_32
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
return rectified_to_warped_x, rectified_to_warped_y
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def get_warped_images(img, polygon_list, baseline_list, target_height=60, device="cuda"):
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
num_lines = len(polygon_list)
|
| 271 |
+
all_lines = ""
|
| 272 |
+
|
| 273 |
+
warped_list = []
|
| 274 |
+
region_output_data = []
|
| 275 |
+
|
| 276 |
+
# img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float() / 255.0
|
| 277 |
+
# img_tensor = img_tensor.to(device)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
for ind in range(len(polygon_list)):
|
| 281 |
+
|
| 282 |
+
line_mask = extract_region_mask(img, polygon_list[ind])
|
| 283 |
+
|
| 284 |
+
summed_axis0 = (line_mask.astype(float) / 255).sum(axis=0)
|
| 285 |
+
avg_height0 = np.median(summed_axis0[summed_axis0 != 0])
|
| 286 |
+
|
| 287 |
+
target_step_size = avg_height0*1.1
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
paths = []
|
| 291 |
+
for i in range(len(baseline_list[ind])-1):
|
| 292 |
+
i_1 = i+1
|
| 293 |
+
|
| 294 |
+
p1 = baseline_list[ind][i]
|
| 295 |
+
p2 = baseline_list[ind][i_1]
|
| 296 |
+
|
| 297 |
+
p1_c = complex(*p1)
|
| 298 |
+
p2_c = complex(*p2)
|
| 299 |
+
|
| 300 |
+
paths.append(Line(p1_c, p2_c))
|
| 301 |
+
|
| 302 |
+
if len(paths) == 0:
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
#try:
|
| 306 |
+
if True:
|
| 307 |
+
# Add a bit on the end
|
| 308 |
+
tan = paths[-1].unit_tangent(1.0)
|
| 309 |
+
p3_c = p2_c + target_step_size * tan
|
| 310 |
+
paths.append(Line(p2_c, p3_c))
|
| 311 |
+
|
| 312 |
+
path = Path(*paths)
|
| 313 |
+
|
| 314 |
+
n_w = target_height*path.length()/target_step_size
|
| 315 |
+
ts = np.arange(0, 1, target_height/float(n_w))
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
(rectified_to_warped_x,
|
| 320 |
+
rectified_to_warped_y) = generate_offset_mapping_1(img, ts, path, target_step_size/2,
|
| 321 |
+
-target_step_size/2,
|
| 322 |
+
cube_size=target_height)
|
| 323 |
+
|
| 324 |
+
rectified_to_warped_x = rectified_to_warped_x[::-1,::-1]
|
| 325 |
+
rectified_to_warped_y = rectified_to_warped_y[::-1,::-1]
|
| 326 |
+
|
| 327 |
+
#warped = remap_with_grid_sample(img, rectified_to_warped_x,
|
| 328 |
+
# rectified_to_warped_y, "border", img_tensor=img_tensor, device=device)
|
| 329 |
+
|
| 330 |
+
warped = cv2.remap(img, rectified_to_warped_x, rectified_to_warped_y, cv2.INTER_CUBIC, borderValue=(255,255,255))
|
| 331 |
+
warped_list.append(warped)
|
| 332 |
+
|
| 333 |
+
return warped_list
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def extract_region_mask(img, bounding_poly):
|
| 339 |
+
pts = np.array(bounding_poly, np.int32)
|
| 340 |
+
|
| 341 |
+
#http://stackoverflow.com/a/15343106/3479446
|
| 342 |
+
mask = np.zeros(img.shape[:2], dtype=np.uint8)
|
| 343 |
+
roi_corners = np.array([pts], dtype=np.int32)
|
| 344 |
+
|
| 345 |
+
ignore_mask_color = (255,)
|
| 346 |
+
cv2.fillPoly(mask, roi_corners, ignore_mask_color, lineType=cv2.LINE_8)
|
| 347 |
+
return mask
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def get_baseline(poly_points, right_to_left=True):
|
| 352 |
+
|
| 353 |
+
baseline = points.get_baseline_chunks(poly_points)
|
| 354 |
+
# Right to left order
|
| 355 |
+
#if right_to_left:
|
| 356 |
+
# baseline.sort(key=lambda x: x[0], reverse=True)
|
| 357 |
+
return baseline
|
| 358 |
+
|
| 359 |
+
# First argument ignored if third argument is given
|
| 360 |
+
def get_line_image(polygon_flat_list, img, polygon_pts=None, target_height=60):
|
| 361 |
+
if polygon_pts is None:
|
| 362 |
+
polygon_pts = points.list_to_xy(polygon_flat_list)
|
| 363 |
+
baseline = get_baseline(polygon_pts)
|
| 364 |
+
line_img = get_warped_images(img, [polygon_pts], [baseline], target_height=target_height)
|
| 365 |
+
if len(line_img) > 0:
|
| 366 |
+
return line_img[0]
|
| 367 |
+
return None
|
| 368 |
+
|
coords/__init__.py
ADDED
|
File without changes
|
coords/points.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
from matplotlib.patches import Rectangle
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image, ImageDraw
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# img_file is the filename with full path
|
| 8 |
+
# points is a string of coordinate points read from XML
|
| 9 |
+
def generate_cropped_region(img_file, points):
|
| 10 |
+
# Process points
|
| 11 |
+
points_list = points.split(' ')
|
| 12 |
+
xy_pts = [(int(points.split(",")[0]),
|
| 13 |
+
int(points.split(",")[1])) for points in points_list]
|
| 14 |
+
pts = np.array(xy_pts)
|
| 15 |
+
img_obj = Image.open(img_file)
|
| 16 |
+
img = np.array(img_obj)
|
| 17 |
+
# Crop the image
|
| 18 |
+
[min_x, min_y] = np.min(pts, axis=0)
|
| 19 |
+
[max_x, max_y] = np.max(pts, axis=0)
|
| 20 |
+
cropped_img = img[min_y:max_y+1, min_x:max_x+1, :]
|
| 21 |
+
cropped_img_obj = Image.fromarray(cropped_img)
|
| 22 |
+
return cropped_img_obj, (min_x, min_y), (max_x, max_y)
|
| 23 |
+
|
| 24 |
+
# pts is numpy 2D points array
|
| 25 |
+
# img also numpy arrray/cv2 array
|
| 26 |
+
def generate_cropped_image(img, pts):
|
| 27 |
+
img_obj = Image.fromarray(img)
|
| 28 |
+
|
| 29 |
+
# Create a polygonal mask
|
| 30 |
+
mask = Image.new('L', (img_obj.width, img_obj.height), color=0)
|
| 31 |
+
draw_mask = ImageDraw.Draw(mask)
|
| 32 |
+
draw_mask.polygon(list(pts.flatten()), fill=255)
|
| 33 |
+
mask = np.array(mask).astype(bool)
|
| 34 |
+
# Choose the polygonal area from image
|
| 35 |
+
output_img = np.zeros_like(img)
|
| 36 |
+
output_img[mask] = img[mask]
|
| 37 |
+
# Crop the image
|
| 38 |
+
[min_x, min_y] = np.min(pts, axis=0).astype(int)
|
| 39 |
+
[max_x, max_y] = np.max(pts, axis=0).astype(int)
|
| 40 |
+
cropped_img = output_img[min_y:max_y+1, min_x:max_x+1, :]
|
| 41 |
+
|
| 42 |
+
return cropped_img
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def draw_poly(plt, xy_pts, color='green'):
|
| 46 |
+
|
| 47 |
+
plt.gca().add_patch(Rectangle(xy_pts[0], 10, 10, facecolor='yellow'))
|
| 48 |
+
for pts1, pts2 in zip(xy_pts, xy_pts[1:]):
|
| 49 |
+
#img = color_rect(img,pts1[0], pts1[1], pts2[0], pts2[1])
|
| 50 |
+
draw_line(plt, pts1[0], pts1[1], pts2[0], pts2[1], color)
|
| 51 |
+
draw_line(plt, xy_pts[0][0], xy_pts[0][1],
|
| 52 |
+
xy_pts[-1][0], xy_pts[-1][1], color)
|
| 53 |
+
|
| 54 |
+
def draw_baseline(plt, xy_pts, color='red'):
|
| 55 |
+
|
| 56 |
+
plt.gca().add_patch(Rectangle(xy_pts[0], 100, 100, facecolor='blue'))
|
| 57 |
+
for pts1, pts2 in zip(xy_pts, xy_pts[1:]):
|
| 58 |
+
#img = color_rect(img,pts1[0], pts1[1], pts2[0], pts2[1])
|
| 59 |
+
draw_line(plt, pts1[0], pts1[1], pts2[0], pts2[1], color=color)
|
| 60 |
+
|
| 61 |
+
def draw_line(plt_obj, x1, y1, x2, y2, color='g'):
|
| 62 |
+
plt_obj.plot([x1, x2], [y1, y2], color=color, linewidth=1)
|
| 63 |
+
|
| 64 |
+
# The argument points is a string.
|
| 65 |
+
# Function returns a list of (x,y) tuples
|
| 66 |
+
def get_xy_pts(points):
|
| 67 |
+
points_list = points.split(' ')
|
| 68 |
+
xy_pts = [(int(points.split(",")[0]),
|
| 69 |
+
int(points.split(",")[1])) for points in points_list]
|
| 70 |
+
return xy_pts
|
| 71 |
+
|
| 72 |
+
# bbox is not necessarily a polygon or rectangle. Just a list of (x,y) tuples
|
| 73 |
+
# If apply_correction is True then all points are restricted to lie within
|
| 74 |
+
# top left and bottom right
|
| 75 |
+
def add_offset_to_polygon(bbox, offset, apply_correction=False,
|
| 76 |
+
top_left=[], bottom_right=[]):
|
| 77 |
+
new_bbox = []
|
| 78 |
+
for i,coord in enumerate(bbox):
|
| 79 |
+
new_bbox.append((coord[0]+offset[0], coord[1]+offset[1]))
|
| 80 |
+
if apply_correction:
|
| 81 |
+
top_left = np.array(top_left)
|
| 82 |
+
bottom_right = np.array(bottom_right)
|
| 83 |
+
pts = np.array(new_bbox)
|
| 84 |
+
for j in [0, 1]:
|
| 85 |
+
ind = np.where(pts[:, j] > bottom_right[j])
|
| 86 |
+
pts[ind, j] = bottom_right[j]
|
| 87 |
+
for j in [0, 1]:
|
| 88 |
+
ind = np.where(pts[:, j] < top_left[j])
|
| 89 |
+
pts[ind, j] = top_left[j]
|
| 90 |
+
new_bbox = list(map(tuple, pts))
|
| 91 |
+
|
| 92 |
+
return new_bbox
|
| 93 |
+
|
| 94 |
+
def add_offset_to_polygon_list(polygon_list, offset):
|
| 95 |
+
new_polygon_list = []
|
| 96 |
+
for poly in polygon_list:
|
| 97 |
+
new_poly = add_offset_to_polygon(poly, offset)
|
| 98 |
+
new_polygon_list.append(new_poly)
|
| 99 |
+
return new_polygon_list
|
| 100 |
+
|
| 101 |
+
def combine_poly(poly1, poly2):
|
| 102 |
+
main_poly = [poly1[0], poly1[1]]
|
| 103 |
+
if poly1[1][0] != poly2[0][0] or poly1[1][1] != poly2[0][1]:
|
| 104 |
+
main_poly.append(poly2[0])
|
| 105 |
+
main_poly.extend(poly2[1:3])
|
| 106 |
+
if poly1[2][0] != poly2[3][0] or poly1[2][1] != poly2[3][1]:
|
| 107 |
+
main_poly.append(poly2[3])
|
| 108 |
+
main_poly.extend(poly1[2:])
|
| 109 |
+
return main_poly
|
| 110 |
+
|
| 111 |
+
# Get left, upper, right, lower min_x, min_y, max_x, max_y
|
| 112 |
+
def get_max_min_polygon(polygon):
|
| 113 |
+
x_list = [x[0] for x in polygon]
|
| 114 |
+
y_list = [y[1] for y in polygon]
|
| 115 |
+
return min(x_list), min(y_list), max(x_list), max(y_list)
|
| 116 |
+
|
| 117 |
+
def add_offset_to_baseline(baseline, offset):
|
| 118 |
+
new_baseline = []
|
| 119 |
+
for pts in baseline:
|
| 120 |
+
new_baseline.append((pts[0]+offset[0], pts[1]+offset[1]))
|
| 121 |
+
return new_baseline
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def add_offset_to_baseline_list(baseline_list, offset):
|
| 125 |
+
new_list = []
|
| 126 |
+
for b in baseline_list:
|
| 127 |
+
new_b = add_offset_to_baseline(b, offset)
|
| 128 |
+
new_list.append(new_b)
|
| 129 |
+
return new_list
|
| 130 |
+
|
| 131 |
+
def combine_baseline(base1, base2):
|
| 132 |
+
#combined = [base1[0], base1[1], base2[0], base2[1]]
|
| 133 |
+
combined = [base2[1], base2[0], base1[1], base1[0]]
|
| 134 |
+
return combined
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_x_y(polygon):
|
| 138 |
+
x_list = [x[0] for x in polygon]
|
| 139 |
+
y_list = [y[1] for y in polygon]
|
| 140 |
+
return x_list, y_list
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# num_pts is number of points on baseline to get
|
| 144 |
+
def get_baseline_regression(poly_pts, num_pts=10, deg=1):
|
| 145 |
+
if len(poly_pts) <= 4:
|
| 146 |
+
deg = 1
|
| 147 |
+
x, y = get_x_y(poly_pts)
|
| 148 |
+
model = (np.polyfit(x, y, deg))
|
| 149 |
+
p = np.poly1d(model)
|
| 150 |
+
# get the x against which we want y
|
| 151 |
+
x1, y1, x2, y2 = get_max_min_polygon(poly_pts)
|
| 152 |
+
num_pts = min(num_pts, x2-x1+1)
|
| 153 |
+
num_pts = int(num_pts)
|
| 154 |
+
x = np.linspace(x1, x2, num_pts, endpoint=True, dtype=int)
|
| 155 |
+
y = p(x)
|
| 156 |
+
baseline = [(a, b) for a,b in zip(x, y)]
|
| 157 |
+
return baseline
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# Given a coordinates list return xy tuples
|
| 162 |
+
def list_to_xy(coord_list):
|
| 163 |
+
xy_list = []
|
| 164 |
+
for ind in range(0, len(coord_list), 2):
|
| 165 |
+
xy_list.append((coord_list[ind], coord_list[ind+1]))
|
| 166 |
+
return xy_list
|
| 167 |
+
|
| 168 |
+
# Given an (x,y) list of tuples, return a flat list
|
| 169 |
+
def xy_to_list(tuples_list):
|
| 170 |
+
flat_list = [x for pair in tuples_list for x in pair]
|
| 171 |
+
return flat_list
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# x is a list of points
|
| 175 |
+
# y is a cooresponding list of points
|
| 176 |
+
def get_baseline_from_xy(x, y, num_pts=10, deg=1):
|
| 177 |
+
if len(x) <= 4:
|
| 178 |
+
deg = 1
|
| 179 |
+
|
| 180 |
+
model = (np.polyfit(x, y, deg))
|
| 181 |
+
p = np.poly1d(model)
|
| 182 |
+
# get the x against which we want y
|
| 183 |
+
x1, y1, x2, y2 = get_max_min_polygon(poly_pts)
|
| 184 |
+
num_pts = min(num_pts, x2-x1+1)
|
| 185 |
+
x = np.linspace(x1, x2, num_pts, endpoint=True, dtype=int)
|
| 186 |
+
y = p(x)
|
| 187 |
+
baseline = [(a, b) for a,b in zip(x, y)]
|
| 188 |
+
return baseline
|
| 189 |
+
|
| 190 |
+
# Here img is the 2D numpy array
|
| 191 |
+
# xy_pts is a list of (x,y) tuples
|
| 192 |
+
def generate_cropped_region_from_polypts(img, xy_pts):
|
| 193 |
+
pts = np.array(xy_pts)
|
| 194 |
+
|
| 195 |
+
# Crop the image
|
| 196 |
+
[min_x, min_y] = np.ceil(np.min(pts, axis=0)).astype(int)
|
| 197 |
+
[max_x, max_y] = np.floor(np.max(pts, axis=0)).astype(int)
|
| 198 |
+
cropped_img = img[min_y:max_y+1, min_x:max_x+1, :]
|
| 199 |
+
cropped_img_obj = Image.fromarray(cropped_img)
|
| 200 |
+
return cropped_img_obj, (min_x, min_y), (max_x, max_y)
|
| 201 |
+
|
| 202 |
+
# Will generate a line image by using xy_pts as a mask
|
| 203 |
+
def generate_line_image(img, xy_pts):
|
| 204 |
+
|
| 205 |
+
pts = np.array(xy_pts)
|
| 206 |
+
[min_x, min_y] = np.ceil(np.min(pts, axis=0)).astype(int)
|
| 207 |
+
[max_x, max_y] = np.floor(np.max(pts, axis=0)).astype(int)
|
| 208 |
+
(width, ht) = (max_x-min_x+1, max_y-min_y+1)
|
| 209 |
+
xy_pts = add_offset_to_polygon(xy_pts, (-min_x, -min_y))
|
| 210 |
+
img = img[min_y:max_y+1, min_x:max_x+1, :]
|
| 211 |
+
#print('min_y:max_y+1, min_x:max_x+1, :', min_y, max_y+1, min_x, max_x+1)
|
| 212 |
+
img_obj = Image.fromarray(img)
|
| 213 |
+
|
| 214 |
+
draw_img = ImageDraw.Draw(img_obj)
|
| 215 |
+
# Create a polygonal mask
|
| 216 |
+
mask = Image.fromarray(np.zeros((img.shape[0], img.shape[1])))
|
| 217 |
+
draw_mask = ImageDraw.Draw(mask)
|
| 218 |
+
draw_mask.polygon(xy_pts, fill='white')
|
| 219 |
+
mask = np.array(mask).astype(bool)
|
| 220 |
+
# Choose the polygonal area from image
|
| 221 |
+
|
| 222 |
+
output_img = np.zeros_like(img)+255
|
| 223 |
+
output_img[mask] = img[mask]
|
| 224 |
+
output_img = Image.fromarray(output_img)
|
| 225 |
+
return output_img, xy_pts
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# restrict coordinates to lie between 0 and max (included)
|
| 229 |
+
def restrict_pts(pts, max_p):
|
| 230 |
+
pts = [(max(0, x), max(0, y)) for (x,y) in pts]
|
| 231 |
+
pts = [(min(max_p[0], x), min(max_p[1], y)) for (x,y) in pts]
|
| 232 |
+
return pts
|
| 233 |
+
|
| 234 |
+
# assuming poly_pts is a list of (x,y) tuples
|
| 235 |
+
# This will add more points by interpolating between two points
|
| 236 |
+
def expand_poly(poly_pts, min_x_increment=10):
|
| 237 |
+
poly_pts = np.array(poly_pts).astype(int)
|
| 238 |
+
new_poly = []
|
| 239 |
+
# for ind, (curr, nxt) in enumerate(zip(poly_pts[:-1], poly_pts[1:])):
|
| 240 |
+
|
| 241 |
+
for ind, curr in enumerate(poly_pts):
|
| 242 |
+
nxt = poly_pts[(ind+1)%len(poly_pts)]
|
| 243 |
+
if np.abs(nxt[0] - curr[0]) < min_x_increment:
|
| 244 |
+
new_poly.append(curr)
|
| 245 |
+
new_poly.append(nxt)
|
| 246 |
+
continue
|
| 247 |
+
x1, x2 = curr[0], nxt[0]
|
| 248 |
+
y1, y2 = curr[1], nxt[1]
|
| 249 |
+
new_poly.append(curr)
|
| 250 |
+
increment = -1*min_x_increment if nxt[0] < curr[0] else min_x_increment
|
| 251 |
+
for x in range(curr[0]+1, nxt[0], increment):
|
| 252 |
+
slope = float(y2-y1)/(x2-x1)
|
| 253 |
+
y = slope*(x - x2) + y2
|
| 254 |
+
new_poly.append((x, y))
|
| 255 |
+
new_poly.append(nxt)
|
| 256 |
+
return new_poly
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# Chunks_len ignored if chunk_len_auto is True
|
| 260 |
+
def get_baseline_chunks(poly_pts, chunks_len=300, chunk_len_auto=True):
|
| 261 |
+
baseline = []
|
| 262 |
+
poly_pts = expand_poly(poly_pts)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
p = np.array(poly_pts)
|
| 266 |
+
|
| 267 |
+
max_x, max_y = np.max(p, 0)
|
| 268 |
+
min_x, min_y = np.min(p, 0)
|
| 269 |
+
|
| 270 |
+
# Decide chunks_len
|
| 271 |
+
if chunk_len_auto:
|
| 272 |
+
if (len(poly_pts) >= 250):
|
| 273 |
+
total_chunks = 5
|
| 274 |
+
else:
|
| 275 |
+
total_chunks = int(np.ceil(len(poly_pts)/50))
|
| 276 |
+
chunks_len = int((max_x-min_x)/total_chunks)
|
| 277 |
+
else:
|
| 278 |
+
total_chunks = int((max_x-min_x)/chunks_len)
|
| 279 |
+
|
| 280 |
+
#print('expanded len', len(poly_pts), 'total chunks', total_chunks)
|
| 281 |
+
for i in range(1, total_chunks+1):
|
| 282 |
+
p1 = [pt for pt in p if (pt[0]-min_x)>=(i-1)*chunks_len and (pt[0]-min_x)<i*chunks_len]
|
| 283 |
+
#print(p1)
|
| 284 |
+
if i == total_chunks:
|
| 285 |
+
p1 = [pt for pt in p if (pt[0] - min_x)>=(i-1)*chunks_len]
|
| 286 |
+
b = get_baseline_regression(p1, num_pts=12)
|
| 287 |
+
|
| 288 |
+
# Points are in ascending order (increasing x - left to right)
|
| 289 |
+
if len(baseline) != 0:
|
| 290 |
+
# This will smooth out the line
|
| 291 |
+
# Get rid of last 4 points and connect the point with next 4 points
|
| 292 |
+
baseline = baseline[:-4]
|
| 293 |
+
baseline.extend(b[4:])
|
| 294 |
+
if i == total_chunks:
|
| 295 |
+
baseline.extend(b[-1:])
|
| 296 |
+
else:
|
| 297 |
+
baseline = b
|
| 298 |
+
|
| 299 |
+
# Make sure baseline does not repeat
|
| 300 |
+
|
| 301 |
+
prev_pt = baseline[0]
|
| 302 |
+
baseline_clean = [prev_pt]
|
| 303 |
+
for b in baseline[1:]:
|
| 304 |
+
if b != prev_pt:
|
| 305 |
+
baseline_clean.append(b)
|
| 306 |
+
prev_pt = b
|
| 307 |
+
return baseline_clean
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# Making sure a value is not outside a boundary or has negative value
|
| 311 |
+
def correct_pt(value, max_value):
|
| 312 |
+
if value < 0:
|
| 313 |
+
return 0
|
| 314 |
+
if value > max_value:
|
| 315 |
+
return max_value
|
| 316 |
+
return value
|
| 317 |
+
|
| 318 |
+
# Assume poly is list of (x, y) tuples or [x, y] list
|
| 319 |
+
# Will retrieve a vertically oriented baseline
|
| 320 |
+
# Will return top to bottom and bottom to top if reversed is true
|
| 321 |
+
def get_vertical_baseline(poly, reversed=False):
|
| 322 |
+
flipped_poly = [(y, x) for (x, y) in poly]
|
| 323 |
+
baseline = get_baseline_chunks(flipped_poly)
|
| 324 |
+
# Flip back
|
| 325 |
+
baseline = [(y, x) for (x, y) in baseline]
|
| 326 |
+
if reversed:
|
| 327 |
+
baseline.sort(key=lambda x: x[1], reverse=True)
|
| 328 |
+
return baseline
|
| 329 |
+
|
| 330 |
+
# Check the polygon is valid
|
| 331 |
+
# If x coord or y coord don't change, its not valid
|
| 332 |
+
def valid_poly(poly_pts):
|
| 333 |
+
if len(poly_pts) <= 2:
|
| 334 |
+
return False
|
| 335 |
+
x = [pt[0] for pt in poly_pts]
|
| 336 |
+
y = [pt[1] for pt in poly_pts]
|
| 337 |
+
|
| 338 |
+
if np.max(x) - np.min(x) <= 1e-2:
|
| 339 |
+
return False
|
| 340 |
+
if np.max(y) - np.min(y) <= 1e-2:
|
| 341 |
+
return False
|
| 342 |
+
return True
|
coords/poly_routines.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import json
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import points
|
| 7 |
+
from PIL import Image, ImageDraw
|
| 8 |
+
import os
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from matplotlib.patches import Rectangle
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def correct_pt(value, max_value):
|
| 15 |
+
boundary = False
|
| 16 |
+
if value < 0:
|
| 17 |
+
value = 0
|
| 18 |
+
boundary = True
|
| 19 |
+
if value >= max_value:
|
| 20 |
+
value = max_value - 1
|
| 21 |
+
boundary = True
|
| 22 |
+
return [value, boundary]
|
| 23 |
+
|
| 24 |
+
# Each polygon is a list of (x,y) tuples
|
| 25 |
+
def get_polygon_list_tuples(out):
|
| 26 |
+
img = cv2.imread(out["image_path"])
|
| 27 |
+
img_height, img_width = img.shape[:2]
|
| 28 |
+
polygon_list = []
|
| 29 |
+
prev = [-1, -1]
|
| 30 |
+
for line_ind in range(len(out['lf'][0])):
|
| 31 |
+
polygon = []
|
| 32 |
+
begin_ind = out['beginning'][line_ind]
|
| 33 |
+
end_ind = out['ending'][line_ind]
|
| 34 |
+
begin_ind = int(np.floor(begin_ind))
|
| 35 |
+
end_ind = int(np.ceil(end_ind))
|
| 36 |
+
end_ind = min(end_ind, len(out['lf'])-1)
|
| 37 |
+
for pt_ind in range(begin_ind, end_ind+1):
|
| 38 |
+
pt_x = float(out['lf'][pt_ind][line_ind][0][0])
|
| 39 |
+
pt_y = float(out['lf'][pt_ind][line_ind][1][0])
|
| 40 |
+
pt_x, boundary_x = correct_pt(pt_x, img_width)
|
| 41 |
+
pt_y, boundary_y = correct_pt(pt_y, img_height)
|
| 42 |
+
if prev != [pt_x, pt_y]:
|
| 43 |
+
polygon.append((pt_x, pt_y))
|
| 44 |
+
prev = [pt_x, pt_y]
|
| 45 |
+
for pt_ind in range(end_ind, begin_ind-1, -1):
|
| 46 |
+
pt_x = float(out['lf'][pt_ind][line_ind][0][1])
|
| 47 |
+
pt_y = float(out['lf'][pt_ind][line_ind][1][1])
|
| 48 |
+
pt_x, boundary_x = correct_pt(pt_x, img_width)
|
| 49 |
+
pt_y, boundary_y = correct_pt(pt_y, img_height)
|
| 50 |
+
if prev != [pt_x, pt_y]:
|
| 51 |
+
polygon.append((pt_x, pt_y))
|
| 52 |
+
prev = [pt_x, pt_y]
|
| 53 |
+
|
| 54 |
+
polygon_list.append(polygon)
|
| 55 |
+
if len(polygon) < 3:
|
| 56 |
+
print('WARNING: DEGENERATE POLYGON AT INDEX', len(polygon_list))
|
| 57 |
+
return polygon_list
|
| 58 |
+
|
| 59 |
+
# Each polygon is a list of (x,y) tuples
|
| 60 |
+
def get_polygon_list_without_trim(out):
|
| 61 |
+
img = cv2.imread(out["image_path"])
|
| 62 |
+
img_height, img_width = img.shape[:2]
|
| 63 |
+
polygon_list = []
|
| 64 |
+
for line_ind in range(len(out['lf'][0])):
|
| 65 |
+
polygon = []
|
| 66 |
+
begin_ind = 0
|
| 67 |
+
end_ind = len(out['lf'])-1
|
| 68 |
+
prev = [-1, -1]
|
| 69 |
+
|
| 70 |
+
for pt_ind in range(begin_ind, end_ind+1):
|
| 71 |
+
pt_x = float(out['lf'][pt_ind][line_ind][0][0])
|
| 72 |
+
pt_y = float(out['lf'][pt_ind][line_ind][1][0])
|
| 73 |
+
pt_x, boundary_x = correct_pt(pt_x, img_width)
|
| 74 |
+
pt_y, boundary_y = correct_pt(pt_y, img_height)
|
| 75 |
+
if prev != [pt_x, pt_y]:
|
| 76 |
+
polygon.append((pt_x, pt_y))
|
| 77 |
+
prev = [pt_x, pt_y]
|
| 78 |
+
for pt_ind in range(end_ind, begin_ind-1, -1):
|
| 79 |
+
pt_x = float(out['lf'][pt_ind][line_ind][0][1])
|
| 80 |
+
pt_y = float(out['lf'][pt_ind][line_ind][1][1])
|
| 81 |
+
pt_x, boundary_x = correct_pt(pt_x, img_width)
|
| 82 |
+
pt_y, boundary_y = correct_pt(pt_y, img_height)
|
| 83 |
+
if prev != [pt_x, pt_y]:
|
| 84 |
+
polygon.append((pt_x, pt_y))
|
| 85 |
+
prev = [pt_x, pt_y]
|
| 86 |
+
|
| 87 |
+
if len(polygon) >= 3:
|
| 88 |
+
polygon_list.append(polygon)
|
| 89 |
+
return polygon_list
|
| 90 |
+
|
| 91 |
+
# Each polygon passed as input is a list of (x,y) tuples
|
| 92 |
+
# Same for output
|
| 93 |
+
def percent_intersection(size, poly1, poly2):
|
| 94 |
+
im1 = Image.new(mode="1", size=size)
|
| 95 |
+
draw1 = ImageDraw.Draw(im1)
|
| 96 |
+
draw1.polygon(poly1, fill=1)
|
| 97 |
+
im2 = Image.new(mode="1", size=size)
|
| 98 |
+
draw2 = ImageDraw.Draw(im2)
|
| 99 |
+
draw2.polygon(poly2, fill=1)
|
| 100 |
+
mask1 = np.asarray(im1, dtype=bool)
|
| 101 |
+
mask2 = np.asarray(im2, dtype=bool)
|
| 102 |
+
intersection_mask = mask1 & mask2
|
| 103 |
+
#plt.imshow(intersection)
|
| 104 |
+
intersection_area = intersection_mask.sum()
|
| 105 |
+
percent1 = intersection_area / mask1.sum()
|
| 106 |
+
percent2 = intersection_area / mask2.sum()
|
| 107 |
+
return intersection_area, percent1, percent2
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_poly_no_overlap(img_name, poly_list, threshold=0.6):
|
| 112 |
+
|
| 113 |
+
img=Image.open(img_name)
|
| 114 |
+
size=img.size
|
| 115 |
+
#polygons = [points.list_to_xy(p) for p in poly_list]
|
| 116 |
+
polygons = poly_list
|
| 117 |
+
del_list = []
|
| 118 |
+
current = 0
|
| 119 |
+
next_ind = current+1
|
| 120 |
+
last_deleted = -1
|
| 121 |
+
while next_ind<len(polygons):
|
| 122 |
+
# Check these are not degernate polygons
|
| 123 |
+
if len(polygons[current]) < 3:
|
| 124 |
+
del_list.append(current)
|
| 125 |
+
current, next_ind = (current+1, next_ind+1)
|
| 126 |
+
continue
|
| 127 |
+
if len(polygons[next_ind]) < 3:
|
| 128 |
+
del_list.append(next_ind)
|
| 129 |
+
next_ind += 1
|
| 130 |
+
continue
|
| 131 |
+
# End check
|
| 132 |
+
overlap_area, percent1, percent2 = percent_intersection(size,
|
| 133 |
+
polygons[current],
|
| 134 |
+
polygons[next_ind])
|
| 135 |
+
|
| 136 |
+
if percent1 > threshold or percent2 > threshold:
|
| 137 |
+
to_del = current if percent1 > percent2 else next_ind
|
| 138 |
+
current, next_ind = (current, next_ind+1) if percent1<percent2\
|
| 139 |
+
else (next_ind, next_ind+1)
|
| 140 |
+
del_list.append(to_del)
|
| 141 |
+
last_deleted = to_del
|
| 142 |
+
#print('last deleted', to_del)
|
| 143 |
+
else: # when no overlap is found
|
| 144 |
+
current, next_ind = (current+1, next_ind+1)
|
| 145 |
+
if current <= last_deleted:
|
| 146 |
+
current = last_deleted + 1
|
| 147 |
+
next_ind = current + 1
|
| 148 |
+
all_ind = set(range(len(poly_list)))
|
| 149 |
+
good_ind = all_ind.difference(set(del_list))
|
| 150 |
+
poly_non_overlapping = [poly_list[i] for i in good_ind]
|
| 151 |
+
return del_list, poly_non_overlapping
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def dump_polygons_json(out, polygons = None, filename=None):
|
| 156 |
+
if filename is None:
|
| 157 |
+
filename = out["image_path"][:-3] + "json"
|
| 158 |
+
if polygons is None:
|
| 159 |
+
polygons = get_polygon_list(out)
|
| 160 |
+
lf_dict = {}
|
| 161 |
+
for ind, poly in enumerate(polygons):
|
| 162 |
+
lf_dict['line_' + str(ind+1)] = points.xy_to_list(poly)
|
| 163 |
+
|
| 164 |
+
with open(filename, 'w') as fout:
|
| 165 |
+
json_dumps_str = json.dumps(lf_dict, indent=2)
|
| 166 |
+
#print('....json_dumps_str', json_dumps_str)
|
| 167 |
+
print(json_dumps_str, file=fout)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# won't flip the polygons...only the image
|
| 171 |
+
def draw_image_with_poly(directory, image, poly, convert=True, flip=False):
|
| 172 |
+
img = cv2.imread(os.path.join(directory, image))
|
| 173 |
+
if flip:
|
| 174 |
+
img = cv2.flip(img, 1)
|
| 175 |
+
plt.imshow(img)
|
| 176 |
+
colors = ['red', 'green', 'blue']
|
| 177 |
+
|
| 178 |
+
for ind, p in enumerate(poly):
|
| 179 |
+
if convert:
|
| 180 |
+
p = points.list_to_xy(p)
|
| 181 |
+
points.draw_poly(plt, p, colors[ind%3])
|
| 182 |
+
plt.text(p[-1][0], p[-1][1], str(ind))
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def flip_polygon(img_file, poly_list):
|
| 187 |
+
img = cv2.imread(img_file)
|
| 188 |
+
h, w = img.shape[:2]
|
| 189 |
+
flipped_poly_list = []
|
| 190 |
+
for p in poly_list:
|
| 191 |
+
flipped = [(w-x, y) for (x, y) in p]
|
| 192 |
+
flipped_poly_list.append(flipped)
|
| 193 |
+
return flipped_poly_list
|
| 194 |
+
|
| 195 |
+
|
coords/text_cleaning_routines.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from bidi.algorithm import get_display
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
def correct_brackets(text):
|
| 6 |
+
text = switch_chars(text, '{', '}')
|
| 7 |
+
text = switch_chars(text, '(', ')')
|
| 8 |
+
text = switch_chars(text, '[', ']')
|
| 9 |
+
text = switch_chars(text, '«', '»')
|
| 10 |
+
return text
|
| 11 |
+
|
| 12 |
+
def switch_chars(text, x, y):
|
| 13 |
+
t = list(text)
|
| 14 |
+
ind_x = [i for i,j in enumerate(t) if j==x]
|
| 15 |
+
ind_y = [i for i,j in enumerate(t) if j==y]
|
| 16 |
+
for i in ind_x:
|
| 17 |
+
t[i] = y
|
| 18 |
+
for i in ind_y:
|
| 19 |
+
t[i] = x
|
| 20 |
+
return ''.join(t)
|
| 21 |
+
|
| 22 |
+
def clean_text(input_text):
|
| 23 |
+
cleaned_text = input_text.replace('\u0009', ' ')
|
| 24 |
+
cleaned_text = cleaned_text.replace('\u000A', ' ')
|
| 25 |
+
cleaned_text = cleaned_text.replace('\u00D7', 'x')
|
| 26 |
+
cleaned_text = cleaned_text.replace('\u066A', '%')
|
| 27 |
+
cleaned_text = cleaned_text.replace('\u06f3', '\u0663')
|
| 28 |
+
cleaned_text = cleaned_text.replace('\u06f7', '\u0667')
|
| 29 |
+
cleaned_text = cleaned_text.replace('\u06f9', '\u0669')
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
cleaned_text = cleaned_text.replace('\u2018', "'")
|
| 33 |
+
cleaned_text = cleaned_text.replace('\u2019', "'")
|
| 34 |
+
cleaned_text = cleaned_text.replace('\u201C', '"')
|
| 35 |
+
cleaned_text = cleaned_text.replace('\u201D', '"')
|
| 36 |
+
cleaned_text = cleaned_text.replace('…', '...')
|
| 37 |
+
cleaned_text = cleaned_text.replace('\u2033', "\u064b")
|
| 38 |
+
cleaned_text = cleaned_text.replace('\u2044', '/')
|
| 39 |
+
cleaned_text = cleaned_text.replace('\u2e17', '\u201e')
|
| 40 |
+
pattern = r'[\u2013\u2014]'
|
| 41 |
+
cleaned_text = re.sub(pattern, '-', cleaned_text)
|
| 42 |
+
pattern = r'[●•\xb7]'
|
| 43 |
+
cleaned_text = re.sub(pattern, '.', cleaned_text)
|
| 44 |
+
return cleaned_text
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_char_sets():
|
| 48 |
+
english_lower = range(ord('a'), ord('z')+1)
|
| 49 |
+
english_upper = range(ord('A'), ord('Z')+1)
|
| 50 |
+
|
| 51 |
+
english_numbers = range(ord('0'), ord('9')+1)
|
| 52 |
+
|
| 53 |
+
english_ord = set(english_lower).union(english_upper)
|
| 54 |
+
english_numbers = {chr(c) for c in set(english_numbers)}
|
| 55 |
+
english_alphabet = {chr(c) for c in english_ord}
|
| 56 |
+
|
| 57 |
+
# This includes numerals/digits also
|
| 58 |
+
arabic_unicodes = range(ord("\u0600"), ord("\u06ff")+1)
|
| 59 |
+
arabic_ord = set(arabic_unicodes)
|
| 60 |
+
arabic_chars = {chr(c) for c in arabic_ord}
|
| 61 |
+
arabic_numbers_ord = range(ord("\u0660"), ord("\u0669")+1)
|
| 62 |
+
arabic_digits = {chr(c) for c in arabic_numbers_ord}
|
| 63 |
+
return {'english_alphabet': english_alphabet, 'arabic_unicodes': arabic_chars,
|
| 64 |
+
'latin_digits': english_numbers, 'arabic_digits': arabic_digits}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_clean_visual_order(text):
|
| 68 |
+
charset_dict = get_char_sets()
|
| 69 |
+
text_set = set(text)
|
| 70 |
+
has_english_alphabet = len(text_set.intersection(charset_dict['english_alphabet'])) > 0
|
| 71 |
+
has_latin_digits = len(text_set.intersection(charset_dict['latin_digits'])) > 0
|
| 72 |
+
has_arabic_digits = len(text_set.intersection(charset_dict['arabic_digits'])) > 0
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if has_arabic_digits or has_english_alphabet or has_latin_digits:
|
| 77 |
+
text_visual_order = get_display(text, base_dir='R')[::-1]
|
| 78 |
+
text_visual_order = correct_brackets(text_visual_order)
|
| 79 |
+
else:
|
| 80 |
+
text_visual_order = text
|
| 81 |
+
clean_visual_order = clean_text(text_visual_order)
|
| 82 |
+
return clean_visual_order
|
coords/text_gt.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Extract text from json file
|
| 2 |
+
# As all lines need sorting etc. this is being added to coords folder
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import points
|
| 6 |
+
import json
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import os
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def sort_lines(lines_list):
|
| 13 |
+
|
| 14 |
+
line_starts = [line['baseline'][0] for line in lines_list]
|
| 15 |
+
sorted_starts = sorted(enumerate(line_starts), key=lambda x: (x[1][1], -x[1][0]))
|
| 16 |
+
sorted_start_ind = [x[0] for x in sorted_starts]
|
| 17 |
+
sorted_lines = [lines_list[i] for i in sorted_start_ind]
|
| 18 |
+
|
| 19 |
+
return sorted_lines
|
| 20 |
+
|
| 21 |
+
def is_valid_key(key, json_obj):
|
| 22 |
+
if not key.lower().startswith('line_'):
|
| 23 |
+
return False
|
| 24 |
+
if "deleted" in json_obj[key].keys() and json_obj[key]["deleted"] != "0":
|
| 25 |
+
return False
|
| 26 |
+
# No text field
|
| 27 |
+
if not 'text' in json_obj[key]:
|
| 28 |
+
return False
|
| 29 |
+
# Text empty
|
| 30 |
+
json_obj[key]['text'] = json_obj[key]['text'].replace('\t', ' ')
|
| 31 |
+
if json_obj[key]['text'].strip() == "":
|
| 32 |
+
return False
|
| 33 |
+
return True
|
| 34 |
+
|
| 35 |
+
def get_text(json_file, return_list=False):
|
| 36 |
+
with open(json_file) as fin:
|
| 37 |
+
json_obj = json.load(fin)
|
| 38 |
+
to_remove_ind = []
|
| 39 |
+
# Get keys in json_obj
|
| 40 |
+
keys = list(json_obj.keys())
|
| 41 |
+
# Get list of line objects
|
| 42 |
+
lines = [json_obj[k] for k in keys if is_valid_key(k, json_obj)]
|
| 43 |
+
# Get baseline of each line
|
| 44 |
+
for ind, line in enumerate(lines):
|
| 45 |
+
poly_pts = line["coord"]
|
| 46 |
+
poly_pts = points.list_to_xy(poly_pts)
|
| 47 |
+
if len(poly_pts) <= 2:
|
| 48 |
+
print(json_file, len(poly_pts))
|
| 49 |
+
to_remove_ind.append(ind)
|
| 50 |
+
if not points.valid_poly(poly_pts):
|
| 51 |
+
to_remove_ind.append(ind)
|
| 52 |
+
continue
|
| 53 |
+
try:
|
| 54 |
+
baseline = points.get_baseline_chunks(poly_pts)
|
| 55 |
+
baseline.sort(key=lambda x: x[0], reverse=True)
|
| 56 |
+
line['baseline'] = baseline
|
| 57 |
+
except Exception as e:
|
| 58 |
+
#print(len(poly_pts))
|
| 59 |
+
#print(poly_pts)
|
| 60 |
+
#print(json_file)
|
| 61 |
+
to_remove_ind.append(ind)
|
| 62 |
+
|
| 63 |
+
# REmove the lines causing exception
|
| 64 |
+
cleaned_lines = [lines[ind] for ind in range(len(lines)) if not ind in to_remove_ind]
|
| 65 |
+
# Sort the lines
|
| 66 |
+
sorted_lines = sort_lines(cleaned_lines)
|
| 67 |
+
|
| 68 |
+
text = []
|
| 69 |
+
for l in sorted_lines:
|
| 70 |
+
text.append(l["text"])
|
| 71 |
+
if return_list:
|
| 72 |
+
return text
|
| 73 |
+
return '\n'.join(text)
|
| 74 |
+
|
| 75 |
+
def get_json_file(img_fullname):
|
| 76 |
+
dir, img_name = os.path.split(img_fullname)
|
| 77 |
+
json_files = []
|
| 78 |
+
#annotators = []
|
| 79 |
+
base_file = img_name[:-4]
|
| 80 |
+
files = os.listdir(dir)
|
| 81 |
+
for f in files:
|
| 82 |
+
prefix = base_file + '_annotate_'
|
| 83 |
+
if f.startswith(prefix):
|
| 84 |
+
|
| 85 |
+
# Check if its a timestamp in filename
|
| 86 |
+
partial_string = f[len(prefix):]
|
| 87 |
+
ind1 = partial_string.rfind('.')
|
| 88 |
+
ind2 = partial_string.find('.')
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if (ind1 == ind2):
|
| 92 |
+
json_files.append(f)
|
| 93 |
+
|
| 94 |
+
if len(json_files) > 1:
|
| 95 |
+
print('More than one json found...returning 0th one', json_files)
|
| 96 |
+
if len(json_files) == 0:
|
| 97 |
+
print('No json found')
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
return os.path.join(dir, json_files[0])
|
| 101 |
+
|
model/trial_26_A/muharaf_charset.json
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"char_to_idx": {
|
| 3 |
+
" ": 1,
|
| 4 |
+
"!": 2,
|
| 5 |
+
"\"": 3,
|
| 6 |
+
"#": 4,
|
| 7 |
+
"$": 5,
|
| 8 |
+
"%": 6,
|
| 9 |
+
"&": 7,
|
| 10 |
+
"'": 8,
|
| 11 |
+
"(": 9,
|
| 12 |
+
")": 10,
|
| 13 |
+
"*": 11,
|
| 14 |
+
"+": 12,
|
| 15 |
+
",": 13,
|
| 16 |
+
"-": 14,
|
| 17 |
+
".": 15,
|
| 18 |
+
"/": 16,
|
| 19 |
+
"0": 17,
|
| 20 |
+
"1": 18,
|
| 21 |
+
"2": 19,
|
| 22 |
+
"3": 20,
|
| 23 |
+
"4": 21,
|
| 24 |
+
"5": 22,
|
| 25 |
+
"6": 23,
|
| 26 |
+
"7": 24,
|
| 27 |
+
"8": 25,
|
| 28 |
+
"9": 26,
|
| 29 |
+
":": 27,
|
| 30 |
+
"=": 28,
|
| 31 |
+
"A": 29,
|
| 32 |
+
"B": 30,
|
| 33 |
+
"C": 31,
|
| 34 |
+
"D": 32,
|
| 35 |
+
"E": 33,
|
| 36 |
+
"F": 34,
|
| 37 |
+
"G": 35,
|
| 38 |
+
"H": 36,
|
| 39 |
+
"I": 37,
|
| 40 |
+
"J": 38,
|
| 41 |
+
"K": 39,
|
| 42 |
+
"L": 40,
|
| 43 |
+
"M": 41,
|
| 44 |
+
"N": 42,
|
| 45 |
+
"O": 43,
|
| 46 |
+
"P": 44,
|
| 47 |
+
"Q": 45,
|
| 48 |
+
"R": 46,
|
| 49 |
+
"S": 47,
|
| 50 |
+
"T": 48,
|
| 51 |
+
"U": 49,
|
| 52 |
+
"V": 50,
|
| 53 |
+
"W": 51,
|
| 54 |
+
"X": 52,
|
| 55 |
+
"Y": 53,
|
| 56 |
+
"Z": 54,
|
| 57 |
+
"[": 55,
|
| 58 |
+
"\\": 56,
|
| 59 |
+
"]": 57,
|
| 60 |
+
"_": 58,
|
| 61 |
+
"a": 59,
|
| 62 |
+
"b": 60,
|
| 63 |
+
"c": 61,
|
| 64 |
+
"d": 62,
|
| 65 |
+
"e": 63,
|
| 66 |
+
"f": 64,
|
| 67 |
+
"g": 65,
|
| 68 |
+
"h": 66,
|
| 69 |
+
"i": 67,
|
| 70 |
+
"j": 68,
|
| 71 |
+
"k": 69,
|
| 72 |
+
"l": 70,
|
| 73 |
+
"m": 71,
|
| 74 |
+
"n": 72,
|
| 75 |
+
"o": 73,
|
| 76 |
+
"p": 74,
|
| 77 |
+
"q": 75,
|
| 78 |
+
"r": 76,
|
| 79 |
+
"s": 77,
|
| 80 |
+
"t": 78,
|
| 81 |
+
"u": 79,
|
| 82 |
+
"v": 80,
|
| 83 |
+
"x": 81,
|
| 84 |
+
"y": 82,
|
| 85 |
+
"z": 83,
|
| 86 |
+
"|": 84,
|
| 87 |
+
"\u00ba": 85,
|
| 88 |
+
"\u00c3": 86,
|
| 89 |
+
"\u00c8": 87,
|
| 90 |
+
"\u00c9": 88,
|
| 91 |
+
"\u00ca": 89,
|
| 92 |
+
"\u00e0": 90,
|
| 93 |
+
"\u00e7": 91,
|
| 94 |
+
"\u00e8": 92,
|
| 95 |
+
"\u00e9": 93,
|
| 96 |
+
"\u00ea": 94,
|
| 97 |
+
"\u060c": 95,
|
| 98 |
+
"\u061b": 96,
|
| 99 |
+
"\u061f": 97,
|
| 100 |
+
"\u0621": 98,
|
| 101 |
+
"\u0622": 99,
|
| 102 |
+
"\u0623": 100,
|
| 103 |
+
"\u0624": 101,
|
| 104 |
+
"\u0625": 102,
|
| 105 |
+
"\u0626": 103,
|
| 106 |
+
"\u0627": 104,
|
| 107 |
+
"\u0628": 105,
|
| 108 |
+
"\u0629": 106,
|
| 109 |
+
"\u062a": 107,
|
| 110 |
+
"\u062b": 108,
|
| 111 |
+
"\u062c": 109,
|
| 112 |
+
"\u062d": 110,
|
| 113 |
+
"\u062e": 111,
|
| 114 |
+
"\u062f": 112,
|
| 115 |
+
"\u0630": 113,
|
| 116 |
+
"\u0631": 114,
|
| 117 |
+
"\u0632": 115,
|
| 118 |
+
"\u0633": 116,
|
| 119 |
+
"\u0634": 117,
|
| 120 |
+
"\u0635": 118,
|
| 121 |
+
"\u0636": 119,
|
| 122 |
+
"\u0637": 120,
|
| 123 |
+
"\u0638": 121,
|
| 124 |
+
"\u0639": 122,
|
| 125 |
+
"\u063a": 123,
|
| 126 |
+
"\u0640": 124,
|
| 127 |
+
"\u0641": 125,
|
| 128 |
+
"\u0642": 126,
|
| 129 |
+
"\u0643": 127,
|
| 130 |
+
"\u0644": 128,
|
| 131 |
+
"\u0645": 129,
|
| 132 |
+
"\u0646": 130,
|
| 133 |
+
"\u0647": 131,
|
| 134 |
+
"\u0648": 132,
|
| 135 |
+
"\u0649": 133,
|
| 136 |
+
"\u064a": 134,
|
| 137 |
+
"\u064b": 135,
|
| 138 |
+
"\u064c": 136,
|
| 139 |
+
"\u064d": 137,
|
| 140 |
+
"\u064e": 138,
|
| 141 |
+
"\u064f": 139,
|
| 142 |
+
"\u0650": 140,
|
| 143 |
+
"\u0651": 141,
|
| 144 |
+
"\u0652": 142,
|
| 145 |
+
"\u0660": 143,
|
| 146 |
+
"\u0661": 144,
|
| 147 |
+
"\u0662": 145,
|
| 148 |
+
"\u0663": 146,
|
| 149 |
+
"\u0664": 147,
|
| 150 |
+
"\u0665": 148,
|
| 151 |
+
"\u0666": 149,
|
| 152 |
+
"\u0667": 150,
|
| 153 |
+
"\u0668": 151,
|
| 154 |
+
"\u0669": 152,
|
| 155 |
+
"\u06a4": 153,
|
| 156 |
+
"\u06a8": 154,
|
| 157 |
+
"\u201e": 155,
|
| 158 |
+
"\ufb6c": 156,
|
| 159 |
+
"\ufc63": 157
|
| 160 |
+
},
|
| 161 |
+
"idx_to_char": {
|
| 162 |
+
"1": " ",
|
| 163 |
+
"2": "!",
|
| 164 |
+
"3": "\"",
|
| 165 |
+
"4": "#",
|
| 166 |
+
"5": "$",
|
| 167 |
+
"6": "%",
|
| 168 |
+
"7": "&",
|
| 169 |
+
"8": "'",
|
| 170 |
+
"9": "(",
|
| 171 |
+
"10": ")",
|
| 172 |
+
"11": "*",
|
| 173 |
+
"12": "+",
|
| 174 |
+
"13": ",",
|
| 175 |
+
"14": "-",
|
| 176 |
+
"15": ".",
|
| 177 |
+
"16": "/",
|
| 178 |
+
"17": "0",
|
| 179 |
+
"18": "1",
|
| 180 |
+
"19": "2",
|
| 181 |
+
"20": "3",
|
| 182 |
+
"21": "4",
|
| 183 |
+
"22": "5",
|
| 184 |
+
"23": "6",
|
| 185 |
+
"24": "7",
|
| 186 |
+
"25": "8",
|
| 187 |
+
"26": "9",
|
| 188 |
+
"27": ":",
|
| 189 |
+
"28": "=",
|
| 190 |
+
"29": "A",
|
| 191 |
+
"30": "B",
|
| 192 |
+
"31": "C",
|
| 193 |
+
"32": "D",
|
| 194 |
+
"33": "E",
|
| 195 |
+
"34": "F",
|
| 196 |
+
"35": "G",
|
| 197 |
+
"36": "H",
|
| 198 |
+
"37": "I",
|
| 199 |
+
"38": "J",
|
| 200 |
+
"39": "K",
|
| 201 |
+
"40": "L",
|
| 202 |
+
"41": "M",
|
| 203 |
+
"42": "N",
|
| 204 |
+
"43": "O",
|
| 205 |
+
"44": "P",
|
| 206 |
+
"45": "Q",
|
| 207 |
+
"46": "R",
|
| 208 |
+
"47": "S",
|
| 209 |
+
"48": "T",
|
| 210 |
+
"49": "U",
|
| 211 |
+
"50": "V",
|
| 212 |
+
"51": "W",
|
| 213 |
+
"52": "X",
|
| 214 |
+
"53": "Y",
|
| 215 |
+
"54": "Z",
|
| 216 |
+
"55": "[",
|
| 217 |
+
"56": "\\",
|
| 218 |
+
"57": "]",
|
| 219 |
+
"58": "_",
|
| 220 |
+
"59": "a",
|
| 221 |
+
"60": "b",
|
| 222 |
+
"61": "c",
|
| 223 |
+
"62": "d",
|
| 224 |
+
"63": "e",
|
| 225 |
+
"64": "f",
|
| 226 |
+
"65": "g",
|
| 227 |
+
"66": "h",
|
| 228 |
+
"67": "i",
|
| 229 |
+
"68": "j",
|
| 230 |
+
"69": "k",
|
| 231 |
+
"70": "l",
|
| 232 |
+
"71": "m",
|
| 233 |
+
"72": "n",
|
| 234 |
+
"73": "o",
|
| 235 |
+
"74": "p",
|
| 236 |
+
"75": "q",
|
| 237 |
+
"76": "r",
|
| 238 |
+
"77": "s",
|
| 239 |
+
"78": "t",
|
| 240 |
+
"79": "u",
|
| 241 |
+
"80": "v",
|
| 242 |
+
"81": "x",
|
| 243 |
+
"82": "y",
|
| 244 |
+
"83": "z",
|
| 245 |
+
"84": "|",
|
| 246 |
+
"85": "\u00ba",
|
| 247 |
+
"86": "\u00c3",
|
| 248 |
+
"87": "\u00c8",
|
| 249 |
+
"88": "\u00c9",
|
| 250 |
+
"89": "\u00ca",
|
| 251 |
+
"90": "\u00e0",
|
| 252 |
+
"91": "\u00e7",
|
| 253 |
+
"92": "\u00e8",
|
| 254 |
+
"93": "\u00e9",
|
| 255 |
+
"94": "\u00ea",
|
| 256 |
+
"95": "\u060c",
|
| 257 |
+
"96": "\u061b",
|
| 258 |
+
"97": "\u061f",
|
| 259 |
+
"98": "\u0621",
|
| 260 |
+
"99": "\u0622",
|
| 261 |
+
"100": "\u0623",
|
| 262 |
+
"101": "\u0624",
|
| 263 |
+
"102": "\u0625",
|
| 264 |
+
"103": "\u0626",
|
| 265 |
+
"104": "\u0627",
|
| 266 |
+
"105": "\u0628",
|
| 267 |
+
"106": "\u0629",
|
| 268 |
+
"107": "\u062a",
|
| 269 |
+
"108": "\u062b",
|
| 270 |
+
"109": "\u062c",
|
| 271 |
+
"110": "\u062d",
|
| 272 |
+
"111": "\u062e",
|
| 273 |
+
"112": "\u062f",
|
| 274 |
+
"113": "\u0630",
|
| 275 |
+
"114": "\u0631",
|
| 276 |
+
"115": "\u0632",
|
| 277 |
+
"116": "\u0633",
|
| 278 |
+
"117": "\u0634",
|
| 279 |
+
"118": "\u0635",
|
| 280 |
+
"119": "\u0636",
|
| 281 |
+
"120": "\u0637",
|
| 282 |
+
"121": "\u0638",
|
| 283 |
+
"122": "\u0639",
|
| 284 |
+
"123": "\u063a",
|
| 285 |
+
"124": "\u0640",
|
| 286 |
+
"125": "\u0641",
|
| 287 |
+
"126": "\u0642",
|
| 288 |
+
"127": "\u0643",
|
| 289 |
+
"128": "\u0644",
|
| 290 |
+
"129": "\u0645",
|
| 291 |
+
"130": "\u0646",
|
| 292 |
+
"131": "\u0647",
|
| 293 |
+
"132": "\u0648",
|
| 294 |
+
"133": "\u0649",
|
| 295 |
+
"134": "\u064a",
|
| 296 |
+
"135": "\u064b",
|
| 297 |
+
"136": "\u064c",
|
| 298 |
+
"137": "\u064d",
|
| 299 |
+
"138": "\u064e",
|
| 300 |
+
"139": "\u064f",
|
| 301 |
+
"140": "\u0650",
|
| 302 |
+
"141": "\u0651",
|
| 303 |
+
"142": "\u0652",
|
| 304 |
+
"143": "\u0660",
|
| 305 |
+
"144": "\u0661",
|
| 306 |
+
"145": "\u0662",
|
| 307 |
+
"146": "\u0663",
|
| 308 |
+
"147": "\u0664",
|
| 309 |
+
"148": "\u0665",
|
| 310 |
+
"149": "\u0666",
|
| 311 |
+
"150": "\u0667",
|
| 312 |
+
"151": "\u0668",
|
| 313 |
+
"152": "\u0669",
|
| 314 |
+
"153": "\u06a4",
|
| 315 |
+
"154": "\u06a8",
|
| 316 |
+
"155": "\u201e",
|
| 317 |
+
"156": "\ufb6c",
|
| 318 |
+
"157": "\ufc63"
|
| 319 |
+
}
|
| 320 |
+
}
|
model/trial_26_A/set0/config_2600.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
snapshot_path: model/trial_26_A/set0/pretrain/
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
network:
|
| 5 |
+
hw:
|
| 6 |
+
char_set_path: model/trial_26_A/muharaf_charset.json
|
| 7 |
+
cnn_out_size: 1024
|
| 8 |
+
input_height: 60
|
| 9 |
+
num_of_channels: 3
|
| 10 |
+
num_of_outputs: 158
|
| 11 |
+
use_instance_norm: true
|
| 12 |
+
lf:
|
| 13 |
+
look_ahead_matrix: null
|
| 14 |
+
step_bias: null
|
| 15 |
+
|
| 16 |
+
sol:
|
| 17 |
+
base0: 16
|
| 18 |
+
base1: 16
|
| 19 |
+
post_processing:
|
| 20 |
+
lf_nms_range:
|
| 21 |
+
- 0
|
| 22 |
+
- 6
|
| 23 |
+
lf_nms_threshold: 0.5
|
| 24 |
+
sol_threshold: 0.1
|
| 25 |
+
|
model/trial_26_A/set0/pretrain/hw.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5e1f8a27a6fc5e64c1c93b72b18200db7b6c240053d3e016ad138b4cd3521b08
|
| 3 |
+
size 73251954
|
model/trial_26_A/set0/pretrain/lf.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1b7082953c1269b67ba0801311604884255a4fe4bdef821dcf1ce169ab582d6f
|
| 3 |
+
size 22228762
|
model/trial_26_A/set0/pretrain/sol.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e94e56959c5175067734e00d7e2e02ff05018f37eb91d726f2319504df60c832
|
| 3 |
+
size 36979951
|
py3/e2e/__init__.py
ADDED
|
File without changes
|
py3/e2e/alignment_dataset.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import random
|
| 8 |
+
from utils import safe_load
|
| 9 |
+
|
| 10 |
+
def collate(batch):
|
| 11 |
+
return batch
|
| 12 |
+
|
| 13 |
+
class AlignmentDataset(Dataset):
|
| 14 |
+
|
| 15 |
+
def __init__(self, set_list, data_range=None, ignore_json=False, resize_width=512):
|
| 16 |
+
|
| 17 |
+
self.ignore_json = ignore_json
|
| 18 |
+
|
| 19 |
+
self.resize_width = resize_width
|
| 20 |
+
|
| 21 |
+
self.ids = set_list
|
| 22 |
+
self.ids.sort()
|
| 23 |
+
|
| 24 |
+
if data_range is not None:
|
| 25 |
+
self.ids = random.sample(self.ids, data_range)
|
| 26 |
+
|
| 27 |
+
print("Alignment Ids Count:", len(self.ids))
|
| 28 |
+
|
| 29 |
+
def __len__(self):
|
| 30 |
+
return len(self.ids)
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, idx):
|
| 33 |
+
|
| 34 |
+
gt_json_path, img_path = self.ids[idx]
|
| 35 |
+
|
| 36 |
+
gt_json = []
|
| 37 |
+
if not self.ignore_json:
|
| 38 |
+
gt_json = safe_load.json_state(gt_json_path)
|
| 39 |
+
if gt_json is None:
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
org_img = cv2.imread(img_path)
|
| 43 |
+
|
| 44 |
+
full_img = org_img.astype(np.float32)
|
| 45 |
+
full_img = full_img.transpose([2,1,0])[None,...]
|
| 46 |
+
full_img = torch.from_numpy(full_img)
|
| 47 |
+
full_img = full_img / 128 - 1
|
| 48 |
+
|
| 49 |
+
target_dim1 = self.resize_width
|
| 50 |
+
s = target_dim1 / float(org_img.shape[1])
|
| 51 |
+
target_dim0 = int(org_img.shape[0]/float(org_img.shape[1]) * target_dim1)
|
| 52 |
+
|
| 53 |
+
img = cv2.resize(org_img,(target_dim1, target_dim0), interpolation = cv2.INTER_CUBIC)
|
| 54 |
+
img = img.astype(np.float32)
|
| 55 |
+
img = img.transpose([2,1,0])[None,...]
|
| 56 |
+
img = torch.from_numpy(img)
|
| 57 |
+
img = img / 128 - 1
|
| 58 |
+
|
| 59 |
+
image_key = gt_json_path[:-len('.json')]
|
| 60 |
+
|
| 61 |
+
return {
|
| 62 |
+
"resized_img": img,
|
| 63 |
+
"full_img": full_img,
|
| 64 |
+
"resize_scale": 1.0/s,
|
| 65 |
+
"gt_lines": [x['gt'] for x in gt_json],
|
| 66 |
+
"img_key": image_key,
|
| 67 |
+
"json_path": gt_json_path,
|
| 68 |
+
"gt_json": gt_json
|
| 69 |
+
}
|
py3/e2e/e2e_model.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.autograd import Variable
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
from utils import string_utils, error_rates
|
| 8 |
+
from utils import transformation_utils
|
| 9 |
+
from . import handwriting_alignment_loss
|
| 10 |
+
|
| 11 |
+
from . import e2e_postprocessing
|
| 12 |
+
|
| 13 |
+
import copy
|
| 14 |
+
from scipy.optimize import linear_sum_assignment
|
| 15 |
+
import math
|
| 16 |
+
#from pynvml import *
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# max_lines_per_image is the max lines in a batch for HW to process
|
| 21 |
+
class E2EModel(nn.Module):
|
| 22 |
+
def __init__(self, sol, lf, hw, dtype=torch.cuda.FloatTensor, max_lines_per_image=8, device="cuda"):
|
| 23 |
+
super(E2EModel, self).__init__()
|
| 24 |
+
|
| 25 |
+
self.dtype = dtype
|
| 26 |
+
|
| 27 |
+
self.sol = sol
|
| 28 |
+
self.lf = lf
|
| 29 |
+
self.hw = hw
|
| 30 |
+
self.line = None
|
| 31 |
+
self.max_lines_per_image = max_lines_per_image
|
| 32 |
+
|
| 33 |
+
self.device=device
|
| 34 |
+
|
| 35 |
+
def train(self):
|
| 36 |
+
self.sol.train()
|
| 37 |
+
self.lf.train()
|
| 38 |
+
self.hw.train()
|
| 39 |
+
|
| 40 |
+
def eval(self):
|
| 41 |
+
self.sol.eval()
|
| 42 |
+
self.lf.eval()
|
| 43 |
+
self.hw.eval()
|
| 44 |
+
|
| 45 |
+
def forward(self, x, use_full_img=True, accpet_threshold=0.1, volatile=True, gt_lines=None,
|
| 46 |
+
idx_to_char=None, HW_cuda=0, device="cuda"):
|
| 47 |
+
|
| 48 |
+
if device != self.device:
|
| 49 |
+
print('Wrong device is set', 'param', device, 'self', self.device)
|
| 50 |
+
asldjfdkfj
|
| 51 |
+
|
| 52 |
+
sol_img = Variable(x['resized_img'].type(self.dtype), requires_grad=False)
|
| 53 |
+
|
| 54 |
+
if use_full_img:
|
| 55 |
+
img = Variable(x['full_img'].type(self.dtype), requires_grad=False)
|
| 56 |
+
scale = x['resize_scale']
|
| 57 |
+
results_scale = 1.0
|
| 58 |
+
else:
|
| 59 |
+
img = sol_img
|
| 60 |
+
scale = 1.0
|
| 61 |
+
results_scale = x['resize_scale']
|
| 62 |
+
|
| 63 |
+
original_starts = self.sol(sol_img)
|
| 64 |
+
|
| 65 |
+
start = original_starts
|
| 66 |
+
|
| 67 |
+
#Take at least one point
|
| 68 |
+
sorted_start, sorted_indices = torch.sort(start[...,0:1], dim=1, descending=True)
|
| 69 |
+
#print("sorted_start size", sorted_start.size())
|
| 70 |
+
#print("sorted_start", sorted_start)
|
| 71 |
+
min_threshold = sorted_start[0,1,0].data
|
| 72 |
+
accpet_threshold = min(accpet_threshold, min_threshold)
|
| 73 |
+
# There should not be more than 56 points to avoid out of memory
|
| 74 |
+
if sorted_start.size()[1] > 56:
|
| 75 |
+
accpet_threshold = max(accpet_threshold, sorted_start[0,55,0].data)
|
| 76 |
+
#print('using accept_threshold', accpet_threshold, sorted_start[0,55,0].data)
|
| 77 |
+
select = original_starts[...,0:1] >= accpet_threshold
|
| 78 |
+
|
| 79 |
+
select_idx = np.where(select.data.cpu().numpy())[1]
|
| 80 |
+
|
| 81 |
+
select = select.expand(select.size(0), select.size(1), start.size(2))
|
| 82 |
+
select = select.detach()
|
| 83 |
+
start = start[select].view(start.size(0), -1, start.size(2))
|
| 84 |
+
|
| 85 |
+
perform_forward = len(start.size()) == 3
|
| 86 |
+
|
| 87 |
+
if not perform_forward:
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
forward_img = img
|
| 91 |
+
|
| 92 |
+
start = start.transpose(0,1)
|
| 93 |
+
|
| 94 |
+
positions = torch.cat([
|
| 95 |
+
start[...,1:3] * scale,
|
| 96 |
+
start[...,3:4],
|
| 97 |
+
start[...,4:5] * scale,
|
| 98 |
+
start[...,0:1]
|
| 99 |
+
], 2)
|
| 100 |
+
|
| 101 |
+
#print('positions size', positions.size())
|
| 102 |
+
hw_out = []
|
| 103 |
+
p_interval = positions.size(0)
|
| 104 |
+
lf_xy_positions = None
|
| 105 |
+
line_imgs = []
|
| 106 |
+
# show_mem_status(1, "before for in FORWARD")
|
| 107 |
+
for p in range(0,min(positions.size(0), np.inf), p_interval):
|
| 108 |
+
sub_positions = positions[p:p+p_interval,0,:]
|
| 109 |
+
sub_select_idx = select_idx[p:p+p_interval]
|
| 110 |
+
|
| 111 |
+
batch_size = sub_positions.size(0)
|
| 112 |
+
sub_positions = [sub_positions]
|
| 113 |
+
# print(sub_positions)
|
| 114 |
+
# sys.exit()
|
| 115 |
+
|
| 116 |
+
expand_img = forward_img.expand(sub_positions[0].size(0), img.size(1), img.size(2), img.size(3))
|
| 117 |
+
|
| 118 |
+
step_size = 8 #5
|
| 119 |
+
extra_bw = 1 #1
|
| 120 |
+
forward_steps = 30 #40
|
| 121 |
+
|
| 122 |
+
grid_line, _, out_positions, xy_positions = self.lf(expand_img, sub_positions, steps=step_size)
|
| 123 |
+
grid_line, _, out_positions, xy_positions = self.lf(expand_img, [out_positions[step_size]], steps=step_size+extra_bw, negate_lw=True)
|
| 124 |
+
grid_line, _, out_positions, xy_positions = self.lf(expand_img, [out_positions[step_size+extra_bw]], steps=forward_steps, allow_end_early=True)
|
| 125 |
+
|
| 126 |
+
#show_mem_status(1, 'after lf')
|
| 127 |
+
|
| 128 |
+
if lf_xy_positions is None:
|
| 129 |
+
lf_xy_positions = xy_positions
|
| 130 |
+
else:
|
| 131 |
+
for i in range(len(lf_xy_positions)):
|
| 132 |
+
lf_xy_positions[i] = torch.cat([
|
| 133 |
+
lf_xy_positions[i],
|
| 134 |
+
xy_positions[i]
|
| 135 |
+
])
|
| 136 |
+
expand_img = expand_img.transpose(2,3)
|
| 137 |
+
|
| 138 |
+
hw_interval = p_interval
|
| 139 |
+
for h in range(0,min(grid_line.size(0), np.inf), hw_interval):
|
| 140 |
+
sub_out_positions = [o[h:h+hw_interval] for o in out_positions]
|
| 141 |
+
sub_xy_positions = [o[h:h+hw_interval] for o in xy_positions]
|
| 142 |
+
sub_sub_select_idx = sub_select_idx[h:h+hw_interval]
|
| 143 |
+
|
| 144 |
+
line = torch.nn.functional.grid_sample(expand_img[h:h+hw_interval].detach(), grid_line[h:h+hw_interval], align_corners=True)
|
| 145 |
+
line = line.transpose(2,3)
|
| 146 |
+
|
| 147 |
+
for l in line:
|
| 148 |
+
l = l.transpose(0,1).transpose(1,2)
|
| 149 |
+
l = (l + 1)*128
|
| 150 |
+
l_np = l.data.cpu().numpy()
|
| 151 |
+
line_imgs.append(l_np)
|
| 152 |
+
# cv2.imwrite("example_line_out.png", l_np)
|
| 153 |
+
# print "Saved!"
|
| 154 |
+
# raw_input()
|
| 155 |
+
|
| 156 |
+
# REsize to 60 ht
|
| 157 |
+
|
| 158 |
+
# Mehreen add: To avoid out of memory errors. A large batch has to be split up for HW network to process
|
| 159 |
+
# This case will arise when SOL finds too many lines on a page
|
| 160 |
+
batch, channels, old_ht, old_width = line.size()
|
| 161 |
+
line = line.detach().cpu()
|
| 162 |
+
total_todo = batch
|
| 163 |
+
#show_mem_status(0, '.... Before hw line')
|
| 164 |
+
|
| 165 |
+
start_index = 0
|
| 166 |
+
while total_todo > 0:
|
| 167 |
+
mini_batch_size = min(self.max_lines_per_image, total_todo)
|
| 168 |
+
partial_lines = line[start_index:start_index+mini_batch_size, :, :, :]
|
| 169 |
+
#print('start_index, end_index', start_index, start_index+mini_batch_size)
|
| 170 |
+
start_index += mini_batch_size
|
| 171 |
+
total_todo = total_todo - mini_batch_size
|
| 172 |
+
#print('partial_line size', partial_lines.size())
|
| 173 |
+
partial_lines = partial_lines.to(self.device)
|
| 174 |
+
out = self.hw(partial_lines)
|
| 175 |
+
if "cuda" in device:
|
| 176 |
+
torch.cuda.empty_cache()
|
| 177 |
+
out = out.transpose(0, 1)
|
| 178 |
+
hw_out.append(out)
|
| 179 |
+
|
| 180 |
+
#print('batch size: ', batch)
|
| 181 |
+
# new_ht = 60
|
| 182 |
+
# new_width = int(old_width/old_ht*new_ht)
|
| 183 |
+
#print('line type', type(line), line.size())
|
| 184 |
+
# self.line = nn.functional.interpolate(line, size=(new_ht, new_width),
|
| 185 |
+
# mode='bilinear', align_corners=True)
|
| 186 |
+
# Mehreen commented out for processing entire batch in one go
|
| 187 |
+
|
| 188 |
+
# out = self.hw(line)
|
| 189 |
+
# out = out.transpose(0,1)
|
| 190 |
+
|
| 191 |
+
# hw_out.append(out)
|
| 192 |
+
#show_mem_status(0, '.... After hw line')
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
hw_out = torch.cat(hw_out, 0)
|
| 197 |
+
# print(original_starts,positions,lf_xy_positions,hw_out,results_scale,line_imgs)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
return {
|
| 201 |
+
"original_sol": original_starts,
|
| 202 |
+
"sol": positions,
|
| 203 |
+
"lf": lf_xy_positions,
|
| 204 |
+
"hw": hw_out,
|
| 205 |
+
"results_scale": results_scale,
|
| 206 |
+
"line_imgs": line_imgs
|
| 207 |
+
}
|
py3/e2e/e2e_postprocessing.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import string_utils, error_rates
|
| 2 |
+
import numpy as np
|
| 3 |
+
from . import nms
|
| 4 |
+
import copy
|
| 5 |
+
|
| 6 |
+
def get_trimmed_polygons(out):
|
| 7 |
+
all_polygons = []
|
| 8 |
+
for j in range(out['lf'][0].shape[0]):
|
| 9 |
+
begin = out['beginning'][j]
|
| 10 |
+
end = out['ending'][j]
|
| 11 |
+
last_xy = None
|
| 12 |
+
begin_f = int(np.floor(begin))
|
| 13 |
+
end_f = int(np.ceil(end))
|
| 14 |
+
points = []
|
| 15 |
+
for i in range(begin_f, end_f+1):
|
| 16 |
+
|
| 17 |
+
if i == begin_f:
|
| 18 |
+
p0 = out['lf'][i][j]
|
| 19 |
+
p1 = out['lf'][i+1][j]
|
| 20 |
+
t = begin - np.floor(begin)
|
| 21 |
+
p = p0 * (1 - t) + p1 * t
|
| 22 |
+
|
| 23 |
+
elif i == end_f:
|
| 24 |
+
|
| 25 |
+
p0 = out['lf'][i-1][j]
|
| 26 |
+
if i != len(out['lf']):
|
| 27 |
+
p1 = out['lf'][i][j]
|
| 28 |
+
t = end - np.floor(end)
|
| 29 |
+
p = p0 * (1 - t) + p1 * t
|
| 30 |
+
else:
|
| 31 |
+
p = p0
|
| 32 |
+
else:
|
| 33 |
+
p = out['lf'][i][j]
|
| 34 |
+
|
| 35 |
+
points.append(p)
|
| 36 |
+
points = np.array(points)
|
| 37 |
+
all_polygons.append(points)
|
| 38 |
+
return all_polygons
|
| 39 |
+
|
| 40 |
+
def trim_ends(out):
|
| 41 |
+
|
| 42 |
+
lf_length = len(out['lf'])
|
| 43 |
+
hw = out['hw']
|
| 44 |
+
# Mehreen: hw is (14, 361, 197) a 14x361 matrix for each character. selected is (14, 361)
|
| 45 |
+
selected = hw.argmax(axis=-1)
|
| 46 |
+
beginning = np.argmax(selected != 0, axis=1)
|
| 47 |
+
ending = selected.shape[1] - 1 - np.argmax(selected[:,::-1] != 0, axis=1)
|
| 48 |
+
|
| 49 |
+
beginning_percent = (beginning+0.5) / float(selected.shape[1])
|
| 50 |
+
ending_percent = (ending+0.5) / float(selected.shape[1])
|
| 51 |
+
|
| 52 |
+
lf_beginning = lf_length * beginning_percent
|
| 53 |
+
lf_ending = lf_length * ending_percent
|
| 54 |
+
|
| 55 |
+
out['beginning'] = lf_beginning
|
| 56 |
+
out['ending'] = lf_ending
|
| 57 |
+
return out
|
| 58 |
+
|
| 59 |
+
def filter_on_pick(out, pick):
|
| 60 |
+
out['sol'] = out['sol'][pick]
|
| 61 |
+
out['lf'] = [l[pick] for l in out['lf']]
|
| 62 |
+
out['hw'] = out['hw'][pick]
|
| 63 |
+
|
| 64 |
+
if 'idx' in out:
|
| 65 |
+
out['idx'] = out['idx'][pick]
|
| 66 |
+
if 'beginning' in out:
|
| 67 |
+
out['beginning'] = out['beginning'][pick]
|
| 68 |
+
if 'ending' in out:
|
| 69 |
+
out['ending'] = out['ending'][pick]
|
| 70 |
+
|
| 71 |
+
def filter_on_pick_no_copy(out, pick):
|
| 72 |
+
output = {}
|
| 73 |
+
output['sol'] = out['sol'][pick]
|
| 74 |
+
output['lf'] = [l[pick] for l in out['lf']]
|
| 75 |
+
output['hw'] = out['hw'][pick]
|
| 76 |
+
##Mehreen
|
| 77 |
+
#print(pick)
|
| 78 |
+
#out['line_imgs'] = out['line_imgs'][pick]
|
| 79 |
+
## End mehreen
|
| 80 |
+
if 'idx' in out:
|
| 81 |
+
output['idx'] = out['idx'][pick]
|
| 82 |
+
if 'beginning' in out:
|
| 83 |
+
output['beginning'] = out['beginning'][pick]
|
| 84 |
+
if 'ending' in out:
|
| 85 |
+
output['ending'] = out['ending'][pick]
|
| 86 |
+
return output
|
| 87 |
+
|
| 88 |
+
def select_non_empty_string(out):
|
| 89 |
+
selected = out['hw'].argmax(axis=-1)
|
| 90 |
+
return np.where(selected.sum(axis=1) != 0)
|
| 91 |
+
|
| 92 |
+
def postprocess(out, **kwargs):
|
| 93 |
+
out = copy.copy(out)
|
| 94 |
+
|
| 95 |
+
# postprocessing should be done with numpy data
|
| 96 |
+
sol_threshold = kwargs.get("sol_threshold", None)
|
| 97 |
+
sol_nms_threshold = kwargs.get("sol_nms_threshold", None)
|
| 98 |
+
lf_nms_params = kwargs.get('lf_nms_params', None)
|
| 99 |
+
lf_nms_2_params = kwargs.get('lf_nms_2_params', None)
|
| 100 |
+
|
| 101 |
+
if sol_threshold is not None:
|
| 102 |
+
pick = np.where(out['sol'][:,-1] > sol_threshold)
|
| 103 |
+
filter_on_pick(out, pick)
|
| 104 |
+
|
| 105 |
+
#Mehreen: this is passed as None from run_hwr from decode_one_img_with_info
|
| 106 |
+
if sol_nms_threshold is not None:
|
| 107 |
+
raise Exception("This is not correct")
|
| 108 |
+
pick = nms.sol_nms_single(out['sol'], sol_nms_threshold)
|
| 109 |
+
out['sol'] = out['sol'][pick]
|
| 110 |
+
|
| 111 |
+
#Mehreen: When post-processing this part is done. sample_config lf_nms_range: [0,6] lf_nms_threshold: 0.5
|
| 112 |
+
if lf_nms_params is not None:
|
| 113 |
+
confidences = out['sol'][:,-1]
|
| 114 |
+
overlap_range = lf_nms_params['overlap_range']
|
| 115 |
+
overlap_thresh = lf_nms_params['overlap_threshold']
|
| 116 |
+
|
| 117 |
+
lf_setup = np.concatenate([l[None,...] for l in out['lf']])
|
| 118 |
+
lf_setup = [lf_setup[:,i] for i in range(lf_setup.shape[1])]
|
| 119 |
+
|
| 120 |
+
pick = nms.lf_non_max_suppression_area(lf_setup, confidences, overlap_range, overlap_thresh)
|
| 121 |
+
filter_on_pick(out, pick)
|
| 122 |
+
|
| 123 |
+
#Mehreen: When post-processing this part is None from decode_one_img_with_info
|
| 124 |
+
if lf_nms_2_params is not None:
|
| 125 |
+
confidences = out['sol'][:,-1]
|
| 126 |
+
overlap_thresh = lf_nms_2_params['overlap_threshold']
|
| 127 |
+
refined_lf = get_trimmed_polygons(out)
|
| 128 |
+
pick = nms.lf_non_max_suppression_area(refined_lf, confidences, None, overlap_thresh)
|
| 129 |
+
filter_on_pick(out, pick)
|
| 130 |
+
|
| 131 |
+
return out
|
| 132 |
+
|
| 133 |
+
def read_order(out):
|
| 134 |
+
first_pt = out['lf'][0][:,:2,0]
|
| 135 |
+
|
| 136 |
+
first_pt = first_pt[:,::-1]
|
| 137 |
+
first_pt = np.concatenate([first_pt, np.arange(first_pt.shape[0])[:,None]], axis=1)
|
| 138 |
+
first_pt = first_pt.tolist()
|
| 139 |
+
|
| 140 |
+
first_pt.sort()
|
| 141 |
+
|
| 142 |
+
return [int(p[2]) for p in first_pt]
|
| 143 |
+
|
| 144 |
+
def decode_handwriting(out, idx_to_char):
|
| 145 |
+
hw_out = out['hw']
|
| 146 |
+
list_of_pred = []
|
| 147 |
+
list_of_raw_pred = []
|
| 148 |
+
for i in range(hw_out.shape[0]):
|
| 149 |
+
logits = hw_out[i,...]
|
| 150 |
+
pred, raw_pred = string_utils.naive_decode(logits)
|
| 151 |
+
pred_str = string_utils.label2str_single(pred, idx_to_char, False)
|
| 152 |
+
raw_pred_str = string_utils.label2str_single(raw_pred, idx_to_char, True)
|
| 153 |
+
list_of_pred.append(pred_str)
|
| 154 |
+
list_of_raw_pred.append(raw_pred_str)
|
| 155 |
+
|
| 156 |
+
return list_of_pred, list_of_raw_pred
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def results_to_numpy(out):
|
| 160 |
+
return {
|
| 161 |
+
"sol": out['sol'].data.cpu().numpy()[:,0,:],
|
| 162 |
+
"lf": [l.data.cpu().numpy() for l in out['lf']] if out['lf'] is not None else None,
|
| 163 |
+
"hw": out['hw'].data.cpu().numpy(),
|
| 164 |
+
"results_scale": out['results_scale'],
|
| 165 |
+
"line_imgs": out['line_imgs'],
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
def align_to_gt_lines(decoded_hw, gt_lines):
|
| 169 |
+
costs = []
|
| 170 |
+
for i in range(len(decoded_hw)):
|
| 171 |
+
costs.append([])
|
| 172 |
+
for j in range(len(gt_lines)):
|
| 173 |
+
pred = decoded_hw[i]
|
| 174 |
+
gt = gt_lines[j]
|
| 175 |
+
cer = error_rates.cer(gt, pred)
|
| 176 |
+
costs[i].append(cer)
|
| 177 |
+
|
| 178 |
+
costs = np.array(costs)
|
| 179 |
+
min_idx = costs.argmin(axis=0)
|
| 180 |
+
min_val = costs.min(axis=0)
|
| 181 |
+
|
| 182 |
+
return min_idx, min_val
|
py3/e2e/forward_pass.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from e2e import e2e_model
|
| 2 |
+
from e2e.e2e_model import E2EModel
|
| 3 |
+
|
| 4 |
+
from . import validation_utils
|
| 5 |
+
|
| 6 |
+
from utils import error_rates
|
| 7 |
+
|
| 8 |
+
import itertools
|
| 9 |
+
import copy
|
| 10 |
+
import numpy as np
|
| 11 |
+
import cv2
|
| 12 |
+
|
| 13 |
+
def forward_pass(x, e2e, config, thresholds, idx_to_char, update_json=False):
|
| 14 |
+
|
| 15 |
+
gt_lines = x['gt_lines']
|
| 16 |
+
gt = "\n".join(gt_lines)
|
| 17 |
+
|
| 18 |
+
out_original = e2e(x)
|
| 19 |
+
results = {}
|
| 20 |
+
if out_original is None:
|
| 21 |
+
#TODO: not a good way to handle this, but fine for now
|
| 22 |
+
None
|
| 23 |
+
|
| 24 |
+
gt_lines = x['gt_lines']
|
| 25 |
+
gt = "\n".join(gt_lines)
|
| 26 |
+
|
| 27 |
+
out_original = E2EModel.results_to_numpy(out_original)
|
| 28 |
+
out_original['idx'] = np.arange(out_original['sol'].shape[0])
|
| 29 |
+
|
| 30 |
+
decoded_hw, decoded_raw_hw = E2EModel.decode_handwriting(out_original, idx_to_char)
|
| 31 |
+
pick, costs = E2EModel.align_to_gt_lines(decoded_hw, gt_lines)
|
| 32 |
+
|
| 33 |
+
most_ideal_pred_lines, improved_idxs = validation_utils.update_ideal_results(pick, costs, decoded_hw, x['gt_json'])
|
| 34 |
+
# if update_json:
|
| 35 |
+
# validation_utils.save_improved_idxs(improved_idxs, decoded_hw,
|
| 36 |
+
# decoded_raw_hw, out_original,
|
| 37 |
+
# x, config[dataset_lookup]['json_folder'], config['alignment']['trim_to_sol'])
|
| 38 |
+
|
| 39 |
+
sol_thresholds = thresholds[0]
|
| 40 |
+
sol_thresholds_idx = list(range(len(sol_thresholds)))
|
| 41 |
+
|
| 42 |
+
lf_nms_ranges = thresholds[1]
|
| 43 |
+
lf_nms_ranges_idx = list(range(len(lf_nms_ranges)))
|
| 44 |
+
|
| 45 |
+
lf_nms_thresholds = thresholds[2]
|
| 46 |
+
lf_nms_thresholds_idx = list(range(len(lf_nms_thresholds)))
|
| 47 |
+
|
| 48 |
+
most_ideal_pred_lines = "\n".join(most_ideal_pred_lines)
|
| 49 |
+
|
| 50 |
+
ideal_pred_lines = [decoded_hw[i] for i in pick]
|
| 51 |
+
ideal_pred_lines = "\n".join(ideal_pred_lines)
|
| 52 |
+
|
| 53 |
+
error = error_rates.cer(gt, ideal_pred_lines)
|
| 54 |
+
ideal_result = error
|
| 55 |
+
|
| 56 |
+
error = error_rates.cer(gt, most_ideal_pred_lines)
|
| 57 |
+
most_ideal_result = error
|
| 58 |
+
|
| 59 |
+
for key in itertools.product(sol_thresholds_idx, lf_nms_ranges_idx, lf_nms_thresholds_idx):
|
| 60 |
+
i,j,k = key
|
| 61 |
+
sol_threshold = sol_thresholds[i]
|
| 62 |
+
lf_nms_range = lf_nms_ranges[j]
|
| 63 |
+
lf_nms_threshold = lf_nms_thresholds[k]
|
| 64 |
+
|
| 65 |
+
out = copy.copy(out_original)
|
| 66 |
+
|
| 67 |
+
out = E2EModel.postprocess(out,
|
| 68 |
+
sol_threshold=sol_threshold,
|
| 69 |
+
lf_nms_params={
|
| 70 |
+
"overlap_range": lf_nms_range,
|
| 71 |
+
"overlap_threshold": lf_nms_threshold
|
| 72 |
+
})
|
| 73 |
+
order = E2EModel.read_order(out)
|
| 74 |
+
E2EModel.filter_on_pick(out, order)
|
| 75 |
+
|
| 76 |
+
# draw_img = E2EModel.draw_output(out, img)
|
| 77 |
+
# cv2.imwrite("test_b_samples/test_img_{}.png".format(a), draw_img)
|
| 78 |
+
|
| 79 |
+
preds = [decoded_hw[i] for i in out['idx']]
|
| 80 |
+
pred = "\n".join(preds)
|
| 81 |
+
|
| 82 |
+
error = error_rates.cer(gt, pred)
|
| 83 |
+
|
| 84 |
+
results[key] = error
|
| 85 |
+
|
| 86 |
+
return results, ideal_result, most_ideal_result
|
py3/e2e/handwriting_alignment_loss.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import string_utils, error_rates
|
| 2 |
+
import torch
|
| 3 |
+
from scipy.optimize import linear_sum_assignment
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.autograd import Variable
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def accumulate_scores(out, out_positions, xy_positions, gt_state, idx_to_char):
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
preds = out.transpose(0,1).cpu()
|
| 12 |
+
batch_size = preds.size(1)
|
| 13 |
+
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
|
| 14 |
+
|
| 15 |
+
for i, logits in enumerate(out.data.cpu().numpy()):
|
| 16 |
+
raw_decode, raw_decode_full = string_utils.naive_decode(logits)
|
| 17 |
+
pred_str = string_utils.label2str_single(raw_decode, idx_to_char, False)
|
| 18 |
+
pred_str_full = string_utils.label2str_single(raw_decode_full, idx_to_char, True)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
sub_out_positions = [o[i].data.cpu().numpy().tolist() for o in out_positions]
|
| 22 |
+
sub_xy_positions = [o[i].data.cpu().numpy().tolist() for o in xy_positions]
|
| 23 |
+
|
| 24 |
+
for gt_obj in gt_state:
|
| 25 |
+
gt_text = gt_obj['gt']
|
| 26 |
+
cer = error_rates.cer(gt_text, pred_str)
|
| 27 |
+
|
| 28 |
+
#This is a terrible way to do this...
|
| 29 |
+
gt_obj['errors'] = gt_obj.get('errors', [])
|
| 30 |
+
gt_obj['pred'] = gt_obj.get('pred', [])
|
| 31 |
+
gt_obj['pred_full'] = gt_obj.get('pred_full', [])
|
| 32 |
+
gt_obj['path'] = gt_obj.get('path', [])
|
| 33 |
+
gt_obj['path_xy'] = gt_obj.get('path_xy', [])
|
| 34 |
+
|
| 35 |
+
gt_obj['errors'].append(cer)
|
| 36 |
+
gt_obj['pred'].append(pred_str)
|
| 37 |
+
gt_obj['pred_full'].append(pred_str_full)
|
| 38 |
+
gt_obj['path'].append(sub_out_positions)
|
| 39 |
+
gt_obj['path_xy'].append(sub_xy_positions)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def update_alignment(out, gt_lines, alignments, idx_to_char, idx_mapping, sol_positions):
|
| 43 |
+
|
| 44 |
+
preds = out.cpu()
|
| 45 |
+
batch_size = preds.size(1)
|
| 46 |
+
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
|
| 47 |
+
|
| 48 |
+
for i, logits in enumerate(out.data.cpu().numpy()):
|
| 49 |
+
raw_decode, raw_decode_full = string_utils.naive_decode(logits)
|
| 50 |
+
pred_str = string_utils.label2str_single(raw_decode, idx_to_char, False)
|
| 51 |
+
|
| 52 |
+
for j, gt in enumerate(gt_lines):
|
| 53 |
+
cer = error_rates.cer(gt, pred_str)
|
| 54 |
+
global_i = idx_mapping[i]
|
| 55 |
+
c = sol_positions[i,0,-1].data[0]
|
| 56 |
+
|
| 57 |
+
# alignment_error = cer
|
| 58 |
+
alignment_error = cer + 0.1 * (1.0 - c)
|
| 59 |
+
|
| 60 |
+
if alignment_error < alignments[j][0]:
|
| 61 |
+
alignments[j][0] = alignment_error
|
| 62 |
+
alignments[j][1] = global_i
|
| 63 |
+
# alignments[j][2] = out[i][:,None,:]
|
| 64 |
+
alignments[j][2] = None
|
| 65 |
+
alignments[j][3] = pred_str
|
| 66 |
+
|
| 67 |
+
def alignment(predictions, hw_scores, alpha_alignment=0.1, alpha_backprop=0.1):
|
| 68 |
+
confidences = predictions[:,:,4]
|
| 69 |
+
|
| 70 |
+
log_confidences = torch.log(confidences + 1e-10)
|
| 71 |
+
log_one_minus_confidences = torch.log(1.0 - confidences + 1e-10)
|
| 72 |
+
|
| 73 |
+
expanded_log_confidences = log_confidences[:,:,None].expand(confidences.size(0), confidences.size(1), hw_scores.size(2))
|
| 74 |
+
expanded_log_one_minus_confidences = log_one_minus_confidences[:,:,None].expand(confidences.size(0), confidences.size(1), hw_scores.size(2))
|
| 75 |
+
|
| 76 |
+
C = alpha_alignment * hw_scores - expanded_log_confidences + expanded_log_one_minus_confidences
|
| 77 |
+
|
| 78 |
+
C = C.data.cpu().numpy()
|
| 79 |
+
X = np.zeros_like(C)
|
| 80 |
+
|
| 81 |
+
idxs = []
|
| 82 |
+
for b in range(C.shape[0]):
|
| 83 |
+
C_i = C[b]
|
| 84 |
+
row_ind, col_ind = linear_sum_assignment(C_i.T)
|
| 85 |
+
idxs.append((col_ind, row_ind))
|
| 86 |
+
|
| 87 |
+
return idxs
|
| 88 |
+
|
| 89 |
+
def loss(preds, non_hw_sol, hw_sol, gt_lines, char_to_idx, criterion):
|
| 90 |
+
label_lengths = []
|
| 91 |
+
all_labels = []
|
| 92 |
+
for gt_str in gt_lines:
|
| 93 |
+
l = string_utils.str2label_single(gt_str, char_to_idx)
|
| 94 |
+
all_labels.append(l)
|
| 95 |
+
label_lengths.append(len(l))
|
| 96 |
+
|
| 97 |
+
all_labels = np.concatenate(all_labels)
|
| 98 |
+
label_lengths = np.array(label_lengths)
|
| 99 |
+
|
| 100 |
+
labels = torch.from_numpy(all_labels.astype(np.int32))
|
| 101 |
+
label_lengths = torch.from_numpy(label_lengths.astype(np.int32))
|
| 102 |
+
|
| 103 |
+
labels = Variable(labels, requires_grad=False)
|
| 104 |
+
label_lengths = Variable(label_lengths, requires_grad=False)
|
| 105 |
+
|
| 106 |
+
batch_size = preds.size(0)
|
| 107 |
+
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
|
| 108 |
+
|
| 109 |
+
ctc_loss = 1e-2 * criterion(preds.cpu(), labels, preds_size, label_lengths)
|
| 110 |
+
|
| 111 |
+
log_one_minus_confidences = torch.log(1.0 - non_hw_sol[:,:,0] + 1e-10)
|
| 112 |
+
log_confidences = torch.log(hw_sol[:,:,0] + 1e-10)
|
| 113 |
+
|
| 114 |
+
selected_confidence = log_confidences.sum()
|
| 115 |
+
not_selected_confidence = log_one_minus_confidences.sum()
|
| 116 |
+
|
| 117 |
+
confidence_loss = -selected_confidence - not_selected_confidence
|
| 118 |
+
|
| 119 |
+
# print " - - - - Losses - - - - "
|
| 120 |
+
# print ctc_loss.data[0]
|
| 121 |
+
# print selected_confidence.data[0], log_confidences.size()
|
| 122 |
+
# print not_selected_confidence.data[0], log_one_minus_confidences.size()
|
| 123 |
+
# print ""
|
| 124 |
+
|
| 125 |
+
return ctc_loss + confidence_loss.cpu()
|
py3/e2e/nms.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pyclipper
|
| 4 |
+
|
| 5 |
+
def sol_non_max_suppression(start_torch, overlap_thresh):
|
| 6 |
+
#Todo: Make this work with batches
|
| 7 |
+
|
| 8 |
+
#Rotation is not taken into account
|
| 9 |
+
start = start_torch.data.cpu().numpy()
|
| 10 |
+
|
| 11 |
+
pick = sol_nms_single(start[0], overlap_thresh)
|
| 12 |
+
|
| 13 |
+
zero_idx = [0 for _ in range(len(pick))]
|
| 14 |
+
|
| 15 |
+
select = (zero_idx, pick)
|
| 16 |
+
return start_torch[select][None,...]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def sol_nms_single(start, overlap_thresh):
|
| 20 |
+
# Based on https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/
|
| 21 |
+
# Maybe could port to pytorch to work over the tensors directly
|
| 22 |
+
|
| 23 |
+
# Mehreen comment: mapping start[:,1] is x, start[:,2] is y, start[:,3] is x and start[:,4] is theta
|
| 24 |
+
# So x1,y1 is top left corner and x2,y2 is bottom right corner
|
| 25 |
+
x1 = start[:,1] - start[:,3]
|
| 26 |
+
y1 = start[:,2] - start[:,3]
|
| 27 |
+
|
| 28 |
+
x2 = start[:,1] + start[:,3]
|
| 29 |
+
y2 = start[:,2] + start[:,3]
|
| 30 |
+
|
| 31 |
+
c = start[:,0]
|
| 32 |
+
|
| 33 |
+
area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
| 34 |
+
idxs = np.argsort(c)
|
| 35 |
+
|
| 36 |
+
pick = []
|
| 37 |
+
while len(idxs) > 0:
|
| 38 |
+
|
| 39 |
+
last = len(idxs) - 1
|
| 40 |
+
i = idxs[last]
|
| 41 |
+
pick.append(i)
|
| 42 |
+
|
| 43 |
+
xx1 = np.maximum(x1[i], x1[idxs[:last]])
|
| 44 |
+
yy1 = np.maximum(y1[i], y1[idxs[:last]])
|
| 45 |
+
xx2 = np.minimum(x2[i], x2[idxs[:last]])
|
| 46 |
+
yy2 = np.minimum(y2[i], y2[idxs[:last]])
|
| 47 |
+
|
| 48 |
+
w = np.maximum(0, xx2 - xx1 + 1)
|
| 49 |
+
h = np.maximum(0, yy2 - yy1 + 1)
|
| 50 |
+
|
| 51 |
+
overlap = (w * h) / area[idxs[:last]]
|
| 52 |
+
|
| 53 |
+
idxs = np.delete(idxs, np.concatenate(([last],
|
| 54 |
+
np.where(overlap > overlap_thresh)[0])))
|
| 55 |
+
return pick
|
| 56 |
+
|
| 57 |
+
def lf_non_max_suppression_area(lf_xy_positions, confidences, overlap_range, overlap_thresh):
|
| 58 |
+
# lf_xy_positions = np.concatenate([l.data.cpu().numpy()[None,...] for l in lf_xy_positions])
|
| 59 |
+
# lf_xy_positions = lf_xy_positions[:,:,:2,:2]
|
| 60 |
+
|
| 61 |
+
# print lf_xy_positions
|
| 62 |
+
# raw_input()
|
| 63 |
+
lf_xy_positions = [l[:,:2,:2] for l in lf_xy_positions]
|
| 64 |
+
#this assumes equal length positions
|
| 65 |
+
# lf_xy_positions = np.concatenate([l[None,...] for l in lf_xy_positions])
|
| 66 |
+
# lf_xy_positions = lf_xy_positions[:,:,:2,:2]
|
| 67 |
+
|
| 68 |
+
c = confidences
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
bboxes = []
|
| 72 |
+
center_lines = []
|
| 73 |
+
scales = []
|
| 74 |
+
for i in range(len(lf_xy_positions)):
|
| 75 |
+
pts = lf_xy_positions[i]
|
| 76 |
+
# for i in xrange(lf_xy_positions.shape[1]):
|
| 77 |
+
# pts = lf_xy_positions[:,i,:]
|
| 78 |
+
if overlap_range is not None:
|
| 79 |
+
pts = pts[overlap_range[0]: overlap_range[1]]
|
| 80 |
+
|
| 81 |
+
f = pts[0]
|
| 82 |
+
delta = f[:,0] - f[:,1]
|
| 83 |
+
scale = np.sqrt( (delta**2).sum() )
|
| 84 |
+
scales.append(scale)
|
| 85 |
+
|
| 86 |
+
# ls = pts[:,:,0].tolist() + pts[:,:,1][::-1].tolist()
|
| 87 |
+
# ls = [[int(x[0]), int(x[1])] for x in ls]
|
| 88 |
+
# poly_regions.append(ls)
|
| 89 |
+
center_lines.append( (pts[:,:,0] + pts[:,:,1])/2.0 )
|
| 90 |
+
|
| 91 |
+
min_x = pts[:,0].min()
|
| 92 |
+
max_x = pts[:,0].max()
|
| 93 |
+
min_y = pts[:,1].min()
|
| 94 |
+
max_y = pts[:,1].max()
|
| 95 |
+
|
| 96 |
+
bboxes.append((min_x, min_y, max_x, max_y))
|
| 97 |
+
|
| 98 |
+
bboxes = np.array(bboxes)
|
| 99 |
+
|
| 100 |
+
if len(bboxes.shape) < 2:
|
| 101 |
+
return []
|
| 102 |
+
|
| 103 |
+
x1 = bboxes[:,0]
|
| 104 |
+
y1 = bboxes[:,1]
|
| 105 |
+
x2 = bboxes[:,2]
|
| 106 |
+
y2 = bboxes[:,3]
|
| 107 |
+
|
| 108 |
+
area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
| 109 |
+
idxs = np.argsort(c)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
overlapping_regions = []
|
| 113 |
+
pick = []
|
| 114 |
+
while len(idxs) > 0:
|
| 115 |
+
|
| 116 |
+
last = len(idxs) - 1
|
| 117 |
+
i = idxs[last]
|
| 118 |
+
pick.append(i)
|
| 119 |
+
|
| 120 |
+
xx1 = np.maximum(x1[i], x1[idxs[:last]])
|
| 121 |
+
yy1 = np.maximum(y1[i], y1[idxs[:last]])
|
| 122 |
+
xx2 = np.minimum(x2[i], x2[idxs[:last]])
|
| 123 |
+
yy2 = np.minimum(y2[i], y2[idxs[:last]])
|
| 124 |
+
|
| 125 |
+
# compute the width and height of the bounding box
|
| 126 |
+
w = np.maximum(0, xx2 - xx1 + 1)
|
| 127 |
+
h = np.maximum(0, yy2 - yy1 + 1)
|
| 128 |
+
|
| 129 |
+
# compute the ratio of overlap
|
| 130 |
+
overlap_bb = (w * h) / area[idxs[:last]]
|
| 131 |
+
|
| 132 |
+
overlap = []
|
| 133 |
+
for step, j in enumerate(idxs[:last]):
|
| 134 |
+
#Skip anything that does't actually have any overlap
|
| 135 |
+
if overlap_bb[step] < 0.1:
|
| 136 |
+
overlap.append(0)
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
path0 = center_lines[i]
|
| 140 |
+
path1 = center_lines[j]
|
| 141 |
+
|
| 142 |
+
path = np.concatenate([path0, path1[::-1]])
|
| 143 |
+
path = [[int(x[0]), int(x[1])] for x in path]
|
| 144 |
+
|
| 145 |
+
expected_scale = (scales[i] + scales[j])/2.0
|
| 146 |
+
one_off_area = expected_scale**2 * (path0.shape[0] + path1.shape[0])/2.0
|
| 147 |
+
|
| 148 |
+
simple_path = pyclipper.SimplifyPolygon(path, pyclipper.PFT_NONZERO)
|
| 149 |
+
inter_area = 0
|
| 150 |
+
for path in simple_path:
|
| 151 |
+
inter_area += abs(pyclipper.Area(path))
|
| 152 |
+
|
| 153 |
+
area_ratio = inter_area / one_off_area
|
| 154 |
+
area_ratio = 1.0 - area_ratio
|
| 155 |
+
|
| 156 |
+
overlap.append(area_ratio)
|
| 157 |
+
|
| 158 |
+
overlap = np.array(overlap)
|
| 159 |
+
to_delete = np.concatenate(([last], np.where(overlap > overlap_thresh)[0]))
|
| 160 |
+
idxs = np.delete(idxs, to_delete)
|
| 161 |
+
|
| 162 |
+
return pick
|
py3/e2e/validation_utils.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import error_rates
|
| 2 |
+
import copy
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
def interpolate(key1, key2, lf, lf_idx, step_percent):
|
| 12 |
+
x0 = lf[lf_idx][key1]
|
| 13 |
+
y0 = lf[lf_idx][key2]
|
| 14 |
+
x1 = lf[lf_idx+1][key1]
|
| 15 |
+
y1 = lf[lf_idx+1][key2]
|
| 16 |
+
|
| 17 |
+
x = x1 * step_percent + x0 * (1.0 - step_percent)
|
| 18 |
+
y = y1 * step_percent + y0 * (1.0 - step_percent)
|
| 19 |
+
|
| 20 |
+
return x, y
|
| 21 |
+
|
| 22 |
+
def get_subdivide_pt(i, pred_full, lf):
|
| 23 |
+
percent = (float(i)+0.5) / float(len(pred_full))
|
| 24 |
+
lf_percent = (len(lf)-1) * percent
|
| 25 |
+
|
| 26 |
+
lf_idx = int(np.floor(lf_percent))
|
| 27 |
+
step_percent = lf_percent - lf_idx
|
| 28 |
+
|
| 29 |
+
x0, y0 = interpolate("x0", "y0", lf, lf_idx, step_percent)
|
| 30 |
+
x1, y1 = interpolate("x1", "y1", lf, lf_idx, step_percent)
|
| 31 |
+
|
| 32 |
+
return x0, y0, x1, y1
|
| 33 |
+
|
| 34 |
+
def save_improved_idxs(improved_idxs, decoded_hw, decoded_raw_hw, out, x, json_folder):
|
| 35 |
+
|
| 36 |
+
output_lines = [{
|
| 37 |
+
"gt": gt['gt']
|
| 38 |
+
} for gt in x['gt_json']]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# for i in improved_idxs:
|
| 42 |
+
for i in range(len(output_lines)):
|
| 43 |
+
|
| 44 |
+
if not i in improved_idxs:
|
| 45 |
+
output_lines[i] = x['gt_json'][i]
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
k = improved_idxs[i]
|
| 49 |
+
|
| 50 |
+
# We want to trim the LF results
|
| 51 |
+
# good to keep around the full length of the prediciton
|
| 52 |
+
# so we can generate the full line-level images later
|
| 53 |
+
# at a different resolution
|
| 54 |
+
line_points = []
|
| 55 |
+
after_line_points = []
|
| 56 |
+
lf_path = out['lf']
|
| 57 |
+
end = out['ending'][k]
|
| 58 |
+
for j in range(len(lf_path)):
|
| 59 |
+
p = lf_path[j][k]
|
| 60 |
+
s = out['results_scale']
|
| 61 |
+
|
| 62 |
+
if j > end:
|
| 63 |
+
after_line_points.append({
|
| 64 |
+
"x0": p[0][1] * s,
|
| 65 |
+
"x1": p[0][0] * s,
|
| 66 |
+
"y0": p[1][1] * s,
|
| 67 |
+
"y1": p[1][0] * s
|
| 68 |
+
})
|
| 69 |
+
else:
|
| 70 |
+
line_points.append({
|
| 71 |
+
"x0": p[0][1] * s,
|
| 72 |
+
"x1": p[0][0] * s,
|
| 73 |
+
"y0": p[1][1] * s,
|
| 74 |
+
"y1": p[1][0] * s
|
| 75 |
+
})
|
| 76 |
+
|
| 77 |
+
begin = out['beginning'][k]
|
| 78 |
+
begin_f = int(np.floor(begin))
|
| 79 |
+
p0 = out['lf'][begin_f][k]
|
| 80 |
+
if begin_f+1 >= len(out['lf']):
|
| 81 |
+
p = p0
|
| 82 |
+
else:
|
| 83 |
+
p1 = out['lf'][begin_f+1][k]
|
| 84 |
+
t = begin - np.floor(begin)
|
| 85 |
+
p = p0 * (1 - t) + p1 * t
|
| 86 |
+
|
| 87 |
+
sol_point = {
|
| 88 |
+
"x0": p[0][1] * s,
|
| 89 |
+
"x1": p[0][0] * s,
|
| 90 |
+
"y0": p[1][1] * s,
|
| 91 |
+
"y1": p[1][0] * s
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
img_file_name = "{}_{}.png".format(x['img_key'], i)
|
| 95 |
+
|
| 96 |
+
output_lines[i]['pred'] = decoded_hw[k]
|
| 97 |
+
output_lines[i]['pred_full'] = decoded_raw_hw[k]
|
| 98 |
+
output_lines[i]['sol'] = sol_point
|
| 99 |
+
output_lines[i]['lf'] = line_points
|
| 100 |
+
output_lines[i]['after_lf'] = after_line_points
|
| 101 |
+
output_lines[i]['start_idx'] = 1 #TODO: update to backward idx
|
| 102 |
+
output_lines[i]['hw_path'] = img_file_name
|
| 103 |
+
|
| 104 |
+
line_img = out['line_imgs'][k]
|
| 105 |
+
|
| 106 |
+
full_img_file_name = os.path.join(json_folder, img_file_name)
|
| 107 |
+
cv2.imwrite(full_img_file_name, line_img)
|
| 108 |
+
|
| 109 |
+
json_path = x['json_path']
|
| 110 |
+
with open(json_path, 'w') as f:
|
| 111 |
+
# print('written data to:', f)
|
| 112 |
+
json.dump(output_lines, f)
|
| 113 |
+
|
| 114 |
+
def update_ideal_results(pick, costs, decoded_hw, gt_json):
|
| 115 |
+
|
| 116 |
+
most_ideal_pred = []
|
| 117 |
+
improved_idxs = {}
|
| 118 |
+
|
| 119 |
+
for i in range(len(gt_json)):
|
| 120 |
+
gt_obj = gt_json[i]
|
| 121 |
+
|
| 122 |
+
prev_pred = gt_obj.get('pred', '')
|
| 123 |
+
gt = gt_obj['gt']
|
| 124 |
+
|
| 125 |
+
pred = decoded_hw[pick[i]]
|
| 126 |
+
|
| 127 |
+
prev_cer = error_rates.cer(gt, prev_pred)
|
| 128 |
+
cer = costs[i]
|
| 129 |
+
|
| 130 |
+
if cer > prev_cer or len(pred) == 0:
|
| 131 |
+
most_ideal_pred.append(prev_pred)
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
most_ideal_pred.append(pred)
|
| 135 |
+
improved_idxs[i] = pick[i]
|
| 136 |
+
|
| 137 |
+
return most_ideal_pred, improved_idxs
|
py3/e2e/visualization.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
def draw_output(out, img):
|
| 4 |
+
img = img.copy()
|
| 5 |
+
|
| 6 |
+
# print(out['lf'][0].shape[0], out['sol'].shape[0])
|
| 7 |
+
# sys.exit()
|
| 8 |
+
|
| 9 |
+
for i in range(out['sol'].shape[0]):
|
| 10 |
+
j=i
|
| 11 |
+
|
| 12 |
+
p = out['sol'][i]
|
| 13 |
+
|
| 14 |
+
c = int(255 * p[-1])
|
| 15 |
+
color = (c,0,255-c)
|
| 16 |
+
|
| 17 |
+
x = p[0]
|
| 18 |
+
y = p[1]
|
| 19 |
+
r = p[2]
|
| 20 |
+
x_comp = np.cos(r)
|
| 21 |
+
y_comp = -np.sin(r)
|
| 22 |
+
s = p[3]
|
| 23 |
+
|
| 24 |
+
rx = x + s * x_comp * 2
|
| 25 |
+
ry = y + s * y_comp * 2
|
| 26 |
+
|
| 27 |
+
rx2 = x - s * x_comp
|
| 28 |
+
ry2 = y - s * y_comp
|
| 29 |
+
|
| 30 |
+
rx = int(rx)
|
| 31 |
+
ry = int(ry)
|
| 32 |
+
|
| 33 |
+
rx2 = int(rx2)
|
| 34 |
+
ry2 = int(ry2)
|
| 35 |
+
|
| 36 |
+
x = int(x)
|
| 37 |
+
y = int(y)
|
| 38 |
+
scale = abs(int(s))
|
| 39 |
+
|
| 40 |
+
color = (0,0,255)
|
| 41 |
+
|
| 42 |
+
cv2.circle(img,(x,y), int(scale), color, 2)
|
| 43 |
+
cv2.circle(img,(x,y), 4, color, -1)
|
| 44 |
+
cv2.arrowedLine(img, (x,y), (rx,ry), color, 2, tipLength=0.25)
|
| 45 |
+
# cv2.line(img, (rx2,ry2), (rx,ry), color, 2)
|
| 46 |
+
cv2.putText(img,str(i),(x,y), cv2.FONT_HERSHEY_SIMPLEX, 1,(0,255,0),2,cv2.LINE_AA)
|
| 47 |
+
# for j in range(out['lf'][0].shape[0]):
|
| 48 |
+
begin = out['beginning'][j]
|
| 49 |
+
end = out['ending'][j]
|
| 50 |
+
|
| 51 |
+
last_xy = None
|
| 52 |
+
# for i in xrange(len(out['lf'])):
|
| 53 |
+
begin_f = int(np.floor(begin))
|
| 54 |
+
end_f = int(np.ceil(end))
|
| 55 |
+
for i in range(begin_f, end_f+1):
|
| 56 |
+
|
| 57 |
+
if i == begin_f:
|
| 58 |
+
p0 = out['lf'][i][j].mean(axis=1)
|
| 59 |
+
p1 = out['lf'][i+1][j].mean(axis=1)
|
| 60 |
+
t = begin - np.floor(begin)
|
| 61 |
+
p = p0 * (1 - t) + p1 * t
|
| 62 |
+
|
| 63 |
+
elif i == end_f:
|
| 64 |
+
|
| 65 |
+
p0 = out['lf'][i-1][j].mean(axis=1)
|
| 66 |
+
if i != len(out['lf']):
|
| 67 |
+
p1 = out['lf'][i][j].mean(axis=1)
|
| 68 |
+
t = end - np.floor(end)
|
| 69 |
+
p = p0 * (1 - t) + p1 * t
|
| 70 |
+
else:
|
| 71 |
+
p = p0
|
| 72 |
+
else:
|
| 73 |
+
p = out['lf'][i][j].mean(axis=1)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
x = p[0]
|
| 77 |
+
y = p[1]
|
| 78 |
+
|
| 79 |
+
x = int(x)
|
| 80 |
+
y = int(y)
|
| 81 |
+
|
| 82 |
+
# c = int(255 * p[-1])
|
| 83 |
+
# color = (c,0,255-c)
|
| 84 |
+
color = (0,150,0)
|
| 85 |
+
cv2.circle(img,(x,y), 4, color, -1)
|
| 86 |
+
|
| 87 |
+
if last_xy is not None:
|
| 88 |
+
cv2.line(img, (x,y), last_xy, color, int(s))
|
| 89 |
+
|
| 90 |
+
last_xy = (x,y)
|
| 91 |
+
return img
|
| 92 |
+
|
| 93 |
+
def draw_output_original(out, img):
|
| 94 |
+
img = img.copy()
|
| 95 |
+
|
| 96 |
+
for j in range(out['lf'][0].shape[0]):
|
| 97 |
+
begin = out['beginning'][j]
|
| 98 |
+
end = out['ending'][j]
|
| 99 |
+
|
| 100 |
+
last_xy = None
|
| 101 |
+
# for i in xrange(len(out['lf'])):
|
| 102 |
+
begin_f = int(np.floor(begin))
|
| 103 |
+
end_f = int(np.ceil(end))
|
| 104 |
+
for i in range(begin_f, end_f+1):
|
| 105 |
+
|
| 106 |
+
if i == begin_f:
|
| 107 |
+
p0 = out['lf'][i][j].mean(axis=1)
|
| 108 |
+
p1 = out['lf'][i+1][j].mean(axis=1)
|
| 109 |
+
t = begin - np.floor(begin)
|
| 110 |
+
p = p0 * (1 - t) + p1 * t
|
| 111 |
+
|
| 112 |
+
elif i == end_f:
|
| 113 |
+
|
| 114 |
+
p0 = out['lf'][i-1][j].mean(axis=1)
|
| 115 |
+
if i != len(out['lf']):
|
| 116 |
+
p1 = out['lf'][i][j].mean(axis=1)
|
| 117 |
+
t = end - np.floor(end)
|
| 118 |
+
p = p0 * (1 - t) + p1 * t
|
| 119 |
+
else:
|
| 120 |
+
p = p0
|
| 121 |
+
else:
|
| 122 |
+
p = out['lf'][i][j].mean(axis=1)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
x = p[0]
|
| 126 |
+
y = p[1]
|
| 127 |
+
|
| 128 |
+
x = int(x)
|
| 129 |
+
y = int(y)
|
| 130 |
+
|
| 131 |
+
color = (0,0,0)
|
| 132 |
+
cv2.circle(img,(x,y), 4, color, -1)
|
| 133 |
+
|
| 134 |
+
if last_xy is not None:
|
| 135 |
+
cv2.line(img, (x,y), last_xy, color, 2)
|
| 136 |
+
|
| 137 |
+
last_xy = (x,y)
|
| 138 |
+
|
| 139 |
+
for i in range(out['sol'].shape[0]):
|
| 140 |
+
|
| 141 |
+
p = out['sol'][i]
|
| 142 |
+
|
| 143 |
+
c = int(255 * p[-1])
|
| 144 |
+
color = (c,0,255-c)
|
| 145 |
+
|
| 146 |
+
x = p[0]
|
| 147 |
+
y = p[1]
|
| 148 |
+
r = p[2]
|
| 149 |
+
x_comp = np.cos(r)
|
| 150 |
+
y_comp = -np.sin(r)
|
| 151 |
+
s = p[3]
|
| 152 |
+
|
| 153 |
+
rx = x + s * x_comp * 2
|
| 154 |
+
ry = y + s * y_comp * 2
|
| 155 |
+
|
| 156 |
+
rx2 = x - s * x_comp
|
| 157 |
+
ry2 = y - s * y_comp
|
| 158 |
+
|
| 159 |
+
rx = int(rx)
|
| 160 |
+
ry = int(ry)
|
| 161 |
+
|
| 162 |
+
rx2 = int(rx2)
|
| 163 |
+
ry2 = int(ry2)
|
| 164 |
+
|
| 165 |
+
x = int(x)
|
| 166 |
+
y = int(y)
|
| 167 |
+
scale = abs(int(s))
|
| 168 |
+
|
| 169 |
+
# color = (0,0,255)
|
| 170 |
+
|
| 171 |
+
cv2.circle(img,(x,y), int(scale), color, 2)
|
| 172 |
+
cv2.circle(img,(x,y), 4, color, -1)
|
| 173 |
+
cv2.arrowedLine(img, (x,y), (rx,ry), color, 2, tipLength=0.25)
|
| 174 |
+
# cv2.line(img, (rx2,ry2), (rx,ry), color, 2)
|
| 175 |
+
cv2.putText(img,str(i),(x,y), cv2.FONT_HERSHEY_SIMPLEX, 1,(0,255,0),2,cv2.LINE_AA)
|
| 176 |
+
return img
|
py3/hw/__init__.py
ADDED
|
File without changes
|
py3/hw/cnn_lstm.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import matplotlib.patches as patches
|
| 6 |
+
|
| 7 |
+
class BidirectionalLSTM(nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, nIn, nHidden, nOut):
|
| 10 |
+
super(BidirectionalLSTM, self).__init__()
|
| 11 |
+
|
| 12 |
+
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True, dropout=0.5, num_layers=2)
|
| 13 |
+
self.embedding = nn.Linear(nHidden * 2, nOut)
|
| 14 |
+
|
| 15 |
+
def forward(self, input):
|
| 16 |
+
|
| 17 |
+
#print('blstm input', input.size())
|
| 18 |
+
recurrent, notused = self.rnn(input)
|
| 19 |
+
#print('rnn output', recurrent.size(), 'not used', notused)
|
| 20 |
+
T, b, h = recurrent.size()
|
| 21 |
+
t_rec = recurrent.view(T * b, h)
|
| 22 |
+
|
| 23 |
+
output = self.embedding(t_rec) # [T * b, nOut]
|
| 24 |
+
output = output.view(T, b, -1)
|
| 25 |
+
|
| 26 |
+
#print('.....', output.size())
|
| 27 |
+
return output
|
| 28 |
+
|
| 29 |
+
class CRNN(nn.Module):
|
| 30 |
+
|
| 31 |
+
def __init__(self, cnnOutSize, nc, nclass, nh, n_rnn=2, leakyRelu=False, use_instance_norm=False):
|
| 32 |
+
super(CRNN, self).__init__()
|
| 33 |
+
|
| 34 |
+
ks = [3, 3, 3, 3, 3, 3, 2]
|
| 35 |
+
ps = [1, 1, 1, 1, 1, 1, 0]
|
| 36 |
+
ss = [1, 1, 1, 1, 1, 1, 1]
|
| 37 |
+
nm = [64, 128, 256, 256, 512, 512, 512]
|
| 38 |
+
|
| 39 |
+
cnn = nn.Sequential()
|
| 40 |
+
|
| 41 |
+
def convRelu(i, batchNormalization=False):
|
| 42 |
+
nIn = nc if i == 0 else nm[i - 1]
|
| 43 |
+
nOut = nm[i]
|
| 44 |
+
cnn.add_module('conv{0}'.format(i),
|
| 45 |
+
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
|
| 46 |
+
if batchNormalization:
|
| 47 |
+
if not use_instance_norm:
|
| 48 |
+
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
|
| 49 |
+
else:
|
| 50 |
+
cnn.add_module(f'instancenorm{i}', nn.InstanceNorm2d(nOut))
|
| 51 |
+
if leakyRelu:
|
| 52 |
+
cnn.add_module('relu{0}'.format(i),
|
| 53 |
+
nn.LeakyReLU(0.2, inplace=True))
|
| 54 |
+
else:
|
| 55 |
+
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
|
| 56 |
+
|
| 57 |
+
convRelu(0)
|
| 58 |
+
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
|
| 59 |
+
convRelu(1)
|
| 60 |
+
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
|
| 61 |
+
convRelu(2, True)
|
| 62 |
+
convRelu(3)
|
| 63 |
+
cnn.add_module('pooling{0}'.format(2),
|
| 64 |
+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
|
| 65 |
+
convRelu(4, True)
|
| 66 |
+
convRelu(5)
|
| 67 |
+
cnn.add_module('pooling{0}'.format(3),
|
| 68 |
+
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
|
| 69 |
+
convRelu(6, True) # 512x1x16
|
| 70 |
+
|
| 71 |
+
self.cnn = cnn
|
| 72 |
+
# Mehreen: nclass is the total outut characters. nh is set to 512 by create_model
|
| 73 |
+
self.rnn = BidirectionalLSTM(cnnOutSize, nh, nclass)
|
| 74 |
+
###MEHREEN ADD PARAM dim=2
|
| 75 |
+
self.softmax = nn.LogSoftmax(dim=2)
|
| 76 |
+
|
| 77 |
+
def forward(self, input):
|
| 78 |
+
conv = self.cnn(input)
|
| 79 |
+
b, c, h, w = conv.size()
|
| 80 |
+
#print('.....', input.size())
|
| 81 |
+
#print('....', b, c, h, w)
|
| 82 |
+
|
| 83 |
+
if torch.any(torch.isnan(conv)):
|
| 84 |
+
print("CONV IS NAN (b,c,h,w) = ", b, c, h, w)
|
| 85 |
+
|
| 86 |
+
#iimg = input.cpu()[0].permute(2, 1, 0)
|
| 87 |
+
#print('....iimg.size', input.size())
|
| 88 |
+
#plt.imshow(iimg)
|
| 89 |
+
#plt.show()
|
| 90 |
+
|
| 91 |
+
####MEHREEN change this
|
| 92 |
+
#conv = conv.view(b, -1, w) ###<--original
|
| 93 |
+
# to
|
| 94 |
+
conv = torch.reshape(conv, (b, c*h, w))
|
| 95 |
+
###End mehreen
|
| 96 |
+
conv = conv.permute(2, 0, 1) # [w, b, c]
|
| 97 |
+
# rnn features
|
| 98 |
+
output = self.rnn(conv)
|
| 99 |
+
if torch.any(torch.isnan(output)):
|
| 100 |
+
print("OUTPUT FROM RNN IS NAN")
|
| 101 |
+
###MEHREEN ADD
|
| 102 |
+
output = self.softmax(output)
|
| 103 |
+
if torch.any(torch.isnan(output)):
|
| 104 |
+
print("OUTPUT FROM SOFTMAX IS NAN")
|
| 105 |
+
if torch.any(torch.isinf(output)):
|
| 106 |
+
print("OUTPUT FROM SOFTMAX IS INF")
|
| 107 |
+
###END MEHREEN
|
| 108 |
+
return output
|
| 109 |
+
|
| 110 |
+
def create_model(config):
|
| 111 |
+
use_instance_norm = False
|
| 112 |
+
if 'use_instance_norm' in config and config['use_instance_norm']:
|
| 113 |
+
use_instance_norm = True
|
| 114 |
+
crnn = CRNN(config['cnn_out_size'], config['num_of_channels'], config['num_of_outputs'], 512,
|
| 115 |
+
use_instance_norm=use_instance_norm)
|
| 116 |
+
return crnn
|
| 117 |
+
|
py3/lf/__init__.py
ADDED
|
File without changes
|
py3/lf/fast_patch_view.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.autograd import Variable
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
from utils import transformation_utils
|
| 6 |
+
|
| 7 |
+
def get_patches(image, crop_window, grid_gen, allow_end_early=False, device="cuda"):
|
| 8 |
+
|
| 9 |
+
dtype = torch.FloatTensor
|
| 10 |
+
if 'cuda' in device:
|
| 11 |
+
dtype = torch.cuda.FloatTensor
|
| 12 |
+
|
| 13 |
+
pts = Variable(torch.FloatTensor([
|
| 14 |
+
[-1.0, -1.0, 1.0, 1.0],
|
| 15 |
+
[-1.0, 1.0, -1.0, 1.0],
|
| 16 |
+
[ 1.0, 1.0, 1.0, 1.0]
|
| 17 |
+
]).type_as(image.data), requires_grad=False)[None,...]
|
| 18 |
+
|
| 19 |
+
bounds = crop_window.matmul(pts)
|
| 20 |
+
|
| 21 |
+
min_bounds, _ = bounds.min(dim=-1)
|
| 22 |
+
max_bounds, _ = bounds.max(dim=-1)
|
| 23 |
+
d_bounds = max_bounds - min_bounds
|
| 24 |
+
floored_idx_offsets = torch.floor(min_bounds[:,:2].data).long()
|
| 25 |
+
max_d_bounds = d_bounds.max(dim=0)[0].max(dim=0)[0]
|
| 26 |
+
crop_size = torch.ceil(max_d_bounds).long()
|
| 27 |
+
if image.is_cuda:
|
| 28 |
+
crop_size = crop_size.cuda()
|
| 29 |
+
w = crop_size.item()
|
| 30 |
+
|
| 31 |
+
memory_space = Variable(torch.zeros(d_bounds.size(0), 3, w, w).type_as(image.data), requires_grad=False)
|
| 32 |
+
translations = []
|
| 33 |
+
N = transformation_utils.compute_renorm_matrix(memory_space)
|
| 34 |
+
all_skipped = True
|
| 35 |
+
|
| 36 |
+
for b_i in range(memory_space.size(0)):
|
| 37 |
+
|
| 38 |
+
o = floored_idx_offsets[b_i]
|
| 39 |
+
|
| 40 |
+
t = Variable(dtype([
|
| 41 |
+
[1,0,-o[0]],
|
| 42 |
+
[0,1,-o[1]],
|
| 43 |
+
[0,0, 1]
|
| 44 |
+
]), requires_grad=False).expand(3,3)
|
| 45 |
+
translations.append(N.mm(t)[None,...])
|
| 46 |
+
|
| 47 |
+
skip_slice = False
|
| 48 |
+
|
| 49 |
+
s_x = (o[0], o[0]+w)
|
| 50 |
+
s_y = (o[1], o[1]+w)
|
| 51 |
+
t_x = (0, w)
|
| 52 |
+
t_y = (0, w)
|
| 53 |
+
if o[0] < 0:
|
| 54 |
+
s_x = (0, w+o[0])
|
| 55 |
+
t_x = (-o[0], w)
|
| 56 |
+
|
| 57 |
+
if o[1] < 0:
|
| 58 |
+
s_y = (0, w+o[1])
|
| 59 |
+
t_y = (-o[1], w)
|
| 60 |
+
|
| 61 |
+
if o[0]+w >= image.size(2):
|
| 62 |
+
s_x = (s_x[0], image.size(2))
|
| 63 |
+
t_x = (t_x[0], image.size(2) - s_x[0])
|
| 64 |
+
|
| 65 |
+
if o[1]+w >= image.size(3):
|
| 66 |
+
s_y = (s_y[1], image.size(3))
|
| 67 |
+
t_y = (t_y[1], image.size(3) - s_y[1])
|
| 68 |
+
|
| 69 |
+
if s_x[0] >= s_x[1]:
|
| 70 |
+
skip_slice = True
|
| 71 |
+
|
| 72 |
+
if t_x[0] >= t_x[1]:
|
| 73 |
+
skip_slice = True
|
| 74 |
+
|
| 75 |
+
if s_y[0] >= s_y[1]:
|
| 76 |
+
skip_slice = True
|
| 77 |
+
|
| 78 |
+
if t_y[0] >= t_y[1]:
|
| 79 |
+
skip_slice = True
|
| 80 |
+
|
| 81 |
+
if not skip_slice:
|
| 82 |
+
all_skipped = False
|
| 83 |
+
i_s = image[b_i:b_i+1, :, s_x[0]:s_x[1], s_y[0]:s_y[1]]
|
| 84 |
+
memory_space[b_i:b_i+1, :, t_x[0]:t_x[1], t_y[0]:t_y[1]] = i_s
|
| 85 |
+
|
| 86 |
+
if all_skipped and allow_end_early:
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
translations = torch.cat(translations, 0)
|
| 90 |
+
grid = grid_gen(translations.bmm(crop_window))
|
| 91 |
+
grid = grid[:,:,:,0:2] / grid[:,:,:,2:3]
|
| 92 |
+
|
| 93 |
+
resampled = torch.nn.functional.grid_sample(memory_space.transpose(2,3), grid, mode='bilinear',
|
| 94 |
+
align_corners=True)
|
| 95 |
+
|
| 96 |
+
return resampled
|
py3/lf/lf_cnn.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
def convRelu(i, batchNormalization=False, leakyRelu=False):
|
| 5 |
+
nc = 3
|
| 6 |
+
ks = [3, 3, 3, 3, 3, 3, 2]
|
| 7 |
+
ps = [1, 1, 1, 1, 1, 1, 1]
|
| 8 |
+
ss = [1, 1, 1, 1, 1, 1, 1]
|
| 9 |
+
nm = [64, 128, 256, 256, 512, 512, 512]
|
| 10 |
+
|
| 11 |
+
cnn = nn.Sequential()
|
| 12 |
+
|
| 13 |
+
nIn = nc if i == 0 else nm[i - 1]
|
| 14 |
+
nOut = nm[i]
|
| 15 |
+
cnn.add_module('conv{0}'.format(i),
|
| 16 |
+
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
|
| 17 |
+
if batchNormalization:
|
| 18 |
+
# Mehreen comment: track_running_stat is set to True to be able to load author's state_dict
|
| 19 |
+
# It was set to False in the original py3 version (no param in author's version)
|
| 20 |
+
cnn.add_module('batchnorm{0}'.format(i), nn.InstanceNorm2d(nOut, track_running_stats=True))
|
| 21 |
+
# cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
|
| 22 |
+
if leakyRelu:
|
| 23 |
+
cnn.add_module('relu{0}'.format(i),
|
| 24 |
+
nn.LeakyReLU(0.2, inplace=True))
|
| 25 |
+
else:
|
| 26 |
+
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
|
| 27 |
+
return cnn
|
| 28 |
+
|
| 29 |
+
def makeCnn():
|
| 30 |
+
|
| 31 |
+
cnn = nn.Sequential()
|
| 32 |
+
cnn.add_module('convRelu{0}'.format(0), convRelu(0))
|
| 33 |
+
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))
|
| 34 |
+
cnn.add_module('convRelu{0}'.format(1), convRelu(1))
|
| 35 |
+
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))
|
| 36 |
+
cnn.add_module('convRelu{0}'.format(2), convRelu(2, True))
|
| 37 |
+
cnn.add_module('convRelu{0}'.format(3), convRelu(3))
|
| 38 |
+
cnn.add_module('pooling{0}'.format(2), nn.MaxPool2d(2, 2))
|
| 39 |
+
cnn.add_module('convRelu{0}'.format(4), convRelu(4, True))
|
| 40 |
+
cnn.add_module('convRelu{0}'.format(5), convRelu(5))
|
| 41 |
+
cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d(2, 2))
|
| 42 |
+
cnn.add_module('convRelu{0}'.format(6), convRelu(6, True))
|
| 43 |
+
cnn.add_module('pooling{0}'.format(4), nn.MaxPool2d(2, 2))
|
| 44 |
+
|
| 45 |
+
return cnn
|
py3/lf/line_follower.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.autograd import Variable
|
| 4 |
+
from .stn.gridgen import AffineGridGen, PerspectiveGridGen, GridGen
|
| 5 |
+
import numpy as np
|
| 6 |
+
from utils import transformation_utils
|
| 7 |
+
from .lf_cnn import makeCnn
|
| 8 |
+
from .fast_patch_view import get_patches
|
| 9 |
+
|
| 10 |
+
class LineFollower(nn.Module):
|
| 11 |
+
def __init__(self, output_grid_size=32, dtype=torch.cuda.FloatTensor, device="cuda"):
|
| 12 |
+
super(LineFollower, self).__init__()
|
| 13 |
+
cnn = makeCnn()
|
| 14 |
+
position_linear = nn.Linear(512,5)
|
| 15 |
+
position_linear.weight.data.zero_()
|
| 16 |
+
position_linear.bias.data[0] = 0
|
| 17 |
+
position_linear.bias.data[1] = 0
|
| 18 |
+
position_linear.bias.data[2] = 0
|
| 19 |
+
|
| 20 |
+
self.output_grid_size = output_grid_size
|
| 21 |
+
|
| 22 |
+
self.dtype = dtype
|
| 23 |
+
self.cnn = cnn
|
| 24 |
+
self.position_linear = position_linear
|
| 25 |
+
self.device = device
|
| 26 |
+
|
| 27 |
+
def forward(self, image, positions, steps=None, all_positions=[], reset_interval=-1, randomize=False, negate_lw=False, skip_grid=False, allow_end_early=False):
|
| 28 |
+
|
| 29 |
+
batch_size = image.size(0)
|
| 30 |
+
renorm_matrix = transformation_utils.compute_renorm_matrix(image)
|
| 31 |
+
expanded_renorm_matrix = renorm_matrix.expand(batch_size,3,3)
|
| 32 |
+
|
| 33 |
+
t = ((np.arange(self.output_grid_size) + 0.5) / float(self.output_grid_size))[:,None].astype(np.float32)
|
| 34 |
+
t = np.repeat(t,axis=1, repeats=self.output_grid_size)
|
| 35 |
+
t = Variable(torch.from_numpy(t), requires_grad=False)
|
| 36 |
+
t = t.to(self.device)
|
| 37 |
+
s = t.t()
|
| 38 |
+
|
| 39 |
+
t = t[:,:,None]
|
| 40 |
+
s = s[:,:,None]
|
| 41 |
+
|
| 42 |
+
interpolations = torch.cat([
|
| 43 |
+
(1-t)*s,
|
| 44 |
+
(1-t)*(1-s),
|
| 45 |
+
t*s,
|
| 46 |
+
t*(1-s),
|
| 47 |
+
], dim=-1)
|
| 48 |
+
|
| 49 |
+
view_window = Variable(self.dtype([
|
| 50 |
+
[2,0,2],
|
| 51 |
+
[0,2,0],
|
| 52 |
+
[0,0,1]
|
| 53 |
+
])).expand(batch_size,3,3)
|
| 54 |
+
|
| 55 |
+
step_bias = Variable(self.dtype([
|
| 56 |
+
[1,0,2],
|
| 57 |
+
[0,1,0],
|
| 58 |
+
[0,0,1]
|
| 59 |
+
])).expand(batch_size,3,3)
|
| 60 |
+
|
| 61 |
+
invert = Variable(self.dtype([
|
| 62 |
+
[-1,0,0],
|
| 63 |
+
[0,-1,0],
|
| 64 |
+
[0,0,1]
|
| 65 |
+
])).expand(batch_size,3,3)
|
| 66 |
+
|
| 67 |
+
if negate_lw:
|
| 68 |
+
view_window = invert.bmm(view_window)
|
| 69 |
+
|
| 70 |
+
grid_gen = GridGen(32,32, device=self.device)
|
| 71 |
+
|
| 72 |
+
view_window_imgs = []
|
| 73 |
+
next_windows = []
|
| 74 |
+
reset_windows = True
|
| 75 |
+
for i in range(steps):
|
| 76 |
+
|
| 77 |
+
if i%reset_interval != 0 or reset_interval==-1:
|
| 78 |
+
p_0 = positions[-1]
|
| 79 |
+
|
| 80 |
+
if i == 0 and len(p_0.size()) == 3 and p_0.size()[1] == 3 and p_0.size()[2] == 3:
|
| 81 |
+
current_window = p_0
|
| 82 |
+
reset_windows = False
|
| 83 |
+
next_windows.append(p_0)
|
| 84 |
+
|
| 85 |
+
else:
|
| 86 |
+
p_0 = all_positions[i].type(self.dtype)
|
| 87 |
+
reset_windows = True
|
| 88 |
+
if randomize:
|
| 89 |
+
add_noise = p_0.clone()
|
| 90 |
+
add_noise.data.zero_()
|
| 91 |
+
mul_moise = p_0.clone()
|
| 92 |
+
mul_moise.data.fill_(1.0)
|
| 93 |
+
|
| 94 |
+
add_noise[:,0].data.uniform_(-2, 2)
|
| 95 |
+
add_noise[:,1].data.uniform_(-2, 2)
|
| 96 |
+
add_noise[:,2].data.uniform_(-.1, .1)
|
| 97 |
+
|
| 98 |
+
p_0 = p_0 * mul_moise + add_noise
|
| 99 |
+
|
| 100 |
+
if reset_windows:
|
| 101 |
+
reset_windows = False
|
| 102 |
+
|
| 103 |
+
current_window = transformation_utils.get_init_matrix(p_0)
|
| 104 |
+
|
| 105 |
+
if len(next_windows) == 0:
|
| 106 |
+
next_windows.append(current_window)
|
| 107 |
+
else:
|
| 108 |
+
current_window = next_windows[-1].detach()
|
| 109 |
+
|
| 110 |
+
crop_window = current_window.bmm(view_window)
|
| 111 |
+
|
| 112 |
+
resampled = get_patches(image, crop_window, grid_gen, allow_end_early, device=self.device)
|
| 113 |
+
|
| 114 |
+
if resampled is None and i > 0:
|
| 115 |
+
#get patches checks to see if stopping early is allowed
|
| 116 |
+
break
|
| 117 |
+
|
| 118 |
+
if resampled is None and i == 0:
|
| 119 |
+
#Odd case where it start completely off of the edge
|
| 120 |
+
#This happens rarely, but maybe should be more eligantly handled
|
| 121 |
+
#in the future
|
| 122 |
+
resampled = Variable(torch.zeros(crop_window.size(0), 3, 32, 32).type_as(image.data), requires_grad=False)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# Process Window CNN
|
| 126 |
+
cnn_out = self.cnn(resampled)
|
| 127 |
+
cnn_out = torch.squeeze(cnn_out, dim=2)
|
| 128 |
+
cnn_out = torch.squeeze(cnn_out, dim=2)
|
| 129 |
+
delta = self.position_linear(cnn_out)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
next_window = transformation_utils.get_step_matrix(delta)
|
| 133 |
+
next_window = next_window.bmm(step_bias)
|
| 134 |
+
if negate_lw:
|
| 135 |
+
next_window = invert.bmm(next_window).bmm(invert)
|
| 136 |
+
|
| 137 |
+
next_windows.append(current_window.bmm(next_window))
|
| 138 |
+
|
| 139 |
+
grid_line = []
|
| 140 |
+
mask_line = []
|
| 141 |
+
line_done = []
|
| 142 |
+
xy_positions = []
|
| 143 |
+
|
| 144 |
+
a_pt = Variable(torch.Tensor(
|
| 145 |
+
[
|
| 146 |
+
[0, 1,1],
|
| 147 |
+
[0,-1,1]
|
| 148 |
+
]
|
| 149 |
+
)).to(self.device)
|
| 150 |
+
a_pt = a_pt.transpose(1,0)
|
| 151 |
+
a_pt = a_pt.expand(batch_size, a_pt.size(0), a_pt.size(1))
|
| 152 |
+
|
| 153 |
+
for i in range(0, len(next_windows)-1):
|
| 154 |
+
|
| 155 |
+
w_0 = next_windows[i]
|
| 156 |
+
w_1 = next_windows[i+1]
|
| 157 |
+
|
| 158 |
+
pts_0 = w_0.bmm(a_pt)
|
| 159 |
+
pts_1 = w_1.bmm(a_pt)
|
| 160 |
+
xy_positions.append(pts_0)
|
| 161 |
+
|
| 162 |
+
if skip_grid:
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
pts = torch.cat([pts_0, pts_1], dim=2)
|
| 166 |
+
|
| 167 |
+
grid_pts = expanded_renorm_matrix.bmm(pts)
|
| 168 |
+
|
| 169 |
+
grid = interpolations[None,:,:,None,:] * grid_pts[:,None,None,:,:]
|
| 170 |
+
grid = grid.sum(dim=-1)[...,:2]
|
| 171 |
+
|
| 172 |
+
grid_line.append(grid)
|
| 173 |
+
|
| 174 |
+
xy_positions.append(pts_1)
|
| 175 |
+
|
| 176 |
+
if skip_grid:
|
| 177 |
+
grid_line = None
|
| 178 |
+
else:
|
| 179 |
+
grid_line = torch.cat(grid_line, dim=1)
|
| 180 |
+
|
| 181 |
+
return grid_line, view_window_imgs, next_windows, xy_positions
|
py3/lf/models/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
#from utils1 import coerce_to_path_and_check_exist
|
| 5 |
+
|
| 6 |
+
from .res_unet import ResUNet
|
| 7 |
+
from .tools import safe_model_state_dict
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_model(name=None):
|
| 11 |
+
if name is None:
|
| 12 |
+
name = 'res_unet18'
|
| 13 |
+
return {
|
| 14 |
+
'res_unet18': partial(ResUNet, encoder_name='resnet18'),
|
| 15 |
+
'res_unet34': partial(ResUNet, encoder_name='resnet34'),
|
| 16 |
+
'res_unet50': partial(ResUNet, encoder_name='resnet50'),
|
| 17 |
+
'res_unet101': partial(ResUNet, encoder_name='resnet101'),
|
| 18 |
+
'res_unet152': partial(ResUNet, encoder_name='resnet152'),
|
| 19 |
+
}[name]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_model_from_path(model_path, device=None, attributes_to_return=None, eval_mode=True):
|
| 23 |
+
if device is None:
|
| 24 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 25 |
+
checkpoint = torch.load(coerce_to_path_and_check_exist(model_path), map_location=device.type)
|
| 26 |
+
checkpoint['model_kwargs']['pretrained_encoder'] = False
|
| 27 |
+
model = get_model(checkpoint['model_name'])(checkpoint['n_classes'], **checkpoint['model_kwargs']).to(device)
|
| 28 |
+
model.load_state_dict(safe_model_state_dict(checkpoint['model_state']))
|
| 29 |
+
if eval_mode:
|
| 30 |
+
model.eval()
|
| 31 |
+
if attributes_to_return is not None:
|
| 32 |
+
if isinstance(attributes_to_return, str):
|
| 33 |
+
attributes_to_return = [attributes_to_return]
|
| 34 |
+
return model, [checkpoint.get(key) for key in attributes_to_return]
|
| 35 |
+
else:
|
| 36 |
+
return model
|
py3/lf/models/res_unet.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
from .resnet import get_resnet_model
|
| 6 |
+
from .tools import conv1x1, conv3x3, DecoderModule, get_norm_layer, UpsampleCatConv
|
| 7 |
+
#from utils1.logger import print_info, print_warning
|
| 8 |
+
|
| 9 |
+
INPUT_CHANNELS = 3
|
| 10 |
+
FINAL_LAYER_CHANNELS = 32
|
| 11 |
+
LAYER1_REDUCED_CHANNELS = 128
|
| 12 |
+
LAYER2_REDUCED_CHANNELS = 256
|
| 13 |
+
LAYER3_REDUCED_CHANNELS = 512
|
| 14 |
+
LAYER4_REDUCED_CHANNELS = 1024
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ResUNet(nn.Module):
|
| 18 |
+
"""U-Net with residual encoder backbone."""
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def name(self):
|
| 22 |
+
return self.enc_name.replace('res', 'res_u')
|
| 23 |
+
|
| 24 |
+
def __init__(self, n_classes, **kwargs):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.n_classes = n_classes
|
| 27 |
+
self.norm_layer_kwargs = kwargs.pop('norm_layer', dict())
|
| 28 |
+
self.norm_layer = get_norm_layer(**self.norm_layer_kwargs)
|
| 29 |
+
self.no_maxpool = kwargs.get('no_maxpool', False)
|
| 30 |
+
self.conv_as_maxpool = kwargs.get('conv_as_maxpool', True)
|
| 31 |
+
self.use_upcatconv = kwargs.get('use_upcatconv', False)
|
| 32 |
+
self.use_deconv = kwargs.get('use_deconv', True)
|
| 33 |
+
assert not (self.use_deconv and self.use_upcatconv)
|
| 34 |
+
self.same_up_channels = kwargs.get('same_up_channels', False)
|
| 35 |
+
self.use_conv1x1 = kwargs.get('use_conv1x1', False)
|
| 36 |
+
assert not (self.conv_as_maxpool and self.no_maxpool)
|
| 37 |
+
self.enc_name = kwargs.get('encoder_name', 'resnet18')
|
| 38 |
+
self.reduced_layers = kwargs.get('reduced_layers', False) and self.enc_name not in ['resnet18, resnet34']
|
| 39 |
+
|
| 40 |
+
pretrained = kwargs.get('pretrained_encoder', False)
|
| 41 |
+
replace_with_dilation = kwargs.get('replace_with_dilation')
|
| 42 |
+
strides = kwargs.get('strides', 2)
|
| 43 |
+
resnet = get_resnet_model(self.enc_name)(pretrained, progress=False, norm_layer=self.norm_layer_kwargs,
|
| 44 |
+
strides=strides, replace_with_dilation=replace_with_dilation)
|
| 45 |
+
|
| 46 |
+
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)
|
| 47 |
+
# XXX: maxpool creates high amplitude high freq activations, removing it leads to better results
|
| 48 |
+
if self.conv_as_maxpool:
|
| 49 |
+
layer0_out_channels = self.get_nb_out_channels(self.layer0)
|
| 50 |
+
self.layer1 = nn.Sequential(*[conv3x3(layer0_out_channels, layer0_out_channels, stride=2),
|
| 51 |
+
self.norm_layer(layer0_out_channels),
|
| 52 |
+
nn.ReLU()] + list(resnet.layer1.children()))
|
| 53 |
+
elif self.no_maxpool:
|
| 54 |
+
self.layer1 = nn.Sequential(*list(resnet.layer1.children()))
|
| 55 |
+
else:
|
| 56 |
+
self.layer1 = nn.Sequential(*[resnet.maxpool] + list(resnet.layer1.children()))
|
| 57 |
+
self.layer2, self.layer3, self.layer4 = resnet.layer2, resnet.layer3, resnet.layer4
|
| 58 |
+
|
| 59 |
+
layer0_out_channels = self.get_nb_out_channels(self.layer0)
|
| 60 |
+
layer1_out_channels = self.get_nb_out_channels(self.layer1)
|
| 61 |
+
layer2_out_channels = self.get_nb_out_channels(self.layer2)
|
| 62 |
+
layer3_out_channels = self.get_nb_out_channels(self.layer3)
|
| 63 |
+
layer4_out_channels = self.get_nb_out_channels(self.layer4)
|
| 64 |
+
if self.reduced_layers:
|
| 65 |
+
self.layer1_red = self._reducing_layer(layer1_out_channels, LAYER1_REDUCED_CHANNELS)
|
| 66 |
+
self.layer2_red = self._reducing_layer(layer2_out_channels, LAYER2_REDUCED_CHANNELS)
|
| 67 |
+
self.layer3_red = self._reducing_layer(layer3_out_channels, LAYER3_REDUCED_CHANNELS)
|
| 68 |
+
self.layer4_red = self._reducing_layer(layer4_out_channels, LAYER4_REDUCED_CHANNELS)
|
| 69 |
+
layer1_out_channels, layer2_out_channels = LAYER1_REDUCED_CHANNELS, LAYER2_REDUCED_CHANNELS
|
| 70 |
+
layer3_out_channels, layer4_out_channels = LAYER3_REDUCED_CHANNELS, LAYER4_REDUCED_CHANNELS
|
| 71 |
+
|
| 72 |
+
self.layer4_up = self._upsampling_layer(layer4_out_channels, layer3_out_channels, layer3_out_channels)
|
| 73 |
+
self.layer3_up = self._upsampling_layer(layer3_out_channels, layer2_out_channels, layer2_out_channels)
|
| 74 |
+
self.layer2_up = self._upsampling_layer(layer2_out_channels, layer1_out_channels, layer1_out_channels)
|
| 75 |
+
self.layer1_up = self._upsampling_layer(layer1_out_channels, layer0_out_channels, layer0_out_channels)
|
| 76 |
+
self.layer0_up = self._upsampling_layer(layer0_out_channels, FINAL_LAYER_CHANNELS, INPUT_CHANNELS)
|
| 77 |
+
self.final_layer = self._final_layer(FINAL_LAYER_CHANNELS)
|
| 78 |
+
|
| 79 |
+
if not pretrained:
|
| 80 |
+
self._init_conv_weights()
|
| 81 |
+
|
| 82 |
+
print("Model {} initialisated with norm_layer={}({}) and kwargs {}"
|
| 83 |
+
.format(self.name, self.norm_layer.func.__name__, self.norm_layer.keywords, kwargs))
|
| 84 |
+
|
| 85 |
+
def _reducing_layer(self, in_channels, out_channels):
|
| 86 |
+
return nn.Sequential(OrderedDict([
|
| 87 |
+
('conv', conv1x1(in_channels, out_channels)),
|
| 88 |
+
('bn', self.norm_layer(out_channels)),
|
| 89 |
+
('relu', nn.ReLU()),
|
| 90 |
+
]))
|
| 91 |
+
|
| 92 |
+
def get_nb_out_channels(self, layer):
|
| 93 |
+
return list(filter(lambda e: isinstance(e, nn.Conv2d), layer.modules()))[-1].out_channels
|
| 94 |
+
|
| 95 |
+
def _upsampling_layer(self, in_channels, out_channels, cat_channels):
|
| 96 |
+
if self.use_upcatconv:
|
| 97 |
+
return UpsampleCatConv(in_channels + cat_channels, out_channels, norm_layer=self.norm_layer,
|
| 98 |
+
use_conv1x1=self.use_conv1x1)
|
| 99 |
+
else:
|
| 100 |
+
up_channels = in_channels if self.same_up_channels else None
|
| 101 |
+
return DecoderModule(in_channels, out_channels, cat_channels, up_channels=up_channels,
|
| 102 |
+
norm_layer=self.norm_layer, n_conv=1, use_deconv=self.use_deconv,
|
| 103 |
+
use_conv1x1=self.use_conv1x1)
|
| 104 |
+
|
| 105 |
+
def _final_layer(self, in_channels):
|
| 106 |
+
return nn.Sequential(OrderedDict([('conv', conv1x1(in_channels, self.n_classes))]))
|
| 107 |
+
|
| 108 |
+
def _init_conv_weights(self):
|
| 109 |
+
for m in self.modules():
|
| 110 |
+
if isinstance(m, nn.Conv2d):
|
| 111 |
+
nn.init.xavier_uniform_(m.weight)
|
| 112 |
+
|
| 113 |
+
def load_state_dict_for_unet(self, state_dict):
|
| 114 |
+
unloaded_params = []
|
| 115 |
+
state = self.state_dict()
|
| 116 |
+
for name, param in state_dict.items():
|
| 117 |
+
if name in state and state[name].shape == param.shape:
|
| 118 |
+
if isinstance(param, nn.Parameter):
|
| 119 |
+
param = param.data
|
| 120 |
+
state[name].copy_(param)
|
| 121 |
+
else:
|
| 122 |
+
unloaded_params.append(name)
|
| 123 |
+
|
| 124 |
+
if len(unloaded_params) > 0:
|
| 125 |
+
print('load_state_dict: {} not found'.format(unloaded_params))
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
x0 = self.layer0(x)
|
| 129 |
+
x1 = self.layer1(x0)
|
| 130 |
+
x2 = self.layer2(x1)
|
| 131 |
+
x3 = self.layer3(x2)
|
| 132 |
+
# x4 = self.layer4(x3)
|
| 133 |
+
|
| 134 |
+
if self.reduced_layers:
|
| 135 |
+
x4 = self.layer4_red(x4)
|
| 136 |
+
x3 = self.layer3_red(x3)
|
| 137 |
+
x2 = self.layer2_red(x2)
|
| 138 |
+
x1 = self.layer1_red(x1)
|
| 139 |
+
|
| 140 |
+
# x3 = self.layer4_up(x4, other=x3)
|
| 141 |
+
x2 = self.layer3_up(x3, other=x2)
|
| 142 |
+
x1 = self.layer2_up(x2, other=x1)
|
| 143 |
+
x0 = self.layer1_up(x1, other=x0)
|
| 144 |
+
x = self.layer0_up(x0, other=x)
|
| 145 |
+
x = self.final_layer(x)
|
| 146 |
+
|
| 147 |
+
return x
|
py3/lf/models/resnet.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from toolz import keyfilter
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
| 6 |
+
|
| 7 |
+
from .tools import conv3x3, conv1x1, get_norm_layer
|
| 8 |
+
|
| 9 |
+
model_urls = {
|
| 10 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
| 11 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
| 12 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
| 13 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
| 14 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
| 15 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
| 16 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
| 17 |
+
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
| 18 |
+
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_resnet_model(name):
|
| 23 |
+
if name is None:
|
| 24 |
+
name = 'resnet18'
|
| 25 |
+
return {
|
| 26 |
+
'resnet18': resnet18,
|
| 27 |
+
'resnet34': resnet34,
|
| 28 |
+
'resnet50': resnet50,
|
| 29 |
+
'resnet101': resnet101,
|
| 30 |
+
'resnet152': resnet152,
|
| 31 |
+
'resnext50_32x4d': resnext50_32x4d,
|
| 32 |
+
'resnext101_32x8d': resnext101_32x8d,
|
| 33 |
+
'wide_resnet50_2': wide_resnet50_2,
|
| 34 |
+
'wide_resnet101_2': wide_resnet101_2,
|
| 35 |
+
}[name]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BasicBlock(nn.Module):
|
| 39 |
+
expansion = 1
|
| 40 |
+
|
| 41 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 42 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 43 |
+
super(BasicBlock, self).__init__()
|
| 44 |
+
if norm_layer is None:
|
| 45 |
+
norm_layer = nn.BatchNorm2d
|
| 46 |
+
if groups != 1 or base_width != 64:
|
| 47 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 48 |
+
if dilation > 1:
|
| 49 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 50 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 51 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 52 |
+
self.bn1 = norm_layer(planes)
|
| 53 |
+
self.relu = nn.ReLU(inplace=True)
|
| 54 |
+
self.conv2 = conv3x3(planes, planes)
|
| 55 |
+
self.bn2 = norm_layer(planes)
|
| 56 |
+
self.downsample = downsample
|
| 57 |
+
self.stride = stride
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
identity = x
|
| 61 |
+
|
| 62 |
+
out = self.conv1(x)
|
| 63 |
+
out = self.bn1(out)
|
| 64 |
+
out = self.relu(out)
|
| 65 |
+
|
| 66 |
+
out = self.conv2(out)
|
| 67 |
+
out = self.bn2(out)
|
| 68 |
+
|
| 69 |
+
if self.downsample is not None:
|
| 70 |
+
identity = self.downsample(x)
|
| 71 |
+
|
| 72 |
+
out += identity
|
| 73 |
+
out = self.relu(out)
|
| 74 |
+
|
| 75 |
+
return out
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Bottleneck(nn.Module):
|
| 79 |
+
expansion = 4
|
| 80 |
+
|
| 81 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 82 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 83 |
+
super(Bottleneck, self).__init__()
|
| 84 |
+
if norm_layer is None:
|
| 85 |
+
norm_layer = nn.BatchNorm2d
|
| 86 |
+
width = int(planes * (base_width / 64.)) * groups
|
| 87 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 88 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 89 |
+
self.bn1 = norm_layer(width)
|
| 90 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 91 |
+
self.bn2 = norm_layer(width)
|
| 92 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 93 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 94 |
+
self.relu = nn.ReLU(inplace=True)
|
| 95 |
+
self.downsample = downsample
|
| 96 |
+
self.stride = stride
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
identity = x
|
| 100 |
+
|
| 101 |
+
out = self.conv1(x)
|
| 102 |
+
out = self.bn1(out)
|
| 103 |
+
out = self.relu(out)
|
| 104 |
+
|
| 105 |
+
out = self.conv2(out)
|
| 106 |
+
out = self.bn2(out)
|
| 107 |
+
out = self.relu(out)
|
| 108 |
+
|
| 109 |
+
out = self.conv3(out)
|
| 110 |
+
out = self.bn3(out)
|
| 111 |
+
|
| 112 |
+
if self.downsample is not None:
|
| 113 |
+
identity = self.downsample(x)
|
| 114 |
+
|
| 115 |
+
out += identity
|
| 116 |
+
out = self.relu(out)
|
| 117 |
+
|
| 118 |
+
return out
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class ResNet(nn.Module):
|
| 122 |
+
|
| 123 |
+
def __init__(self, block, layers, n_classes=1000, zero_init_residual=False,
|
| 124 |
+
groups=1, width_per_group=64, strides=2, replace_with_dilation=None, **kwargs):
|
| 125 |
+
super(ResNet, self).__init__()
|
| 126 |
+
self.norm_layer_kwargs = kwargs.get('norm_layer', dict())
|
| 127 |
+
norm_layer = get_norm_layer(**self.norm_layer_kwargs)
|
| 128 |
+
self._norm_layer = norm_layer
|
| 129 |
+
self.inplanes = 64
|
| 130 |
+
self.groups = groups
|
| 131 |
+
self.base_width = width_per_group
|
| 132 |
+
self.dilation = 1
|
| 133 |
+
if replace_with_dilation is None:
|
| 134 |
+
# each element in the tuple indicates if we should replace
|
| 135 |
+
# the 2x2 stride with a dilated convolution instead
|
| 136 |
+
replace_with_dilation = [False, False, False]
|
| 137 |
+
elif isinstance(replace_with_dilation, bool):
|
| 138 |
+
replace_with_dilation = [replace_with_dilation] * 3
|
| 139 |
+
assert len(replace_with_dilation) == 3
|
| 140 |
+
self.strides = strides if not isinstance(strides, int) else [strides] * 5
|
| 141 |
+
assert len(self.strides) == 5
|
| 142 |
+
|
| 143 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=self.strides[0], padding=3, bias=False)
|
| 144 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 145 |
+
self.relu = nn.ReLU(inplace=True)
|
| 146 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=self.strides[1], padding=1)
|
| 147 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 148 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=self.strides[2], dilate=replace_with_dilation[0])
|
| 149 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=self.strides[3], dilate=replace_with_dilation[1])
|
| 150 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=self.strides[4], dilate=replace_with_dilation[2])
|
| 151 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 152 |
+
self.fc = nn.Linear(512 * block.expansion, n_classes)
|
| 153 |
+
|
| 154 |
+
for m in self.modules():
|
| 155 |
+
if isinstance(m, nn.Conv2d):
|
| 156 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 157 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 158 |
+
nn.init.constant_(m.weight, 1)
|
| 159 |
+
nn.init.constant_(m.bias, 0)
|
| 160 |
+
|
| 161 |
+
# Zero-initialize the last BN in each residual branch,
|
| 162 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 163 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 164 |
+
if zero_init_residual:
|
| 165 |
+
for m in self.modules():
|
| 166 |
+
if isinstance(m, Bottleneck):
|
| 167 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 168 |
+
elif isinstance(m, BasicBlock):
|
| 169 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 170 |
+
|
| 171 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 172 |
+
norm_layer = self._norm_layer
|
| 173 |
+
downsample = None
|
| 174 |
+
previous_dilation = self.dilation
|
| 175 |
+
if dilate:
|
| 176 |
+
self.dilation *= stride
|
| 177 |
+
stride = 1
|
| 178 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 179 |
+
downsample = nn.Sequential(
|
| 180 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 181 |
+
norm_layer(planes * block.expansion),
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
layers = []
|
| 185 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
| 186 |
+
self.base_width, previous_dilation, norm_layer))
|
| 187 |
+
self.inplanes = planes * block.expansion
|
| 188 |
+
for _ in range(1, blocks):
|
| 189 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
| 190 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 191 |
+
norm_layer=norm_layer))
|
| 192 |
+
|
| 193 |
+
return nn.Sequential(*layers)
|
| 194 |
+
|
| 195 |
+
def forward(self, x):
|
| 196 |
+
x = self.conv1(x)
|
| 197 |
+
x = self.bn1(x)
|
| 198 |
+
x = self.relu(x)
|
| 199 |
+
x = self.maxpool(x)
|
| 200 |
+
|
| 201 |
+
x = self.layer1(x)
|
| 202 |
+
x = self.layer2(x)
|
| 203 |
+
x = self.layer3(x)
|
| 204 |
+
x = self.layer4(x)
|
| 205 |
+
|
| 206 |
+
x = self.avgpool(x)
|
| 207 |
+
x = torch.flatten(x, 1)
|
| 208 |
+
x = self.fc(x)
|
| 209 |
+
|
| 210 |
+
return x
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
| 214 |
+
model = ResNet(block, layers, **kwargs)
|
| 215 |
+
if pretrained:
|
| 216 |
+
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
|
| 217 |
+
if not model.norm_layer_kwargs.get('track_running_stats', True):
|
| 218 |
+
state_dict = keyfilter(lambda k: 'running' not in k, state_dict)
|
| 219 |
+
model.load_state_dict(state_dict)
|
| 220 |
+
return model
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def resnet18(pretrained=False, progress=True, **kwargs):
|
| 224 |
+
r"""ResNet-18 model from
|
| 225 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 229 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 230 |
+
"""
|
| 231 |
+
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def resnet34(pretrained=False, progress=True, **kwargs):
|
| 235 |
+
r"""ResNet-34 model from
|
| 236 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 240 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 241 |
+
"""
|
| 242 |
+
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def resnet50(pretrained=False, progress=True, **kwargs):
|
| 246 |
+
r"""ResNet-50 model from
|
| 247 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 251 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 252 |
+
"""
|
| 253 |
+
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def resnet101(pretrained=False, progress=True, **kwargs):
|
| 257 |
+
r"""ResNet-101 model from
|
| 258 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 262 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 263 |
+
"""
|
| 264 |
+
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def resnet152(pretrained=False, progress=True, **kwargs):
|
| 268 |
+
r"""ResNet-152 model from
|
| 269 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 273 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 274 |
+
"""
|
| 275 |
+
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
| 279 |
+
r"""ResNeXt-50 32x4d model from
|
| 280 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 284 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 285 |
+
"""
|
| 286 |
+
kwargs['groups'] = 32
|
| 287 |
+
kwargs['width_per_group'] = 4
|
| 288 |
+
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
| 292 |
+
r"""ResNeXt-101 32x8d model from
|
| 293 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 297 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 298 |
+
"""
|
| 299 |
+
kwargs['groups'] = 32
|
| 300 |
+
kwargs['width_per_group'] = 8
|
| 301 |
+
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
| 305 |
+
r"""Wide ResNet-50-2 model from
|
| 306 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
| 307 |
+
|
| 308 |
+
The model is the same as ResNet except for the bottleneck number of channels
|
| 309 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
| 310 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
| 311 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 315 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 316 |
+
"""
|
| 317 |
+
kwargs['width_per_group'] = 64 * 2
|
| 318 |
+
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
| 322 |
+
r"""Wide ResNet-101-2 model from
|
| 323 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
| 324 |
+
|
| 325 |
+
The model is the same as ResNet except for the bottleneck number of channels
|
| 326 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
| 327 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
| 328 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 332 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 333 |
+
"""
|
| 334 |
+
kwargs['width_per_group'] = 64 * 2
|
| 335 |
+
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
|
py3/lf/models/tools.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
from torch import nn, cat
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_norm_layer(**kwargs):
|
| 9 |
+
name = kwargs.get('name', 'instance_norm')
|
| 10 |
+
momentum = kwargs.get('momentum', 0.1)
|
| 11 |
+
affine = kwargs.get('affine', True)
|
| 12 |
+
track_stats = kwargs.get('track_running_stats', False)
|
| 13 |
+
num_groups = kwargs.get('num_groups', 32)
|
| 14 |
+
|
| 15 |
+
norm_layer = {
|
| 16 |
+
'batch_norm': partial(nn.BatchNorm2d, momentum=momentum, affine=affine, track_running_stats=track_stats),
|
| 17 |
+
'group_norm': partial(nn.GroupNorm, num_groups=num_groups, affine=affine),
|
| 18 |
+
'instance_norm': partial(nn.InstanceNorm2d, momentum=momentum, affine=affine, track_running_stats=track_stats),
|
| 19 |
+
}[name]
|
| 20 |
+
if norm_layer.func == nn.GroupNorm:
|
| 21 |
+
return lambda num_channels: norm_layer(num_channels=num_channels)
|
| 22 |
+
else:
|
| 23 |
+
return norm_layer
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def initialize_weights(*models):
|
| 27 |
+
for model in models:
|
| 28 |
+
for module in model.modules():
|
| 29 |
+
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
|
| 30 |
+
nn.init.kaiming_normal_(module.weight)
|
| 31 |
+
if module.bias is not None:
|
| 32 |
+
module.bias.data.zero_()
|
| 33 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 34 |
+
module.weight.data.fill_(1)
|
| 35 |
+
module.bias.data.zero_()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def count_parameters(model):
|
| 39 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def safe_model_state_dict(state_dict):
|
| 43 |
+
"""Convert a state dict saved from a DataParallel module to normal module state_dict."""
|
| 44 |
+
if not next(iter(state_dict)).startswith("module."):
|
| 45 |
+
return state_dict # abort if dict is not a DataParallel model_state
|
| 46 |
+
new_state_dict = OrderedDict()
|
| 47 |
+
for k, v in state_dict.items():
|
| 48 |
+
new_state_dict[k[7:]] = v # remove 'module.' prefix
|
| 49 |
+
return new_state_dict
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 53 |
+
"""3x3 convolution with padding"""
|
| 54 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 55 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 59 |
+
"""1x1 convolution"""
|
| 60 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class UpsampleCatConv(nn.Module):
|
| 64 |
+
def __init__(self, in_channels, out_channels, norm_layer=None, mode='bilinear', use_conv1x1=False):
|
| 65 |
+
super().__init__()
|
| 66 |
+
norm_layer = norm_layer if norm_layer is not None else nn.BatchNorm2d
|
| 67 |
+
conv_layer = conv1x1 if use_conv1x1 else conv3x3
|
| 68 |
+
self.mode = mode
|
| 69 |
+
self.conv = conv_layer(in_channels, out_channels)
|
| 70 |
+
self.norm = norm_layer(out_channels)
|
| 71 |
+
self.act = nn.ReLU()
|
| 72 |
+
|
| 73 |
+
def forward(self, x, other):
|
| 74 |
+
x = nn.functional.interpolate(x, size=(other.size(2), other.size(3)), mode=self.mode, align_corners=False)
|
| 75 |
+
x = cat((x, other), dim=1)
|
| 76 |
+
x = self.conv(x)
|
| 77 |
+
x = self.norm(x)
|
| 78 |
+
x = self.act(x)
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class UpsampleConv(nn.Module):
|
| 83 |
+
def __init__(self, in_channels, out_channels, norm_layer=None, mode='bilinear', use_conv1x1=False):
|
| 84 |
+
super().__init__()
|
| 85 |
+
norm_layer = norm_layer if norm_layer is not None else nn.BatchNorm2d
|
| 86 |
+
conv_layer = conv1x1 if use_conv1x1 else conv3x3
|
| 87 |
+
self.mode = mode
|
| 88 |
+
self.conv = conv_layer(in_channels, out_channels)
|
| 89 |
+
self.norm = norm_layer(out_channels)
|
| 90 |
+
self.act = nn.ReLU()
|
| 91 |
+
|
| 92 |
+
def forward(self, x, output_size):
|
| 93 |
+
x = nn.functional.interpolate(x, size=output_size[2:], mode=self.mode, align_corners=False)
|
| 94 |
+
x = self.conv(x)
|
| 95 |
+
x = self.norm(x)
|
| 96 |
+
x = self.act(x)
|
| 97 |
+
return x
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class DeconvModule(nn.Module):
|
| 101 |
+
def __init__(self, in_channels, out_channels, norm_layer=None):
|
| 102 |
+
super().__init__()
|
| 103 |
+
norm_layer = norm_layer if norm_layer is not None else nn.BatchNorm2d
|
| 104 |
+
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
| 105 |
+
self.norm = norm_layer(out_channels)
|
| 106 |
+
self.act = nn.ReLU()
|
| 107 |
+
|
| 108 |
+
def forward(self, x, output_size):
|
| 109 |
+
x = self.deconv(x, output_size=output_size)
|
| 110 |
+
x = self.norm(x)
|
| 111 |
+
x = self.act(x)
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class DecoderModule(nn.Module):
|
| 116 |
+
def __init__(self, in_channels, out_channels, cat_channels=None, up_channels=None,
|
| 117 |
+
norm_layer=None, n_conv=2, use_deconv=False, use_conv1x1=False):
|
| 118 |
+
super().__init__()
|
| 119 |
+
cat_channels = cat_channels or in_channels // 2
|
| 120 |
+
up_channels = up_channels or in_channels // 2
|
| 121 |
+
norm_layer = norm_layer if norm_layer is not None else nn.BatchNorm2d
|
| 122 |
+
self.use_deconv = use_deconv
|
| 123 |
+
if use_deconv:
|
| 124 |
+
self.decode = DeconvModule(in_channels, up_channels, norm_layer)
|
| 125 |
+
else:
|
| 126 |
+
self.decode = UpsampleConv(in_channels, up_channels, norm_layer, 'bilinear', use_conv1x1)
|
| 127 |
+
self.conv_block = nn.Sequential(OrderedDict(sum([[
|
| 128 |
+
('conv{}'.format(k + 1), conv3x3(up_channels + cat_channels if k == 0 else out_channels, out_channels)),
|
| 129 |
+
('bn{}'.format(k + 1), norm_layer(out_channels)),
|
| 130 |
+
('relu{}'.format(k + 1), nn.ReLU())]
|
| 131 |
+
for k in range(n_conv)], [])))
|
| 132 |
+
|
| 133 |
+
def forward(self, x, other):
|
| 134 |
+
try:
|
| 135 |
+
x = self.decode(x, output_size=other.size())
|
| 136 |
+
except ValueError:
|
| 137 |
+
# XXX a size adjustement is needed for odd sizes
|
| 138 |
+
B, C, H, W = other.size()
|
| 139 |
+
h, w = H // 2 * 2, W // 2 * 2
|
| 140 |
+
x = self.decode(x, output_size=(B, C, h, w))
|
| 141 |
+
x = F.pad(x, (W - w, 0, H - h, 0))
|
| 142 |
+
x = cat((x, other), dim=1)
|
| 143 |
+
x = self.conv_block(x)
|
| 144 |
+
return x
|
py3/lf/stn/__init__.py
ADDED
|
File without changes
|
py3/lf/stn/gridgen.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.autograd import Function
|
| 3 |
+
from torch.autograd import Variable
|
| 4 |
+
from torch.nn.modules.module import Module
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AffineGridGenFunction(Function):
|
| 10 |
+
def __init__(self, height, width):
|
| 11 |
+
super(AffineGridGenFunction, self).__init__()
|
| 12 |
+
self.height, self.width = height, width
|
| 13 |
+
self.grid = np.zeros( [self.height, self.width, 3], dtype=np.float32)
|
| 14 |
+
self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.height), 0), repeats = self.width, axis = 0).T, 0)
|
| 15 |
+
self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.width), 0), repeats = self.height, axis = 0), 0)
|
| 16 |
+
self.grid[:,:,2] = np.ones([self.height, width])
|
| 17 |
+
self.grid = torch.from_numpy(self.grid.astype(np.float32))
|
| 18 |
+
#print(self.grid)
|
| 19 |
+
|
| 20 |
+
def forward(self, input1):
|
| 21 |
+
self.input1 = input1
|
| 22 |
+
output = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
|
| 23 |
+
self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
|
| 24 |
+
for i in range(input1.size(0)):
|
| 25 |
+
self.batchgrid[i] = self.grid
|
| 26 |
+
|
| 27 |
+
if input1.is_cuda:
|
| 28 |
+
self.batchgrid = self.batchgrid.cuda()
|
| 29 |
+
output = output.cuda()
|
| 30 |
+
|
| 31 |
+
for i in range(input1.size(0)):
|
| 32 |
+
output = torch.bmm(self.batchgrid.view(-1, self.height*self.width, 3), torch.transpose(input1, 1, 2)).view(-1, self.height, self.width, 2)
|
| 33 |
+
|
| 34 |
+
return output
|
| 35 |
+
|
| 36 |
+
def backward(self, grad_output):
|
| 37 |
+
|
| 38 |
+
grad_input1 = torch.zeros(self.input1.size())
|
| 39 |
+
|
| 40 |
+
if grad_output.is_cuda:
|
| 41 |
+
self.batchgrid = self.batchgrid.cuda()
|
| 42 |
+
grad_input1 = grad_input1.cuda()
|
| 43 |
+
grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
|
| 44 |
+
|
| 45 |
+
return grad_input1
|
| 46 |
+
|
| 47 |
+
class AffineGridGen(Module):
|
| 48 |
+
def __init__(self, height, width):
|
| 49 |
+
super(AffineGridGen, self).__init__()
|
| 50 |
+
self.height, self.width = height, width
|
| 51 |
+
self.f = AffineGridGenFunction(self.height, self.width)
|
| 52 |
+
|
| 53 |
+
def forward(self, input):
|
| 54 |
+
return self.f(input)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class PerspectiveGridGenFunction(Function):
|
| 58 |
+
def __init__(self, height, width):
|
| 59 |
+
super(PerspectiveGridGenFunction, self).__init__()
|
| 60 |
+
self.height, self.width = height, width
|
| 61 |
+
self.grid = np.zeros( [self.height, self.width, 3], dtype=np.float32)
|
| 62 |
+
self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(np.linspace(-1, 1, self.height), 0), repeats = self.width, axis = 0).T, 0)
|
| 63 |
+
self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(np.linspace(-1, 1, self.width), 0), repeats = self.height, axis = 0), 0)
|
| 64 |
+
self.grid[:,:,2] = np.ones([self.height, width])
|
| 65 |
+
self.grid = torch.from_numpy(self.grid.astype(np.float32))
|
| 66 |
+
|
| 67 |
+
def forward(self, input1):
|
| 68 |
+
self.input1 = input1
|
| 69 |
+
output = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
|
| 70 |
+
self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
|
| 71 |
+
for i in range(input1.size(0)):
|
| 72 |
+
self.batchgrid[i] = self.grid
|
| 73 |
+
|
| 74 |
+
if input1.is_cuda:
|
| 75 |
+
self.batchgrid = self.batchgrid.cuda()
|
| 76 |
+
output = output.cuda()
|
| 77 |
+
|
| 78 |
+
for i in range(input1.size(0)):
|
| 79 |
+
output = torch.bmm(self.batchgrid.view(-1, self.height*self.width, 3), torch.transpose(input1, 1, 2)).view(-1, self.height, self.width, 3)
|
| 80 |
+
|
| 81 |
+
return output
|
| 82 |
+
|
| 83 |
+
def backward(self, grad_output):
|
| 84 |
+
|
| 85 |
+
grad_input1 = torch.zeros(self.input1.size())
|
| 86 |
+
|
| 87 |
+
if grad_output.is_cuda:
|
| 88 |
+
self.batchgrid = self.batchgrid.cuda()
|
| 89 |
+
grad_input1 = grad_input1.cuda()
|
| 90 |
+
grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 3), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
|
| 91 |
+
|
| 92 |
+
return grad_input1
|
| 93 |
+
|
| 94 |
+
class PerspectiveGridGen(Module):
|
| 95 |
+
def __init__(self, height, width):
|
| 96 |
+
super(PerspectiveGridGen, self).__init__()
|
| 97 |
+
self.height, self.width = height, width
|
| 98 |
+
self.f = PerspectiveGridGenFunction(self.height, self.width)
|
| 99 |
+
|
| 100 |
+
def forward(self, input):
|
| 101 |
+
return self.f(input)
|
| 102 |
+
|
| 103 |
+
class GridGen(Module):
|
| 104 |
+
def __init__(self, height, width, device="cuda"):
|
| 105 |
+
super(GridGen, self).__init__()
|
| 106 |
+
self.device = device
|
| 107 |
+
self.height, self.width = height, width
|
| 108 |
+
self.grid = np.zeros( [self.height, self.width, 3], dtype=np.float32)
|
| 109 |
+
|
| 110 |
+
grid_space_h = (np.arange(self.height) + 0.5) / float(self.height)
|
| 111 |
+
grid_space_w = (np.arange(self.width) + 0.5) / float(self.width)
|
| 112 |
+
|
| 113 |
+
grid_space_h = 2 * grid_space_h - 1
|
| 114 |
+
grid_space_w = 2 * grid_space_w - 1
|
| 115 |
+
|
| 116 |
+
self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(grid_space_h, 0), repeats = self.width, axis = 0).T, 0)
|
| 117 |
+
self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(grid_space_w, 0), repeats = self.height, axis = 0), 0)
|
| 118 |
+
# self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(np.linspace(-1, 1, self.height), 0), repeats = self.width, axis = 0).T, 0)
|
| 119 |
+
# self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(np.linspace(-1, 1, self.width), 0), repeats = self.height, axis = 0), 0)
|
| 120 |
+
self.grid[:,:,2] = np.ones([self.height, width])
|
| 121 |
+
self.grid = Variable(torch.from_numpy(self.grid.astype(np.float32)), requires_grad=False).to(self.device)
|
| 122 |
+
|
| 123 |
+
def forward(self, input):
|
| 124 |
+
out = torch.matmul(input[:,None,None,:,:], self.grid[None,:,:,:,None])
|
| 125 |
+
out = out.squeeze(-1)
|
| 126 |
+
return out
|
py3/sol/__init__.py
ADDED
|
File without changes
|
py3/sol/crop_transform.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from sol import crop_utils
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
class CropTransform(object):
|
| 6 |
+
def __init__(self, crop_params):
|
| 7 |
+
crop_size = crop_params['crop_size']
|
| 8 |
+
self.random_crop_params = crop_params
|
| 9 |
+
self.pad_params = ((crop_size,crop_size),(crop_size,crop_size),(0,0))
|
| 10 |
+
|
| 11 |
+
def __call__(self, sample):
|
| 12 |
+
org_img = sample['img']
|
| 13 |
+
gt = sample['sol_gt']
|
| 14 |
+
|
| 15 |
+
org_img = np.pad(org_img, self.pad_params, 'mean')
|
| 16 |
+
|
| 17 |
+
gt[:,:,0] = gt[:,:,0] + self.pad_params[0][0]
|
| 18 |
+
gt[:,:,1] = gt[:,:,1] + self.pad_params[1][0]
|
| 19 |
+
|
| 20 |
+
gt[:,:,2] = gt[:,:,2] + self.pad_params[0][0]
|
| 21 |
+
gt[:,:,3] = gt[:,:,3] + self.pad_params[1][0]
|
| 22 |
+
|
| 23 |
+
crop_params, org_img, gt_match = crop_utils.generate_random_crop(org_img, gt, self.random_crop_params)
|
| 24 |
+
|
| 25 |
+
gt = gt[gt_match][None,...]
|
| 26 |
+
gt[...,0] = gt[...,0] - crop_params['dim1'][0]
|
| 27 |
+
gt[...,1] = gt[...,1] - crop_params['dim0'][0]
|
| 28 |
+
|
| 29 |
+
gt[...,2] = gt[...,2] - crop_params['dim1'][0]
|
| 30 |
+
gt[...,3] = gt[...,3] - crop_params['dim0'][0]
|
| 31 |
+
|
| 32 |
+
return {
|
| 33 |
+
"img": org_img,
|
| 34 |
+
"sol_gt": gt
|
| 35 |
+
}
|
py3/sol/crop_utils.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
def perform_crop(img, crop):
|
| 6 |
+
cs = crop['crop_size']
|
| 7 |
+
cropped_gt_img = img[crop['dim0'][0]:crop['dim0'][1], crop['dim1'][0]:crop['dim1'][1]]
|
| 8 |
+
scaled_gt_img = cv2.resize(cropped_gt_img, (cs, cs), interpolation = cv2.INTER_CUBIC)
|
| 9 |
+
return scaled_gt_img
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def generate_random_crop(img, gt, params):
|
| 13 |
+
|
| 14 |
+
contains_label = np.random.random() < params['prob_label']
|
| 15 |
+
cs = params['crop_size']
|
| 16 |
+
|
| 17 |
+
cnt = 0
|
| 18 |
+
while True:
|
| 19 |
+
|
| 20 |
+
dim0 = np.random.randint(0,img.shape[0]-cs)
|
| 21 |
+
dim1 = np.random.randint(0,img.shape[1]-cs)
|
| 22 |
+
|
| 23 |
+
crop = {
|
| 24 |
+
"dim0": [dim0, dim0+cs],
|
| 25 |
+
"dim1": [dim1, dim1+cs],
|
| 26 |
+
"crop_size": cs
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
#TODO: this only works for the center points
|
| 30 |
+
gt_match = np.zeros_like(gt[...,0:2])
|
| 31 |
+
gt_match[...,0][gt[...,0] < dim1] = 1
|
| 32 |
+
gt_match[...,0][gt[...,0] > dim1+cs] = 1
|
| 33 |
+
|
| 34 |
+
gt_match[...,1][gt[...,1] < dim0] = 1
|
| 35 |
+
gt_match[...,1][gt[...,1] > dim0+cs] = 1
|
| 36 |
+
|
| 37 |
+
gt_match = 1-gt_match
|
| 38 |
+
gt_match = np.logical_and(gt_match[...,0], gt_match[...,1])
|
| 39 |
+
|
| 40 |
+
if gt_match.sum() > 0 and contains_label or cnt > 100:
|
| 41 |
+
cropped_gt_img = perform_crop(img, crop)
|
| 42 |
+
return crop, cropped_gt_img, np.where(gt_match != 0)
|
| 43 |
+
|
| 44 |
+
if gt_match.sum() == 0 and not contains_label:
|
| 45 |
+
cropped_gt_img = perform_crop(img, crop)
|
| 46 |
+
return crop, cropped_gt_img, np.where(gt_match != 0)
|
| 47 |
+
|
| 48 |
+
cnt += 1
|
py3/sol/start_of_line_finder.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.autograd import Variable
|
| 3 |
+
from torch import nn
|
| 4 |
+
from . import vgg
|
| 5 |
+
|
| 6 |
+
class StartOfLineFinder(nn.Module):
|
| 7 |
+
def __init__(self, base_0, base_1):
|
| 8 |
+
super(StartOfLineFinder, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.cnn = vgg.vgg11()
|
| 11 |
+
self.base_0 = base_0
|
| 12 |
+
self.base_1 = base_1
|
| 13 |
+
|
| 14 |
+
def forward(self, img):
|
| 15 |
+
y = self.cnn(img)
|
| 16 |
+
#print('sol input is image of size', img.size())
|
| 17 |
+
#print('sol forward is output of size', y.size())
|
| 18 |
+
|
| 19 |
+
priors_0 = Variable(torch.arange(0,y.size(2)).type_as(img.data), requires_grad=False)[None,:,None]
|
| 20 |
+
priors_0 = (priors_0 + 0.5) * self.base_0
|
| 21 |
+
priors_0 = priors_0.expand(y.size(0), priors_0.size(1), y.size(3))
|
| 22 |
+
priors_0 = priors_0[:,None,:,:]
|
| 23 |
+
|
| 24 |
+
priors_1 = Variable(torch.arange(0,y.size(3)).type_as(img.data), requires_grad=False)[None,None,:]
|
| 25 |
+
priors_1 = (priors_1 + 0.5) * self.base_1
|
| 26 |
+
priors_1 = priors_1.expand(y.size(0), y.size(2), priors_1.size(2))
|
| 27 |
+
priors_1 = priors_1[:,None,:,:]
|
| 28 |
+
|
| 29 |
+
predictions = torch.cat([
|
| 30 |
+
torch.sigmoid(y[:,0:1,:,:]),
|
| 31 |
+
y[:,1:2,:,:] + priors_0,
|
| 32 |
+
y[:,2:3,:,:] + priors_1,
|
| 33 |
+
y[:,3:4,:,:],
|
| 34 |
+
y[:,4:5,:,:]
|
| 35 |
+
], dim=1)
|
| 36 |
+
|
| 37 |
+
predictions = predictions.transpose(1,3).contiguous()
|
| 38 |
+
predictions = predictions.view(predictions.size(0),-1,5)
|
| 39 |
+
|
| 40 |
+
#print('priors_0', priors_0)
|
| 41 |
+
#print('sol final prediction is size', predictions.size())
|
| 42 |
+
return predictions
|
py3/sol/vgg.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
|
| 8 |
+
'vgg19_bn', 'vgg19',
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VGG(nn.Module):
|
| 13 |
+
|
| 14 |
+
def __init__(self, features, num_classes=1000):
|
| 15 |
+
super(VGG, self).__init__()
|
| 16 |
+
self.features = features
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
x = self.features(x)
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
def _initialize_weights(self):
|
| 23 |
+
for m in self.modules():
|
| 24 |
+
if isinstance(m, nn.Conv2d):
|
| 25 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 26 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 27 |
+
if m.bias is not None:
|
| 28 |
+
m.bias.data.zero_()
|
| 29 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 30 |
+
m.weight.data.fill_(1)
|
| 31 |
+
m.bias.data.zero_()
|
| 32 |
+
elif isinstance(m, nn.Linear):
|
| 33 |
+
m.weight.data.normal_(0, 0.01)
|
| 34 |
+
m.bias.data.zero_()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def make_layers(cfg, batch_norm=False):
|
| 38 |
+
layers = []
|
| 39 |
+
in_channels = 3
|
| 40 |
+
for i,v in enumerate(cfg):
|
| 41 |
+
if v == 'M':
|
| 42 |
+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
| 43 |
+
else:
|
| 44 |
+
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
|
| 45 |
+
if i == len(cfg)-1:
|
| 46 |
+
layers += [conv2d]
|
| 47 |
+
break
|
| 48 |
+
if batch_norm:
|
| 49 |
+
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
| 50 |
+
else:
|
| 51 |
+
layers += [conv2d, nn.ReLU(inplace=True)]
|
| 52 |
+
in_channels = v
|
| 53 |
+
return nn.Sequential(*layers)
|
| 54 |
+
|
| 55 |
+
OUTPUT_FEATURES = 5
|
| 56 |
+
cfg = {
|
| 57 |
+
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, OUTPUT_FEATURES],
|
| 58 |
+
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, OUTPUT_FEATURES],
|
| 59 |
+
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, OUTPUT_FEATURES],
|
| 60 |
+
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, OUTPUT_FEATURES],
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def vgg11(pretrained=False, **kwargs):
|
| 65 |
+
"""VGG 11-layer model (configuration "A")
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 69 |
+
"""
|
| 70 |
+
model = VGG(make_layers(cfg['A']), **kwargs)
|
| 71 |
+
if pretrained:
|
| 72 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
|
| 73 |
+
return model
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def vgg11_bn(pretrained=False, **kwargs):
|
| 77 |
+
"""VGG 11-layer model (configuration "A") with batch normalization
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 81 |
+
"""
|
| 82 |
+
model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
|
| 83 |
+
if pretrained:
|
| 84 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
|
| 85 |
+
return model
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def vgg13(pretrained=False, **kwargs):
|
| 89 |
+
"""VGG 13-layer model (configuration "B")
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 93 |
+
"""
|
| 94 |
+
model = VGG(make_layers(cfg['B']), **kwargs)
|
| 95 |
+
if pretrained:
|
| 96 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
|
| 97 |
+
return model
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def vgg13_bn(pretrained=False, **kwargs):
|
| 101 |
+
"""VGG 13-layer model (configuration "B") with batch normalization
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 105 |
+
"""
|
| 106 |
+
model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
|
| 107 |
+
if pretrained:
|
| 108 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
|
| 109 |
+
return model
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def vgg16(pretrained=False, **kwargs):
|
| 113 |
+
"""VGG 16-layer model (configuration "D")
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 117 |
+
"""
|
| 118 |
+
model = VGG(make_layers(cfg['D']), **kwargs)
|
| 119 |
+
if pretrained:
|
| 120 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
|
| 121 |
+
return model
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def vgg16_bn(pretrained=False, **kwargs):
|
| 125 |
+
"""VGG 16-layer model (configuration "D") with batch normalization
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 129 |
+
"""
|
| 130 |
+
model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
|
| 131 |
+
if pretrained:
|
| 132 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
|
| 133 |
+
return model
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def vgg19(pretrained=False, **kwargs):
|
| 137 |
+
"""VGG 19-layer model (configuration "E")
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 141 |
+
"""
|
| 142 |
+
model = VGG(make_layers(cfg['E']), **kwargs)
|
| 143 |
+
if pretrained:
|
| 144 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
|
| 145 |
+
return model
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def vgg19_bn(pretrained=False, **kwargs):
|
| 149 |
+
"""VGG 19-layer model (configuration 'E') with batch normalization
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 153 |
+
"""
|
| 154 |
+
model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
|
| 155 |
+
if pretrained:
|
| 156 |
+
model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
|
| 157 |
+
return model
|
py3/utils/__init__.py
ADDED
|
File without changes
|
py3/utils/character_set.ipynb
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "a0f742b0",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"/home/msaeed3/mehreen/source/start_follow_read/py3\n",
|
| 14 |
+
"....c ب\n",
|
| 15 |
+
"....c أ\n",
|
| 16 |
+
"....c ن\n",
|
| 17 |
+
"....c \n",
|
| 18 |
+
"....c ا\n",
|
| 19 |
+
"....c ل\n",
|
| 20 |
+
"....c ت\n",
|
| 21 |
+
"....c ع\n",
|
| 22 |
+
"....c و\n",
|
| 23 |
+
"....c ي\n",
|
| 24 |
+
"....c ٢\n",
|
| 25 |
+
"....c س\n",
|
| 26 |
+
"....c ط\n",
|
| 27 |
+
"....c ج\n",
|
| 28 |
+
"....c ه\n",
|
| 29 |
+
"....c ز\n",
|
| 30 |
+
"....c ف\n",
|
| 31 |
+
"....c ذ\n",
|
| 32 |
+
"....c ٳ\n",
|
| 33 |
+
"....c ك\n",
|
| 34 |
+
"....c ٪\n",
|
| 35 |
+
"....c ٜ\n",
|
| 36 |
+
"....c ١\n",
|
| 37 |
+
"....c د\n",
|
| 38 |
+
"....c ة\n",
|
| 39 |
+
"....c م\n",
|
| 40 |
+
"....c ٬\n",
|
| 41 |
+
"....c ق\n",
|
| 42 |
+
"....c ر\n",
|
| 43 |
+
"....c ش\n",
|
| 44 |
+
"....c ٚ\n",
|
| 45 |
+
"....c ح\n",
|
| 46 |
+
"....c -\n",
|
| 47 |
+
"....c !\n",
|
| 48 |
+
"....c ص\n",
|
| 49 |
+
"....c ض\n",
|
| 50 |
+
"....c ٝ\n",
|
| 51 |
+
"....c ث\n",
|
| 52 |
+
"....c 7\n",
|
| 53 |
+
"....c ٙ\n",
|
| 54 |
+
"....c ٨\n",
|
| 55 |
+
"....c ک\n",
|
| 56 |
+
"....c ٮ\n",
|
| 57 |
+
"....c ڤ\n",
|
| 58 |
+
"....c ٤\n",
|
| 59 |
+
"....c ں\n",
|
| 60 |
+
"....c 4\n",
|
| 61 |
+
"....c ٰ\n",
|
| 62 |
+
"....c ]\n",
|
| 63 |
+
"....c ڡ\n",
|
| 64 |
+
"....c ى\n",
|
| 65 |
+
"....c \\\n",
|
| 66 |
+
"....c 2\n",
|
| 67 |
+
"....c ٛ\n",
|
| 68 |
+
"....c =\n",
|
| 69 |
+
"....c إ\n",
|
| 70 |
+
"....c غ\n",
|
| 71 |
+
"....c ٲ\n",
|
| 72 |
+
"....c ّ\n",
|
| 73 |
+
"....c >\n",
|
| 74 |
+
"....c .\n",
|
| 75 |
+
"....c )\n",
|
| 76 |
+
"....c <\n",
|
| 77 |
+
"....c ئ\n",
|
| 78 |
+
"....c |\n",
|
| 79 |
+
"....c 0\n",
|
| 80 |
+
"....c +\n",
|
| 81 |
+
"....c x\n",
|
| 82 |
+
"....c ؟\n",
|
| 83 |
+
"....c خ\n",
|
| 84 |
+
"....c }\n",
|
| 85 |
+
"....c &\n",
|
| 86 |
+
"....c %\n",
|
| 87 |
+
"....c ،\n",
|
| 88 |
+
"....c @\n",
|
| 89 |
+
"....c $\n",
|
| 90 |
+
"....c ء\n",
|
| 91 |
+
"....c 5\n",
|
| 92 |
+
"....c 8\n",
|
| 93 |
+
"....c _\n",
|
| 94 |
+
"....c ٌ\n",
|
| 95 |
+
"....c ×\n",
|
| 96 |
+
"....c ^\n",
|
| 97 |
+
"....c ٍ\n",
|
| 98 |
+
"....c `\n",
|
| 99 |
+
"....c [\n",
|
| 100 |
+
"....c آ\n",
|
| 101 |
+
"....c َ\n",
|
| 102 |
+
"....c ;\n",
|
| 103 |
+
"....c ً\n",
|
| 104 |
+
"....c ُ\n",
|
| 105 |
+
"....c /\n",
|
| 106 |
+
"....c ٕ\n",
|
| 107 |
+
"....c ~\n",
|
| 108 |
+
"....c \"\n",
|
| 109 |
+
"....c ٖ\n",
|
| 110 |
+
"....c ظ\n",
|
| 111 |
+
"....c 3\n",
|
| 112 |
+
"....c :\n",
|
| 113 |
+
"....c ۟\n",
|
| 114 |
+
"....c ٥\n",
|
| 115 |
+
"....c چ\n",
|
| 116 |
+
"....c ٣\n",
|
| 117 |
+
"....c ,\n",
|
| 118 |
+
"....c ٧\n",
|
| 119 |
+
"....c ﮐ\n",
|
| 120 |
+
"....c {\n",
|
| 121 |
+
"....c 9\n",
|
| 122 |
+
"....c ?\n",
|
| 123 |
+
"....c '\n",
|
| 124 |
+
"....c ْ\n",
|
| 125 |
+
"....c *\n",
|
| 126 |
+
"....c ـ\n",
|
| 127 |
+
"....c ٔ\n",
|
| 128 |
+
"....c #\n",
|
| 129 |
+
"....c ٓ\n",
|
| 130 |
+
"....c ِ\n",
|
| 131 |
+
"....c 1\n",
|
| 132 |
+
"....c 6\n",
|
| 133 |
+
"....c ‘\n",
|
| 134 |
+
"....c (\n",
|
| 135 |
+
"....c ٠\n",
|
| 136 |
+
"....c ٞ\n",
|
| 137 |
+
"....c ٯ\n",
|
| 138 |
+
"....c ؤ\n",
|
| 139 |
+
"....c ٘\n",
|
| 140 |
+
"....c ٟ\n",
|
| 141 |
+
"....c ٴ\n",
|
| 142 |
+
"....c ݘ\n",
|
| 143 |
+
"....c ٫\n",
|
| 144 |
+
"....c ی\n",
|
| 145 |
+
"....c ٦\n",
|
| 146 |
+
"....c ٩\n",
|
| 147 |
+
"....c ٵ\n",
|
| 148 |
+
"....c ٱ\n",
|
| 149 |
+
"....c –\n",
|
| 150 |
+
"....c ؛\n",
|
| 151 |
+
"....c ٶ\n",
|
| 152 |
+
"....c ٭\n",
|
| 153 |
+
"....c ٗ\n",
|
| 154 |
+
"....c ﭐ\n",
|
| 155 |
+
"....c �\n",
|
| 156 |
+
"....c ﺟ\n",
|
| 157 |
+
"....c ﮞ\n",
|
| 158 |
+
"ﮞ 866\n",
|
| 159 |
+
"� 876\n",
|
| 160 |
+
"ﭐ 888\n",
|
| 161 |
+
"ﮐ 892\n",
|
| 162 |
+
"ﺟ 910\n",
|
| 163 |
+
"۟ 2654\n",
|
| 164 |
+
"‘ 2685\n",
|
| 165 |
+
"ݘ 2718\n",
|
| 166 |
+
"ی 2790\n",
|
| 167 |
+
"– 2802\n",
|
| 168 |
+
"٥ 7794\n",
|
| 169 |
+
"ٴ 7866\n",
|
| 170 |
+
"ں 7885\n",
|
| 171 |
+
"٤ 7898\n",
|
| 172 |
+
"ڡ 7902\n",
|
| 173 |
+
"ٝ 7921\n",
|
| 174 |
+
"ٵ 7933\n",
|
| 175 |
+
"ٲ 7944\n",
|
| 176 |
+
"ٶ 7962\n",
|
| 177 |
+
"٬ 7966\n",
|
| 178 |
+
"ٳ 7971\n",
|
| 179 |
+
"٧ 8011\n",
|
| 180 |
+
"٭ 8012\n",
|
| 181 |
+
"٢ 8019\n",
|
| 182 |
+
"٘ 8045\n",
|
| 183 |
+
"٦ 8049\n",
|
| 184 |
+
"ٰ 8056\n",
|
| 185 |
+
"٠ 8061\n",
|
| 186 |
+
"ٟ 8067\n",
|
| 187 |
+
"ٙ 8080\n",
|
| 188 |
+
"٩ 8081\n",
|
| 189 |
+
"ٯ 8083\n",
|
| 190 |
+
"٪ 8084\n",
|
| 191 |
+
"٫ 8098\n",
|
| 192 |
+
"ک 8102\n",
|
| 193 |
+
"ٱ 8118\n",
|
| 194 |
+
"ٜ 8123\n",
|
| 195 |
+
"ڤ 8126\n",
|
| 196 |
+
"ٮ 8133\n",
|
| 197 |
+
"ٞ 8133\n",
|
| 198 |
+
"١ 8158\n",
|
| 199 |
+
"ٛ 8162\n",
|
| 200 |
+
"٨ 8163\n",
|
| 201 |
+
"ٚ 8194\n",
|
| 202 |
+
"چ 8219\n",
|
| 203 |
+
"ٗ 8222\n",
|
| 204 |
+
"٣ 8378\n",
|
| 205 |
+
"ٍ 12142\n",
|
| 206 |
+
"~ 12159\n",
|
| 207 |
+
"9 12171\n",
|
| 208 |
+
"ِ 12173\n",
|
| 209 |
+
"1 12198\n",
|
| 210 |
+
"ٓ 12228\n",
|
| 211 |
+
"[ 12238\n",
|
| 212 |
+
"{ 12281\n",
|
| 213 |
+
"' 12285\n",
|
| 214 |
+
"! 12317\n",
|
| 215 |
+
"× 12331\n",
|
| 216 |
+
"< 12337\n",
|
| 217 |
+
"2 12344\n",
|
| 218 |
+
"ْ 12345\n",
|
| 219 |
+
"_ 12349\n",
|
| 220 |
+
"- 12350\n",
|
| 221 |
+
"% 12360\n",
|
| 222 |
+
"8 12366\n",
|
| 223 |
+
"5 12373\n",
|
| 224 |
+
"3 12381\n",
|
| 225 |
+
"ٔ 12383\n",
|
| 226 |
+
"} 12384\n",
|
| 227 |
+
"# 12386\n",
|
| 228 |
+
"x 12392\n",
|
| 229 |
+
"ً 12392\n",
|
| 230 |
+
": 12394\n",
|
| 231 |
+
"7 12398\n",
|
| 232 |
+
"* 12402\n",
|
| 233 |
+
"= 12404\n",
|
| 234 |
+
"+ 12431\n",
|
| 235 |
+
"> 12454\n",
|
| 236 |
+
"6 12457\n",
|
| 237 |
+
"ّ 12459\n",
|
| 238 |
+
"\\ 12461\n",
|
| 239 |
+
") 12462\n",
|
| 240 |
+
"؛ 12467\n",
|
| 241 |
+
"` 12477\n",
|
| 242 |
+
"$ 12479\n",
|
| 243 |
+
"0 12486\n",
|
| 244 |
+
"؟ 12487\n",
|
| 245 |
+
"? 12487\n",
|
| 246 |
+
"ـ 12506\n",
|
| 247 |
+
". 12514\n",
|
| 248 |
+
"( 12517\n",
|
| 249 |
+
"ٌ 12534\n",
|
| 250 |
+
"^ 12539\n",
|
| 251 |
+
"\" 12541\n",
|
| 252 |
+
"/ 12544\n",
|
| 253 |
+
"، 12565\n",
|
| 254 |
+
"ٖ 12566\n",
|
| 255 |
+
"َ 12568\n",
|
| 256 |
+
"ٕ 12574\n",
|
| 257 |
+
"; 12588\n",
|
| 258 |
+
"ُ 12604\n",
|
| 259 |
+
"& 12610\n",
|
| 260 |
+
"@ 12610\n",
|
| 261 |
+
"] 12644\n",
|
| 262 |
+
"4 12668\n",
|
| 263 |
+
", 12673\n",
|
| 264 |
+
"| 12690\n",
|
| 265 |
+
"آ 13719\n",
|
| 266 |
+
"ؤ 17881\n",
|
| 267 |
+
"ظ 25831\n",
|
| 268 |
+
"غ 33164\n",
|
| 269 |
+
"ء 41411\n",
|
| 270 |
+
"ئ 58235\n",
|
| 271 |
+
"إ 62171\n",
|
| 272 |
+
"ث 63826\n",
|
| 273 |
+
"ذ 73869\n",
|
| 274 |
+
"ز 77302\n",
|
| 275 |
+
"ض 80181\n",
|
| 276 |
+
"ص 106844\n",
|
| 277 |
+
"ى 107009\n",
|
| 278 |
+
"خ 113804\n",
|
| 279 |
+
"ط 116166\n",
|
| 280 |
+
"ش 120758\n",
|
| 281 |
+
"أ 155544\n",
|
| 282 |
+
"ج 183706\n",
|
| 283 |
+
"ك 213684\n",
|
| 284 |
+
"ح 228902\n",
|
| 285 |
+
"ه 259392\n",
|
| 286 |
+
"ق 272453\n",
|
| 287 |
+
"ف 306028\n",
|
| 288 |
+
"س 308145\n",
|
| 289 |
+
"د 387589\n",
|
| 290 |
+
"ع 403225\n",
|
| 291 |
+
"ب 418369\n",
|
| 292 |
+
"ة 431012\n",
|
| 293 |
+
"ت 556063\n",
|
| 294 |
+
"ر 567321\n",
|
| 295 |
+
"ن 612221\n",
|
| 296 |
+
"و 639292\n",
|
| 297 |
+
"م 800866\n",
|
| 298 |
+
"ي 919298\n",
|
| 299 |
+
"ل 1496107\n",
|
| 300 |
+
"ا 2022160\n",
|
| 301 |
+
" 3392229\n",
|
| 302 |
+
"('Size:', 144)\n"
|
| 303 |
+
]
|
| 304 |
+
}
|
| 305 |
+
],
|
| 306 |
+
"source": [
|
| 307 |
+
"import sys\n",
|
| 308 |
+
"import json\n",
|
| 309 |
+
"import os\n",
|
| 310 |
+
"from collections import defaultdict\n",
|
| 311 |
+
"\n",
|
| 312 |
+
"# These are for RASAM full pages\n",
|
| 313 |
+
"#OUT_PATH_c = '/home/msaeed3/mehreen/datasets/RASAM/sfr/'\n",
|
| 314 |
+
"#DATA_PATH_c = '/home/msaeed3/mehreen/datasets/RASAM/sfr/'\n",
|
| 315 |
+
"#OUT_NAME_c = 'char_set_rasam.json'\n",
|
| 316 |
+
"\n",
|
| 317 |
+
"# These are for regions in RASAM and RASM\n",
|
| 318 |
+
"#OUT_PATH_c = '/home/msaeed3/mehreen/datasets/RASM/regions_sfr/'\n",
|
| 319 |
+
"#DATA_PATH_c = '/home/msaeed3/mehreen/datasets/RASM/regions_sfr/'\n",
|
| 320 |
+
"\n",
|
| 321 |
+
"# These are for RASM, RASAM, MoiseK, KHATT\n",
|
| 322 |
+
"#OUT_PATH_c = '/home/msaeed3/mehreen/datasets/arabic_all/'\n",
|
| 323 |
+
"#DATA_PATH_c = '/home/msaeed3/mehreen/datasets/arabic_all/'\n",
|
| 324 |
+
"#OUT_NAME_c = 'char_set_arabic.json'\n",
|
| 325 |
+
"\n",
|
| 326 |
+
"OUT_PATH_c = '/home/msaeed3/mehreen/datasets/synthetic/line_images/'\n",
|
| 327 |
+
"OUT_NAME_c = 'char_set_line_images.json'\n",
|
| 328 |
+
"DATA_PATH_c = '/home/msaeed3/mehreen/datasets/synthetic/line_images/'\n",
|
| 329 |
+
"\n",
|
| 330 |
+
"os.chdir('/home/msaeed3/mehreen/source/start_follow_read/py3/')\n",
|
| 331 |
+
"print(os.getcwd())\n",
|
| 332 |
+
"\n",
|
| 333 |
+
"def load_char_set(char_set_path):\n",
|
| 334 |
+
" with open(char_set_path) as f:\n",
|
| 335 |
+
" char_set = json.load(f)\n",
|
| 336 |
+
"\n",
|
| 337 |
+
" idx_to_char = {}\n",
|
| 338 |
+
" for k,v in char_set['idx_to_char'].items():\n",
|
| 339 |
+
" idx_to_char[int(k)] = v\n",
|
| 340 |
+
"\n",
|
| 341 |
+
" return idx_to_char, char_set['char_to_idx']\n",
|
| 342 |
+
"\n",
|
| 343 |
+
"if __name__ == \"__main__\":\n",
|
| 344 |
+
" character_set_path = OUT_PATH_c + OUT_NAME_c\n",
|
| 345 |
+
" out_char_to_idx = {}\n",
|
| 346 |
+
" out_idx_to_char = {}\n",
|
| 347 |
+
" char_freq = defaultdict(int) \n",
|
| 348 |
+
" \n",
|
| 349 |
+
" dirs = [DATA_PATH_c]\n",
|
| 350 |
+
" \n",
|
| 351 |
+
" input_data_files = ['Train.json', 'Valid.json', 'Test.json']\n",
|
| 352 |
+
" data_file = []\n",
|
| 353 |
+
" for path in dirs:\n",
|
| 354 |
+
" for i in range(len(input_data_files)):\n",
|
| 355 |
+
" data_file = path + input_data_files[i]\n",
|
| 356 |
+
" with open(data_file) as f:\n",
|
| 357 |
+
" paths = json.load(f)\n",
|
| 358 |
+
"\n",
|
| 359 |
+
" for json_path, image_path in paths:\n",
|
| 360 |
+
" \n",
|
| 361 |
+
" with open(json_path) as f:\n",
|
| 362 |
+
" data = json.load(f)\n",
|
| 363 |
+
"\n",
|
| 364 |
+
" cnt = 1 # this is important that this starts at 1 not 0\n",
|
| 365 |
+
" for data_item in data:\n",
|
| 366 |
+
" # Mehreen: Cater for Nan gt\n",
|
| 367 |
+
" if 'gt' in data_item and type(data_item['gt']) == float:\n",
|
| 368 |
+
" continue\n",
|
| 369 |
+
" for c in data_item.get('gt', None):\n",
|
| 370 |
+
"\n",
|
| 371 |
+
" if c is None:\n",
|
| 372 |
+
" print(\"There was a None GT\")\n",
|
| 373 |
+
" continue\n",
|
| 374 |
+
" if c not in out_char_to_idx:\n",
|
| 375 |
+
" print('....c',c)\n",
|
| 376 |
+
" out_char_to_idx[c] = cnt\n",
|
| 377 |
+
" out_idx_to_char[cnt] = c\n",
|
| 378 |
+
" cnt += 1\n",
|
| 379 |
+
" char_freq[c] += 1\n",
|
| 380 |
+
"\n",
|
| 381 |
+
"\n",
|
| 382 |
+
" out_char_to_idx2 = {}\n",
|
| 383 |
+
" out_idx_to_char2 = {}\n",
|
| 384 |
+
"\n",
|
| 385 |
+
" for i, c in enumerate(sorted(out_char_to_idx.keys())):\n",
|
| 386 |
+
" out_char_to_idx2[c] = i+1\n",
|
| 387 |
+
" out_idx_to_char2[i+1] = c\n",
|
| 388 |
+
"\n",
|
| 389 |
+
" output_data = {\n",
|
| 390 |
+
" \"char_to_idx\": out_char_to_idx2,\n",
|
| 391 |
+
" \"idx_to_char\": out_idx_to_char2\n",
|
| 392 |
+
" }\n",
|
| 393 |
+
"\n",
|
| 394 |
+
" for k,v in sorted(iter(char_freq.items()), key=lambda x: x[1]):\n",
|
| 395 |
+
" print(k, v)\n",
|
| 396 |
+
"\n",
|
| 397 |
+
" print((\"Size:\", len(output_data['char_to_idx'])))\n",
|
| 398 |
+
" \n",
|
| 399 |
+
" with open(character_set_path, 'w') as outfile:\n",
|
| 400 |
+
" json.dump(output_data, outfile)\n",
|
| 401 |
+
"\n",
|
| 402 |
+
" "
|
| 403 |
+
]
|
| 404 |
+
},
|
| 405 |
+
{
|
| 406 |
+
"cell_type": "code",
|
| 407 |
+
"execution_count": null,
|
| 408 |
+
"id": "6337a429",
|
| 409 |
+
"metadata": {},
|
| 410 |
+
"outputs": [],
|
| 411 |
+
"source": [
|
| 412 |
+
"import os\n",
|
| 413 |
+
"import json\n",
|
| 414 |
+
"DATA_PATH_c = '/home/msaeed3/mehreen/datasets/arabic_all/'\n",
|
| 415 |
+
"OUT_NAME_c = 'char_set_arabic.json'\n",
|
| 416 |
+
"char_set_path = os.path.join(DATA_PATH_c, OUT_NAME_c)\n",
|
| 417 |
+
"\n",
|
| 418 |
+
"with open(char_set_path) as f:\n",
|
| 419 |
+
" char_set = json.load(f)\n",
|
| 420 |
+
"\n",
|
| 421 |
+
"larger_set = list(char_set['char_to_idx'].keys())\n",
|
| 422 |
+
"larger_set.sort()\n",
|
| 423 |
+
"for l in larger_set:\n",
|
| 424 |
+
" print(l, ascii(l))\n"
|
| 425 |
+
]
|
| 426 |
+
},
|
| 427 |
+
{
|
| 428 |
+
"cell_type": "code",
|
| 429 |
+
"execution_count": null,
|
| 430 |
+
"id": "8c54a65d",
|
| 431 |
+
"metadata": {},
|
| 432 |
+
"outputs": [],
|
| 433 |
+
"source": [
|
| 434 |
+
"DATA_PATH_c = '/home/msaeed3/mehreen/datasets/synthetic/line_images/'\n",
|
| 435 |
+
"OUT_NAME_c = 'char_set_line_images.json'\n",
|
| 436 |
+
"char_set_path = os.path.join(DATA_PATH_c, OUT_NAME_c)\n",
|
| 437 |
+
"\n",
|
| 438 |
+
"with open(char_set_path) as f:\n",
|
| 439 |
+
" char_set = json.load(f)\n",
|
| 440 |
+
"\n",
|
| 441 |
+
"smaller_set = list(char_set['char_to_idx'].keys())\n",
|
| 442 |
+
"smaller_set.sort()\n",
|
| 443 |
+
"for s in smaller_set:\n",
|
| 444 |
+
" print(s, ascii(s))"
|
| 445 |
+
]
|
| 446 |
+
},
|
| 447 |
+
{
|
| 448 |
+
"cell_type": "code",
|
| 449 |
+
"execution_count": null,
|
| 450 |
+
"id": "085f5ac4",
|
| 451 |
+
"metadata": {},
|
| 452 |
+
"outputs": [],
|
| 453 |
+
"source": [
|
| 454 |
+
"\n"
|
| 455 |
+
]
|
| 456 |
+
},
|
| 457 |
+
{
|
| 458 |
+
"cell_type": "code",
|
| 459 |
+
"execution_count": null,
|
| 460 |
+
"id": "34fc501d",
|
| 461 |
+
"metadata": {},
|
| 462 |
+
"outputs": [],
|
| 463 |
+
"source": [
|
| 464 |
+
"diff = larger_set.difference(smaller_set)"
|
| 465 |
+
]
|
| 466 |
+
},
|
| 467 |
+
{
|
| 468 |
+
"cell_type": "code",
|
| 469 |
+
"execution_count": null,
|
| 470 |
+
"id": "ce121a50",
|
| 471 |
+
"metadata": {},
|
| 472 |
+
"outputs": [],
|
| 473 |
+
"source": [
|
| 474 |
+
"def get_special_ascii_set():\n",
|
| 475 |
+
" char_set = []\n",
|
| 476 |
+
" asccii_range = [[33, 64], [91, 96], [123, 126], \n",
|
| 477 |
+
" # Arabic-Indic digits\n",
|
| 478 |
+
" [ord('\\u0661'), ord('\\u0669')]]\n",
|
| 479 |
+
" for r in asccii_range:\n",
|
| 480 |
+
" for val in range(r[0], r[1]+1):\n",
|
| 481 |
+
" char_set.append(chr(val))\n",
|
| 482 |
+
" return char_set\n",
|
| 483 |
+
"\n",
|
| 484 |
+
"print((ord('\\u0661')))\n",
|
| 485 |
+
"special_ascii = get_special_ascii_set()\n",
|
| 486 |
+
"special_ascii.sort()\n",
|
| 487 |
+
"print(special_ascii)\n",
|
| 488 |
+
"special_ascii = set(special_ascii)"
|
| 489 |
+
]
|
| 490 |
+
},
|
| 491 |
+
{
|
| 492 |
+
"cell_type": "code",
|
| 493 |
+
"execution_count": null,
|
| 494 |
+
"id": "3a89913c",
|
| 495 |
+
"metadata": {},
|
| 496 |
+
"outputs": [],
|
| 497 |
+
"source": [
|
| 498 |
+
"larger_set = set(larger_set)\n",
|
| 499 |
+
"smaller_set = set(smaller_set)\n",
|
| 500 |
+
"l_diff_ss = larger_set.difference(special_ascii.union(smaller_set))\n",
|
| 501 |
+
"l_diff_ss = (list(l_diff_ss))\n",
|
| 502 |
+
"l_diff_ss.sort()\n",
|
| 503 |
+
"for l in l_diff_ss:\n",
|
| 504 |
+
" print(l, ascii(l))"
|
| 505 |
+
]
|
| 506 |
+
},
|
| 507 |
+
{
|
| 508 |
+
"cell_type": "code",
|
| 509 |
+
"execution_count": null,
|
| 510 |
+
"id": "72f97354",
|
| 511 |
+
"metadata": {},
|
| 512 |
+
"outputs": [],
|
| 513 |
+
"source": [
|
| 514 |
+
"larger_set.difference(smaller_set)"
|
| 515 |
+
]
|
| 516 |
+
}
|
| 517 |
+
],
|
| 518 |
+
"metadata": {
|
| 519 |
+
"kernelspec": {
|
| 520 |
+
"display_name": "torch",
|
| 521 |
+
"language": "python",
|
| 522 |
+
"name": "torch"
|
| 523 |
+
},
|
| 524 |
+
"language_info": {
|
| 525 |
+
"codemirror_mode": {
|
| 526 |
+
"name": "ipython",
|
| 527 |
+
"version": 3
|
| 528 |
+
},
|
| 529 |
+
"file_extension": ".py",
|
| 530 |
+
"mimetype": "text/x-python",
|
| 531 |
+
"name": "python",
|
| 532 |
+
"nbconvert_exporter": "python",
|
| 533 |
+
"pygments_lexer": "ipython3",
|
| 534 |
+
"version": "3.9.13"
|
| 535 |
+
}
|
| 536 |
+
},
|
| 537 |
+
"nbformat": 4,
|
| 538 |
+
"nbformat_minor": 5
|
| 539 |
+
}
|
py3/utils/character_set.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
|
| 6 |
+
def load_char_set(char_set_path):
|
| 7 |
+
with open(char_set_path) as f:
|
| 8 |
+
char_set = json.load(f)
|
| 9 |
+
|
| 10 |
+
idx_to_char = {}
|
| 11 |
+
for k,v in char_set['idx_to_char'].items():
|
| 12 |
+
idx_to_char[int(k)] = v
|
| 13 |
+
|
| 14 |
+
return idx_to_char, char_set['char_to_idx']
|
| 15 |
+
|
| 16 |
+
if __name__ == "__main__":
|
| 17 |
+
character_set_path = sys.argv[-1]
|
| 18 |
+
out_char_to_idx = {}
|
| 19 |
+
out_idx_to_char = {}
|
| 20 |
+
char_freq = defaultdict(int)
|
| 21 |
+
for i in range(1, len(sys.argv)-1):
|
| 22 |
+
data_file = sys.argv[i]
|
| 23 |
+
with open(data_file) as f:
|
| 24 |
+
paths = json.load(f)
|
| 25 |
+
|
| 26 |
+
for json_path, image_path in paths:
|
| 27 |
+
with open(json_path) as f:
|
| 28 |
+
data = json.load(f)
|
| 29 |
+
|
| 30 |
+
cnt = 1 # this is important that this starts at 1 not 0
|
| 31 |
+
for data_item in data:
|
| 32 |
+
for c in data_item.get('gt', None):
|
| 33 |
+
if c is None:
|
| 34 |
+
print("There was a None GT")
|
| 35 |
+
continue
|
| 36 |
+
if c not in out_char_to_idx:
|
| 37 |
+
out_char_to_idx[c] = cnt
|
| 38 |
+
out_idx_to_char[cnt] = c
|
| 39 |
+
cnt += 1
|
| 40 |
+
char_freq[c] += 1
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
out_char_to_idx2 = {}
|
| 44 |
+
out_idx_to_char2 = {}
|
| 45 |
+
|
| 46 |
+
for i, c in enumerate(sorted(out_char_to_idx.keys())):
|
| 47 |
+
out_char_to_idx2[c] = i+1
|
| 48 |
+
out_idx_to_char2[i+1] = c
|
| 49 |
+
|
| 50 |
+
output_data = {
|
| 51 |
+
"char_to_idx": out_char_to_idx2,
|
| 52 |
+
"idx_to_char": out_idx_to_char2
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
for k,v in sorted(iter(char_freq.items()), key=lambda x: x[1]):
|
| 56 |
+
print(k, v)
|
| 57 |
+
|
| 58 |
+
print(("Size:", len(output_data['char_to_idx'])))
|
| 59 |
+
|
| 60 |
+
with open(character_set_path, 'w') as outfile:
|
| 61 |
+
json.dump(output_data, outfile)
|
py3/utils/continuous_state.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
import sol
|
| 8 |
+
from sol.start_of_line_finder import StartOfLineFinder
|
| 9 |
+
from lf.line_follower import LineFollower
|
| 10 |
+
from hw import cnn_lstm
|
| 11 |
+
|
| 12 |
+
from utils import safe_load
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import cv2
|
| 16 |
+
import json
|
| 17 |
+
import sys
|
| 18 |
+
import os
|
| 19 |
+
import time
|
| 20 |
+
import random
|
| 21 |
+
|
| 22 |
+
def init_model(config, sol_dir='best_validation', lf_dir='best_validation', hw_dir='best_validation',
|
| 23 |
+
only_load=None, device="cuda"):
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
dtype = torch.FloatTensor
|
| 28 |
+
if 'cuda' in device:
|
| 29 |
+
dtype = torch.cuda.FloatTensor
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
base_0 = config['network']['sol']['base0']
|
| 33 |
+
base_1 = config['network']['sol']['base1']
|
| 34 |
+
|
| 35 |
+
sol = None
|
| 36 |
+
lf = None
|
| 37 |
+
hw = None
|
| 38 |
+
|
| 39 |
+
if only_load is None or only_load == 'sol' or 'sol' in only_load:
|
| 40 |
+
sol = StartOfLineFinder(base_0, base_1)
|
| 41 |
+
sol_state = safe_load.torch_state(os.path.join(config['snapshot_path'], "sol.pt"))
|
| 42 |
+
sol.load_state_dict(sol_state)
|
| 43 |
+
sol.to(device)
|
| 44 |
+
|
| 45 |
+
if only_load is None or only_load == 'lf' or 'lf' in only_load:
|
| 46 |
+
# This field may not be present in config and maybe added by the calling module...so you won't see it in the config file
|
| 47 |
+
pt_file = 'lf.pt'
|
| 48 |
+
|
| 49 |
+
lf = LineFollower(config['network']['hw']['input_height'], dtype=dtype, device=device)
|
| 50 |
+
lf_state = safe_load.torch_state(os.path.join(config['snapshot_path'], pt_file))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# special case for backward support of
|
| 54 |
+
# previous way to save the LF weights
|
| 55 |
+
if 'cnn' in lf_state:
|
| 56 |
+
new_state = {}
|
| 57 |
+
for k, v in lf_state.items():
|
| 58 |
+
print(k)
|
| 59 |
+
if k == 'cnn':
|
| 60 |
+
for k2, v2 in v.items():
|
| 61 |
+
if "running" in k2:
|
| 62 |
+
AAA=1
|
| 63 |
+
else:
|
| 64 |
+
new_state[k+"."+k2]=v2
|
| 65 |
+
if k == 'position_linear':
|
| 66 |
+
# print(k2, v2)
|
| 67 |
+
for k2, v2 in v.state_dict().items():
|
| 68 |
+
new_state[k+"."+k2]=v2
|
| 69 |
+
# if k == 'learned_window':
|
| 70 |
+
# print(k, v.data)
|
| 71 |
+
# new_state[k]=nn.Parameter(v.data)
|
| 72 |
+
|
| 73 |
+
lf_state = new_state
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
lf.load_state_dict(lf_state)
|
| 78 |
+
|
| 79 |
+
lf.to(device)
|
| 80 |
+
|
| 81 |
+
if only_load is None or only_load == 'hw' or 'hw' in only_load:
|
| 82 |
+
hw = cnn_lstm.create_model(config['network']['hw'])
|
| 83 |
+
hw_state = safe_load.torch_state(os.path.join(config['snapshot_path'], "hw.pt"))
|
| 84 |
+
hw.load_state_dict(hw_state)
|
| 85 |
+
hw.to(device)
|
| 86 |
+
|
| 87 |
+
return sol, lf, hw
|
py3/utils/dataset_parse.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def load_file_list(config):
|
| 5 |
+
file_list_path = config['file_list']
|
| 6 |
+
with open(file_list_path) as f:
|
| 7 |
+
data = json.load(f)
|
| 8 |
+
|
| 9 |
+
for d in data:
|
| 10 |
+
# print("files:",d)
|
| 11 |
+
json_path = os.path.join(config['json_folder'], d[0])
|
| 12 |
+
img_path = os.path.join(config['img_folder'], d[1])
|
| 13 |
+
|
| 14 |
+
d[0] = json_path
|
| 15 |
+
d[1] = img_path
|
| 16 |
+
|
| 17 |
+
return data
|
py3/utils/dataset_wrapper.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class DatasetWrapper(object):
|
| 2 |
+
def __init__(self, dataset, count):
|
| 3 |
+
self.count = count
|
| 4 |
+
self.idx = 0
|
| 5 |
+
self.dataset = dataset
|
| 6 |
+
self.iter_dataset = iter(dataset)
|
| 7 |
+
self.epoch = 0
|
| 8 |
+
|
| 9 |
+
def __iter__(self):
|
| 10 |
+
return self
|
| 11 |
+
|
| 12 |
+
def __next__(self):
|
| 13 |
+
if self.idx >= self.count:
|
| 14 |
+
self.idx = 0
|
| 15 |
+
raise StopIteration
|
| 16 |
+
|
| 17 |
+
self.idx += 1
|
| 18 |
+
while True:
|
| 19 |
+
try:
|
| 20 |
+
return next(self.iter_dataset)
|
| 21 |
+
except StopIteration:
|
| 22 |
+
self.iter_dataset = iter(self.dataset)
|
| 23 |
+
self.epoch += 1
|
| 24 |
+
try:
|
| 25 |
+
return next(self.iter_dataset)
|
| 26 |
+
except StopIteration:
|
| 27 |
+
raise Exception("Appears as if dataset is empty")
|
py3/utils/error_rates.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import editdistance
|
| 2 |
+
def cer(r, h):
|
| 3 |
+
#Remove any double or trailing
|
| 4 |
+
r = ' '.join(r.split())
|
| 5 |
+
h = ' '.join(h.split())
|
| 6 |
+
|
| 7 |
+
return err(r, h)
|
| 8 |
+
|
| 9 |
+
def err(r, h):
|
| 10 |
+
dis = editdistance.eval(r, h)
|
| 11 |
+
if len(r) == 0.0:
|
| 12 |
+
return len(h)
|
| 13 |
+
|
| 14 |
+
# print(float(dis) / float(len(r)))
|
| 15 |
+
return float(dis) / float(len(r))
|
| 16 |
+
|
| 17 |
+
def wer(r, h):
|
| 18 |
+
r = r.split()
|
| 19 |
+
h = h.split()
|
| 20 |
+
|
| 21 |
+
return err(r,h)
|
py3/utils/fast_inverse.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from torch.autograd import Variable
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
def adjoint(A):
|
| 6 |
+
"""compute inverse without division by det; ...xv3xc3 input, or array of matrices assumed"""
|
| 7 |
+
AI = np.empty_like(A)
|
| 8 |
+
for i in range(3):
|
| 9 |
+
AI[...,i,:] = np.cross(A[...,i-2,:], A[...,i-1,:])
|
| 10 |
+
return AI
|
| 11 |
+
|
| 12 |
+
def inverse_transpose(A):
|
| 13 |
+
"""
|
| 14 |
+
efficiently compute the inverse-transpose for stack of 3x3 matrices
|
| 15 |
+
"""
|
| 16 |
+
I = adjoint(A)
|
| 17 |
+
det = dot(I, A).mean(axis=-1)
|
| 18 |
+
return I / det[...,None,None]
|
| 19 |
+
|
| 20 |
+
def inverse(A):
|
| 21 |
+
"""inverse of a stack of 3x3 matrices"""
|
| 22 |
+
return np.swapaxes( inverse_transpose(A), -1,-2)
|
| 23 |
+
def dot(A, B):
|
| 24 |
+
"""dot arrays of vecs; contract over last indices"""
|
| 25 |
+
return np.einsum('...i,...i->...', A, B)
|
| 26 |
+
|
| 27 |
+
def adjoint_torch(A):
|
| 28 |
+
AI = A.clone()
|
| 29 |
+
for i in range(3):
|
| 30 |
+
AI[...,i,:] = torch.cross(A[...,i-2,:], A[...,i-1,:])
|
| 31 |
+
return AI
|
| 32 |
+
|
| 33 |
+
def inverse_transpose_torch(A):
|
| 34 |
+
I = adjoint_torch(A)
|
| 35 |
+
det = dot_torch(I, A).mean(dim=-1)
|
| 36 |
+
return I / det[:,None,None]
|
| 37 |
+
|
| 38 |
+
def inverse_torch(A):
|
| 39 |
+
return inverse_transpose_torch(A).transpose(1, 2)
|
| 40 |
+
|
| 41 |
+
def dot_torch(A, B):
|
| 42 |
+
A_view = A.view(-1,1,3)
|
| 43 |
+
B_view = B.contiguous().view(-1,3,1)
|
| 44 |
+
out = torch.bmm(A_view, B_view)
|
| 45 |
+
out_view = out.view(A.size()[:-1])
|
| 46 |
+
return out_view
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
A = np.random.rand(2,3,3)
|
| 51 |
+
I = inverse(A)
|
| 52 |
+
|
| 53 |
+
A_torch = Variable(torch.from_numpy(A))
|
| 54 |
+
|
| 55 |
+
I_torch = inverse_torch(A_torch)
|
| 56 |
+
print(I)
|
| 57 |
+
print(I_torch)
|
| 58 |
+
|
py3/utils/safe_load.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
def torch_state(path):
|
| 6 |
+
for i in range(10):
|
| 7 |
+
try:
|
| 8 |
+
state = torch.load(path)
|
| 9 |
+
return state
|
| 10 |
+
except:
|
| 11 |
+
print("Failed to load",i,path)
|
| 12 |
+
time.sleep(i)
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
print("Failed to load state")
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
def json_state(path):
|
| 19 |
+
for i in range(10):
|
| 20 |
+
try:
|
| 21 |
+
with open(path) as f:
|
| 22 |
+
state = json.load(f)
|
| 23 |
+
return state
|
| 24 |
+
except:
|
| 25 |
+
print("Failed to load",i,path)
|
| 26 |
+
time.sleep(i)
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
print("Failed to load state")
|
| 30 |
+
return None
|